mindspore.ops.AllGather

class mindspore.ops.AllGather(group=GlobalComm.WORLD_COMM_GROUP)[source]

Gathers tensors from the specified communication group.

Note

The tensors must have the same shape and format in all processes of the collection. The user needs to preset communication environment variables before running the following example, please check the details on the official website of MindSpore.

Parameters

group (str) – The communication group to work on. Default: “GlobalComm.WORLD_COMM_GROUP”.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor. If the number of devices in the group is N, then the shape of output is \((N, x_1, x_2, ..., x_R)\).

Raises
  • TypeError – If group is not a str.

  • ValueError – If the local rank id of the calling process in the group is larger than the group’s rank size.

Supported Platforms:

Ascend GPU

Examples

>>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> import numpy as np
>>> import mindspore.ops as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor, context
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.allgather = ops.AllGather()
...
...     def construct(self, x):
...         return self.allgather(x)
...
>>> input_x = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_x)
>>> print(output)
[[1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1.]]