mindspore.ops.AllGatherV
- class mindspore.ops.AllGatherV(group=GlobalComm.WORLD_COMM_GROUP)[source]
Gathers uneven tensors from the specified communication group and returns the tensor which is all gathered.
Note
Only support flatten tensor as input. input tensor should be flattened and concatenated before call this primitive.
- Parameters
group (str, optional) – The communication group to work on. Default:
GlobalComm.WORLD_COMM_GROUP
, which means"hccl_world_group"
in Ascend, and"nccl_world_group"
in GPU.
- Inputs:
input_x (Tensor) - One-dimensional tensor to be gathered, with the shape \((x_1)\).
output_split_sizes (Union[tuple[int], list[int], Tensor]) - One-dimensional tensor, a list of the amount of data gathered by all ranks. The basic unit is the data type of Tensor.
- Outputs:
Tensor. flattened and concatenated tensor gather from remote ranks. If gather result is empty, it will return a Tensor with shape (), and value has no actual meaning.
- Raises
RuntimeError – Device target is invalid, backend is invalid, or distributed initialization fails.
- Supported Platforms:
Ascend
GPU
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.
This example should be run with 2 devices.
>>> import mindspore as ms >>> from mindspore.ops import AllGatherV >>> import mindspore.nn as nn >>> from mindspore.communication import init, get_rank >>> from mindspore import Tensor >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.allgatherv = AllGatherV() ... ... def construct(self, x, output_split_sizes): ... return self.allgatherv(x, output_split_sizes) ... >>> rank = get_rank() >>> data = [i for i in range(rank + 3)] >>> input_x = Tensor(data) >>> output_split_sizes = [3, 4] >>> net = Net() >>> output = net(input_x, output_split_sizes) >>> print(output) [0 1 2 0 1 2 3]
- Tutorial Examples: