mindspore.communication.comm_func.scatter_tensor

View Source On AtomGit
mindspore.communication.comm_func.scatter_tensor(tensor, src=0, group=GlobalComm.WORLD_COMM_GROUP)[source]

Scatter tensor evenly across the processes in the specified communication group.

Note

  • The interface behavior only support Tensor input and scatter evenly, which is different from that of pytorch.distributed.scatter.

  • Only the tensor in process src (global rank) will do scatter.

  • Only support PyNative mode, Graph mode is not currently supported.

Parameters
  • tensor (Tensor) – The input tensor to be scattered. The shape of tensor is \((x_1, x_2, ..., x_R)\).

  • src (int, optional) – Specifies the rank(global rank) of the process that send the tensor. And only process src will send the tensor.

  • group (str, optional) – The communication group to work on. Default: GlobalComm.WORLD_COMM_GROUP, which means "hccl_world_group" in Ascend, and "nccl_world_group" in GPU.

Returns

Tensor, the shape of output is \((x_1/src\_rank, x_2, ..., x_R)\). The dimension 0 of data is equal to the dimension of input tensor divided by src, and the other dimension keep the same.

Raises
  • TypeError – If the type of the first input parameter is not Tensor, or any of op and group is not a str.

  • RuntimeError – If device target is invalid, or backend is invalid, or distributed initialization fails.

Supported Platforms:

Ascend GPU

Examples

Note

Before running the following examples, you need to configure the communication environment variables.

For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun startup for more details.

This example should be run with 2 devices.

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.communication as comm
>>>
>>> # Launch 2 processes.
>>>
>>> comm.init()
>>> input = ms.Tensor(np.arange(8).reshape([4, 2]).astype(np.float32))
>>> out = comm.comm_func.scatter_tensor(tensor=input, src=0)
>>> print(out)
# rank_0
[[0. 1.]
 [2. 3.]]
# rank_1
[[4. 5.]
 [6. 7.]]