mindspore.ops.IndexFill

查看源文件
class mindspore.ops.IndexFill[源代码]

index 中给定的顺序选择索引,将输入Tensor xdim 维度下的元素用 value 的值填充。

警告

这是一个实验性API,后续可能修改或删除。

更多参考详见 mindspore.ops.index_fill()

输入:
  • x (Tensor) - 输入Tensor。

  • dim (Union[int, Tensor]) - 填充输入Tensor的维度,要求是一个int或者数据类型为int32或int64的零维Tensor。

  • index (Tensor) - 填充输入Tensor的索引,数据类型为int32。

  • value (Union[bool, int, float, Tensor]) - 填充输入Tensor的值。

输出:

填充后的Tensor。shape和数据类型与输入 x 相同。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> index_fill = ops.IndexFill()
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32))
>>> index = Tensor([0, 2], mindspore.int32)
>>> value = Tensor(-2.0, mindspore.float32)
>>> y = index_fill(x, 1, index, value)
>>> print(y)
[[-2. 2. -2.]
 [-2. 5. -2.]
 [-2. 8. -2.]]