mindspore.dataset.Dataset.send

查看源文件
mindspore.dataset.Dataset.send(tensor=None, dst=0, group=None)[源代码]

数据集通信接口,将数据发送至目标 Dataset ,可以通过 mindspore.dataset.Dataset.recv 接收。

每调用一次,仅发送一条数据。

说明

这是一个实验性API,后续可能修改或删除。

参数:
  • tensor (Union[Tensor, list[Tensor]], 可选) - 待发送的Tensor/Tensor列表。默认值: None ,表示从当前数据集中获取数据并发送。

  • dst (Union[int, list[int]], 可选) - 目标 Dataset 对应的Rank ID或ID列表。不能是当前Rank ID或包含当前Rank ID。默认值: 0 ,表示将数据发送到Rank 0。

  • group (str, 可选) - 指定通信组实例(由 mindspore.mint.distributed.init_process_group() 或者 mindspore.mint.distributed.new_group() 方法创建)的名称。默认值: None ,使用 GlobalComm.WORLD_COMM_GROUP

样例:

>>> import mindspore as ms
>>> from mindspore.mint.distributed import init_process_group
>>> from mindspore.mint.distributed import get_rank
>>> from mindspore import Tensor
>>> import mindspore.dataset as ds
>>> import numpy as np
>>>
>>> # Launch 8 processes by msrun --worker_num=8 --local_worker_num=8 script.py
>>> init_process_group()
>>> this_rank = get_rank()
>>>
>>> # Create a dataset with 3 columns
>>> input_columns = ["column1", "column2", "column3"]
>>> dataset = ds.GeneratorDataset([(1, 2, 3), (3, 4, 5), (5, 6, 7)], column_names=input_columns)
>>>
>>> # Send a data from the current dataset to the dst rank: 0
>>> if this_rank == 2:
>>>     dataset.send()
>>> if this_rank == 0:
>>>     data = dataset.recv(2)
>>>
>>> # Send the data "send_tensor" to the dst rank: 7
>>> if this_rank == 0:
>>>     send_tensor = Tensor(np.zeros([2, 2, 3]), ms.float32)
>>>     dataset.send(send_tensor, 7)
>>> if this_rank == 7:
>>>     recv_tensor = dataset.recv(0)
>>>
>>> # Send the list of data to dst rank [0, 2, 4, 6]
>>> if this_rank in [1, 3, 5, 7]:
>>>     send_data = Tensor(np.zeros([2, 2, 3]), ms.float32)
>>>     send_label = Tensor(np.zeros([3,]), ms.bool)
>>>     dataset.send([send_data, send_label], [0, 2, 4, 6])
>>> if this_rank in [0, 2, 4, 6]:
>>>     recv_tensors = dataset.recv([1, 3, 5, 7])