mindspore.dataset.dataloader.DistributedSampler

查看源文件
class mindspore.dataset.dataloader.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[源代码]

将数据集进行分片用于分布式训练的采样器。

参数:
  • dataset (Dataset) - 用于采样的数据集。

  • num_replicas (int, 可选) - 参与分布式训练的总分片数量。默认值: None

  • rank (int, 可选) - 当前分片在 num_replicas 中的序列号。默认值: None

  • shuffle (bool, 可选) - 采样器是否对样本进行随机排序,默认值: True

  • seed (int, 可选) - 当设置 shuffleTrue 时,用于随机排序采样器的种子值。默认值: 0

  • drop_last (bool, 可选) - 采样器是否舍弃尾部数据。如果为 True ,采样器将舍弃尾部数据,使其能被等分到所有分片中;如果为 False ,采样器将添加额外索引,使数据能被等分到分片中。默认值: False

样例:

>>> from mindspore.dataset.dataloader import DistributedSampler
>>>
>>> dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> sampler = DistributedSampler(dataset, num_replicas=3, rank=0)