比较与torch.scatter_add的差异

查看源文件

torch.scatter_add

torch.scatter_add(input, dim, index, src)

更多内容详见torch.scatter_add

mindspore.ops.tensor_scatter_elements

mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis, reduction)

更多内容详见mindspore.ops.tensor_scatter_elements

差异对比

PyTorch:在任意维度 d 上,要求 index.size(d) <= src.size(d) ,即 index 可以选择 src 的部分或全部数据分散到 input 里。

MindSpore: indices 的shape必须和 updates 的shape一致,即 updates 的所有数据都会被 indices 分散到 input_x 里。

功能上无差异。

分类

子类

PyTorch

MindSpore

差异

参数

参数 1

input

input_x

功能一致,参数名不同

参数 2

dim

axis

功能一致,参数名不同

参数 3

index

indices

MindSpore的 indices 的shape必须和 updates 的shape一致,PyTorch要求在任意维度 d 上, index.size(d) <= src.size(d)

参数 4

src

updates

功能一致

参数 5

reduction

MindSpore的 reduction 必须设置为 “add”

代码示例

# PyTorch
import torch
import numpy as np
x = torch.tensor(np.zeros((5, 5)), dtype=torch.float32)
src = torch.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=torch.float32)
index = torch.tensor(np.array([[0, 1], [0, 1], [0, 1]]), dtype=torch.int64)
out = torch.scatter_add(x=x, dim=1, index=index, src=src)
print(out)
# tensor([[1., 2., 0., 0., 0.],
#         [4., 5., 0., 0., 0.],
#         [7., 8., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])

# MindSpore
import mindspore as ms
import numpy as np
x = ms.Tensor(np.zeros((5, 5)), dtype=ms.float32)
src = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
index = ms.Tensor(np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]), dtype=ms.int64)
out = ms.ops.tensor_scatter_elements(input_x=x, axis=1, indices=index, updates=src, reduction="add")
print(out)
# [[1. 2. 3. 0. 0.]
#  [4. 5. 6. 0. 0.]
#  [7. 8. 9. 0. 0.]
#  [0. 0. 0. 0. 0.]
#  [0. 0. 0. 0. 0.]]