mindspore.ops.Gather

class mindspore.ops.Gather(*args, **kwargs)[source]

Returns a slice of the input tensor based on the specified indices and axis.

Slices the input tensor base on the indices at specified axis. See the following example for more clear.

Inputs:
  • input_params (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). The original Tensor.

  • input_indices (Tensor) - The shape of tensor is \((y_1, y_2, ..., y_S)\). Specifies the indices of elements of the original Tensor. Must be in the range [0, input_param.shape[axis]).

  • axis (int) - Specifies the dimension index to gather indices.

Outputs:

Tensor, the shape of tensor is \(input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]\).

Raises

TypeError – If axis is not an int.

Supported Platforms:

Ascend GPU CPU

Examples

>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> output = ops.Gather()(input_params, input_indices, axis)
>>> print(output)
[[ 2.  7.]
 [ 4. 54.]
 [ 2. 55.]]
>>> axis = 0
>>> output = ops.Gather()(input_params, input_indices, axis)
>>> print(output)
[[3. 4. 54. 22.]
 [2. 2. 55.  3.]]