mindspore.dataset.dataloader.get_worker_info
- mindspore.dataset.dataloader.get_worker_info()[源代码]
获取当前
DataLoader
工作进程的信息。信息包括:
id (
int
):当前工作进程的ID。num_workers (
int
):工作进程的总数。seed (
int
):当前工作进程使用的随机种子。此值由主进程生成的基础种子和当前工作进程的ID确定。dataset (
Dataset
):从主进程复制到当前工作进程的数据集对象。
如果当前进程不是
DataLoader
工作进程,则返回None
。- 返回:
Union[WorkerInfo, None],当前
DataLoader
工作进程的信息。
样例:
>>> from mindspore.dataset.dataloader import DataLoader, IterableDataset, get_worker_info >>> >>> # Split workload according to the worker info in multi-process data loading >>> class IterableStyleDataset(IterableDataset): ... def __init__(self, num_samples): ... self.start = 0 ... self.end = num_samples ... ... def __iter__(self): ... worker_info = get_worker_info() ... if worker_info is None: # single-process data loading ... return iter(range(self.start, self.end)) ... else: # multi-process data loading ... return iter(range(worker_info.id, self.end, worker_info.num_workers)) >>> >>> dataset = IterableStyleDataset(2) >>> dataloader = DataLoader(dataset, num_workers=2) >>> print(list(dataloader)) [Tensor(shape=[1], dtype=Int64, value= [0]), Tensor(shape=[1], dtype=Int64, value= [1])]