mindspore.dataset.dataloader.get_worker_info

查看源文件
mindspore.dataset.dataloader.get_worker_info()[源代码]

获取当前 DataLoader 工作进程的信息。

信息包括:

  • id (int):当前工作进程的ID。

  • num_workers (int):工作进程的总数。

  • seed (int):当前工作进程使用的随机种子。此值由主进程生成的基础种子和当前工作进程的ID确定。

  • dataset (Dataset):从主进程复制到当前工作进程的数据集对象。

如果当前进程不是 DataLoader 工作进程,则返回 None

返回:

Union[WorkerInfo, None],当前 DataLoader 工作进程的信息。

样例:

>>> from mindspore.dataset.dataloader import DataLoader, IterableDataset, get_worker_info
>>>
>>> # Split workload according to the worker info in multi-process data loading
>>> class IterableStyleDataset(IterableDataset):
...     def __init__(self, num_samples):
...         self.start = 0
...         self.end = num_samples
...
...     def __iter__(self):
...         worker_info = get_worker_info()
...         if worker_info is None:  # single-process data loading
...             return iter(range(self.start, self.end))
...         else:  # multi-process data loading
...             return iter(range(worker_info.id, self.end, worker_info.num_workers))
>>>
>>> dataset = IterableStyleDataset(2)
>>> dataloader = DataLoader(dataset, num_workers=2)
>>> print(list(dataloader))
[Tensor(shape=[1], dtype=Int64, value= [0]), Tensor(shape=[1], dtype=Int64, value= [1])]