mindspore.dataset.BatchInfo

查看源文件
class mindspore.dataset.BatchInfo[源代码]

batch 操作中参数 batch_sizeper_batch_map 的传入对象是回调函数时,可以通过此类提供的方法获取数据集信息。

get_batch_num()[源代码]

返回当前epoch已经处理的batch数,数值从0开始。

样例:

>>> # Create a dataset where its batch size is dynamic
>>> # Define a callable batch size function and let batch size increase 1 each time.
>>> import mindspore.dataset as ds
>>> from mindspore.dataset import BatchInfo
>>>
>>> dataset = ds.GeneratorDataset([i for i in range(3)], "column1", shuffle=False)
>>> def add_one(BatchInfo):
...     return BatchInfo.get_batch_num() + 1
>>> dataset = dataset.batch(batch_size=add_one)
>>> print(list(dataset))
[[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[2], dtype=Int64, value= [1, 2])]]
get_epoch_num()[源代码]

返回当前的epoch数,数值从0开始。

样例:

>>> # Create a dataset where its batch size is dynamic
>>> # Define a callable batch size function and let batch size increase 1 each epoch.
>>> import mindspore.dataset as ds
>>> from mindspore.dataset import BatchInfo
>>>
>>> dataset = ds.GeneratorDataset([i for i in range(4)], "column1", shuffle=False)
>>> def add_one_by_epoch(BatchInfo):
...     return BatchInfo.get_epoch_num() + 1
>>> dataset = dataset.batch(batch_size=add_one_by_epoch)
>>>
>>> result = []
>>> epoch = 2
>>> iterator = dataset.create_tuple_iterator(num_epochs=epoch)
>>> for i in range(epoch):
...    result.extend(list(iterator))
>>> # result:
>>> # [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[1], dtype=Int64, value= [1])],
>>> #  [Tensor(shape=[1], dtype=Int64, value= [2])], [Tensor(shape=[1], dtype=Int64, value= [3])],
>>> #  [Tensor(shape=[2], dtype=Int64, value= [0, 1])], [Tensor(shape=[2], dtype=Int64, value= [2, 3])]]