mindspore.dataset.dataloader.DistributedSampler

View Source On Gitee
class mindspore.dataset.dataloader.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source]

A sampler that partitioning datasets for distributed training.

Parameters
  • dataset (Dataset) – Dataset used for sampling.

  • num_replicas (int, optional) – Number of shards participating in distributed training. Default: None .

  • rank (int, optional) – The sequence number of the current shard within num_replicas. Default: None .

  • shuffle (bool, optional) – Whether the sampler shuffle samples randomly. Default: True .

  • seed (int, optional) – When shuffle is set to True , the seed value used for randomizing the sampler. Default: 0 .

  • drop_last (bool, optional) – Whether the sampler discards trailing data. If True , the sampler discards trailing data to enable equal distribution across all shards; if False , the sampler adds extra indices to enable equal distribution across shards. Default: False .

Examples

>>> from mindspore.dataset.dataloader import DistributedSampler
>>>
>>> dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> sampler = DistributedSampler(dataset, num_replicas=3, rank=0)