mindspore.dataset.dataloader.get_worker_info

View Source On Gitee
mindspore.dataset.dataloader.get_worker_info()[source]

Get the information about the current DataLoader worker process.

The information includes:

  • id (int): The ID of the current worker process.

  • num_workers (int): The total number of the worker processes.

  • seed (int): The random seed used by the current worker process. This value is determined by the base seed generated by the main process and the ID of the current worker process.

  • dataset (Dataset): The dataset object copied from the main process to the current worker process.

If the current process is not a DataLoader worker process, return None.

Returns

Union[WorkerInfo, None], the information about the current DataLoader worker process.

Examples

>>> from mindspore.dataset.dataloader import DataLoader, IterableDataset, get_worker_info
>>>
>>> # Split workload according to the worker info in multi-process data loading
>>> class IterableStyleDataset(IterableDataset):
...     def __init__(self, num_samples):
...         self.start = 0
...         self.end = num_samples
...
...     def __iter__(self):
...         worker_info = get_worker_info()
...         if worker_info is None:  # single-process data loading
...             return iter(range(self.start, self.end))
...         else:  # multi-process data loading
...             return iter(range(worker_info.id, self.end, worker_info.num_workers))
>>>
>>> dataset = IterableStyleDataset(2)
>>> dataloader = DataLoader(dataset, num_workers=2)
>>> print(list(dataloader))
[Tensor(shape=[1], dtype=Int64, value= [0]), Tensor(shape=[1], dtype=Int64, value= [1])]