mindspore.ops.IndexPut

class mindspore.ops.IndexPut(accumulate=0)[source]

According to the index number of indexes, replace the value corresponding to x1 with the value in x2.

Parameters

accumulate (int) – If accumulate is 1, the elements in x2 are added to x1, else the elements in x2 replace the corresponding element in x1, should be 0 or 1. Default: 0.

Inputs:
  • x1 (Tensor) - The assigned target tensor, 1-D or higher dimensional.

  • x2 (Tensor) - 1-D Tensor of the same type as x1. If the size of x2 is 1, it will broadcast to the same size as x1.

  • indices (tuple[Tensor], list[Tensor]) - the indices of type int32 or int64, used to index into x1.

The rank of tensors in indices should be 1-D, size of indices should <= x1.rank and the tensors in indices should be broadcastable.

Outputs:

Tensor, has the same dtype and shape as x1.

Raises
  • TypeError – If the dtype of x1 is not equal to the dtype of x2.

  • TypeError – If indices is not tuple[Tensor] or list[Tensor].

  • TypeError – If the dtype of tensors in indices are not int32 or int64.

  • TypeError – If the dtype of tensors in indices are inconsistent.

  • TypeError – If the dtype of accumulate are not int.

  • ValueError – If rank(x2) is not 1-D.

  • ValueError – If size(x2) is not 1 or max size of the tensors in indices when rank(x1) == size(indices).

  • ValueError – If size(x2) is not 1 or x1.shape[-1] when rank(x1) > size(indices).

  • ValueError – If the rank of tensors in indices is not 1-D.

  • ValueError – If the tensors in indices is not be broadcastable.

  • ValueError – If size(indices) > rank(x1).

  • ValueError – If accumulate is not equal to 0 or 1.

Supported Platforms:

Ascend CPU

Examples

>>> x1 = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
>>> x2 = Tensor(np.array([3]).astype(np.int32))
>>> indices = [Tensor(np.array([0, 0]).astype(np.int32)), Tensor(np.array([0, 1]).astype(np.int32))]
>>> accumulate = 1
>>> op = ops.IndexPut(accumulate = accumulate)
>>> output = op(x1, x2, indices)
>>> print(output)
 [[4 5 3]
 [4 5 6]]