mindspore.communication.comm_func.reduce_scatter_tensor
- mindspore.communication.comm_func.reduce_scatter_tensor(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False)[source]
- Reduces and scatters tensors from the specified communication group and returns the tensor which is reduced and scattered. - Note - The tensors must have the same shape and format in all processes of the collection. - Parameters
- tensor (Tensor) – The input tensor to be reduced and scattered, suppose it has a shape \((N, *)\), where * means any number of additional dimensions. N must be divisible by rank_size. rank_size refers to the number of cards in the communication group. 
- op (str, optional) – Specifies an operation used for element-wise reductions, like SUM and MAX. Default: - ReduceOp.SUM.
- group (str, optional) – The communication group to work on. Default: - GlobalComm.WORLD_COMM_GROUP, which means- "hccl_world_group"in Ascend, and- "nccl_world_group"in GPU.
- async_op (bool, optional) – Whether this operator should be an async operator. Default: - False.
 
- Returns
- Tuple(Tensor, CommHandle), the output tensor has the same dtype as input_x with a shape of \((N/rank\_size, *)\). CommHandle is an async work handle, if async_op is set to True. CommHandle will be None, when async_op is False. 
- Raises
- TypeError – If the type of the first input parameter is not Tensor, or any of op and group is not a str. 
- ValueError – If the first dimension of the input cannot be divided by the rank_size. 
- RuntimeError – If device target is invalid, or backend is invalid, or distributed initialization fails. 
 
 - Supported Platforms:
- Ascend
 - Examples - Note - Before running the following examples, you need to configure the communication environment variables. - For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details. - This example should be run with 2 devices. - >>> import numpy as np >>> import mindspore as ms >>> import mindspore.communication as comm >>> >>> comm.init() >>> input_tensor = ms.Tensor(np.ones([8, 8]).astype(np.float32)) >>> output, _ = comm.comm_func.reduce_scatter_tensor(input_tensor) >>> print(output) [[2. 2. 2. 2. 2. 2. 2. 2.] [2. 2. 2. 2. 2. 2. 2. 2.] [2. 2. 2. 2. 2. 2. 2. 2.] [2. 2. 2. 2. 2. 2. 2. 2.]]