mindspore.dataset.dataloader.BatchSampler

查看源文件
class mindspore.dataset.dataloader.BatchSampler(sampler, batch_size, drop_last)[源代码]

每次生成一个 mini-batch 索引的采样器。

参数:
  • sampler (Union[Sampler, Iterable]) - 用于生成单个索引的采样器。

  • batch_size (int) - mini-batch 的大小。

  • drop_last (bool) - 如果最后一批数据小于 batch_size ,是否舍弃该批次。

样例:

>>> from mindspore.dataset.dataloader import BatchSampler, SequentialSampler
>>>
>>> dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> sequential_sampler = SequentialSampler(dataset)
>>>
>>> batch_sampler = BatchSampler(sequential_sampler, 4, False)
>>> print(list(batch_sampler))
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]
>>>
>>> batch_sampler = BatchSampler(sequential_sampler, 4, True)
>>> print(list(batch_sampler))
[[0, 1, 2, 3], [4, 5, 6, 7]]