mindspore.ops.ReduceScatterV
- class mindspore.ops.ReduceScatterV(op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP)[source]
Reduces and scatters uneven tensors from the specified communication group and returns the tensor which is reduced and scattered.
Note
Only support flatten tensor as input. The input tensor should be flattened and concatenated before call this primitive.
- Parameters
op (str, optional) – Specifies an operation used for element-wise reductions, like SUM, MIN and MAX, currently PROD is not supported. Default:
ReduceOp.SUM
.group (str, optional) – The communication group to work on. Default:
GlobalComm.WORLD_COMM_GROUP
, which means"hccl_world_group"
in Ascend, and"nccl_world_group"
in GPU.
- Inputs:
input_x (Tensor) - One-dimensional tensor to be distributed, with the shape \((x_1)\).
input_split_sizes (Union[tuple[int], list[int], Tensor]) - One-dimensional tensor, a list of received data volumes for all ranks. The basic unit is the data type of Tensor. The value is not verified, and the user guarantees its correctness.
- Outputs:
Tensor. Reduces and scatters tensor from remote ranks. If the result is empty, it will return a Tensor with shape (), and value has no actual meaning.
- Raises
RuntimeError – Device target is invalid, backend is invalid, or distributed initialization fails.
- Supported Platforms:
Ascend
GPU
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.
This example should be run with 2 devices.
>>> import mindspore as ms >>> from mindspore import Tensor >>> from mindspore.communication import init, get_rank >>> from mindspore.ops import ReduceOp >>> import mindspore.nn as nn >>> from mindspore.ops.operations.comm_ops import ReduceScatterV >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.reducescatterv = ReduceScatterV(ReduceOp.SUM) ... ... def construct(self, x, input_split_sizes): ... return self.reducescatterv(x, input_split_sizes) ... >>> rank = get_rank() >>> input_x = Tensor([0, 1, 2.0]) >>> input_split_sizes = [2, 1] >>> net = Net() >>> output = net(input_x, input_split_sizes) >>> print(output) rank 0: [0. 2.] rank 1: [4.]
- Tutorial Examples: