mindspore.dataset.dataloader.get_worker_info
- mindspore.dataset.dataloader.get_worker_info()[source]
Get the information about the current
DataLoader
worker process.The information includes:
id (
int
): The ID of the current worker process.num_workers (
int
): The total number of the worker processes.seed (
int
): The random seed used by the current worker process. This value is determined by the base seed generated by the main process and the ID of the current worker process.dataset (
Dataset
): The dataset object copied from the main process to the current worker process.
If the current process is not a
DataLoader
worker process, returnNone
.- Returns
Union[WorkerInfo, None], the information about the current
DataLoader
worker process.
Examples
>>> from mindspore.dataset.dataloader import DataLoader, IterableDataset, get_worker_info >>> >>> # Split workload according to the worker info in multi-process data loading >>> class IterableStyleDataset(IterableDataset): ... def __init__(self, num_samples): ... self.start = 0 ... self.end = num_samples ... ... def __iter__(self): ... worker_info = get_worker_info() ... if worker_info is None: # single-process data loading ... return iter(range(self.start, self.end)) ... else: # multi-process data loading ... return iter(range(worker_info.id, self.end, worker_info.num_workers)) >>> >>> dataset = IterableStyleDataset(2) >>> dataloader = DataLoader(dataset, num_workers=2) >>> print(list(dataloader)) [Tensor(shape=[1], dtype=Int64, value= [0]), Tensor(shape=[1], dtype=Int64, value= [1])]