mindspore.Tensor.index_put

Tensor.index_put(indices, values, accumulate=False)[source]

Returns a Tensor. According to the index number of indices , replace the value corresponding to the “self Tensor” with the value in values.

Parameters
  • indices (tuple[Tensor], list[Tensor]) – the indices of type int32 or int64, used to index into the “self Tensor”. The rank of tensors in indices should be 1-D, size of indices should <= “self Tensor”.rank and the tensors in indices should be broadcastable.

  • values (Tensor) – 1-D Tensor of the same type as “self Tensor”. if size == 1 will be broadcast

  • accumulate (bool) – If accumulate is True, the elements in values are added to “self Tensor”, else the elements in values replace the corresponding element in the “self Tensor”. Default: False.

Returns

Tensor, with the same type and shape as the “self Tensor”.

Raises
  • TypeError – If the dtype of the “self Tensor” is not equal to the dtype of values.

  • TypeError – If the dtype of indices is not tuple[Tensor], list[Tensor].

  • TypeError – If the dtype of tensors in indices are not int32 or int64.

  • TypeError – If the dtype of tensors in indices are inconsistent.

  • TypeError – If the dtype of accumulate is not bool.

  • ValueError – If rank(values) is not 1-D.

  • ValueError – If size(values) is not 1 or max size of the tensors in indices when rank(“self Tensor”) == size(indices).

  • ValueError – If size(values) is not 1 or “self Tensor”.shape[-1] when rank(“self Tensor”) > size(indices).

  • ValueError – If the rank of tensors in indices is not 1-D.

  • ValueError – If the tensors in indices is not be broadcastable.

  • ValueError – If size(indices) > rank(“self Tensor”).

Supported Platforms:

Ascend CPU

Examples

>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
>>> values = Tensor(np.array([3]).astype(np.int32))
>>> indices = [Tensor(np.array([0, 1, 1]).astype(np.int32)), Tensor(np.array([1, 2, 1]).astype(np.int32))]
>>> accumulate = True
>>> output = x.index_put(indices, values, accumulate)
>>> print(output)
[[1 5 3]
[4 8 9]]