比较与torch.scatter的差异

查看源文件

以下映射关系均可参考本文。

PyTorch APIs

MindSpore APIs

torch.scatter

mindspore.ops.scatter

torch.Tensor.scatter

mindspore.Tensor.scatter

torch.scatter

torch.scatter(input, dim, index, src)

更多内容详见torch.scatter

mindspore.ops.scatter

mindspore.ops.scatter(input, axis, index, src)

更多内容详见mindspore.ops.scatter

差异对比

MindSpore此API功能与PyTorch不一致。

PyTorch:在任意维度 d 上,要求 index.size(d) <= src.size(d) ,即 index 可以选择 src 的部分或全部数据分散到 input 里。

MindSpore: index 的shape必须和 src 的shape一致,即 src 的所有数据都会被 index 分散到 input 里。

分类

子类

PyTorch

MindSpore

差异

参数

参数 1

input

input

一致

参数 2

dim

axis

参数名不一致

参数 3

index

index

MindSpore的 index 的shape必须和 src 的shape一致,PyTorch要求在任意维度 d 上, index.size(d) <= src.size(d)

参数 4

src

src

一致

代码示例 1

src 的部分数据进行scatter操作。

# PyTorch
import torch
import numpy as np
input = torch.tensor(np.zeros((5, 5)), dtype=torch.float32)
src = torch.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=torch.float32)
index = torch.tensor(np.array([[0, 1], [0, 1], [0, 1]]), dtype=torch.int64)
out = torch.scatter(input=input, dim=1, index=index, src=src)
print(out)
# tensor([[1., 2., 0., 0., 0.],
#         [4., 5., 0., 0., 0.],
#         [7., 8., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])
# MindSpore不支持对部分数据进行scatter操作。

代码示例 2

src 的全部数据进行scatter操作。

import torch
import numpy as np
input = torch.tensor(np.zeros((5, 5)), dtype=torch.float32)
src = torch.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=torch.float32)
index = torch.tensor(np.array([[0, 1], [0, 1], [0, 1]]), dtype=torch.int64)
out = torch.scatter(input=input, dim=1, index=index, src=src)
print(out)
# tensor([[1., 2., 3., 0., 0.],
#         [4., 5., 6., 0., 0.],
#         [7., 8., 9., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])

# MindSpore
import mindspore as ms
import numpy as np
input = ms.Tensor(np.zeros((5, 5)), dtype=ms.float32)
src = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
index = ms.Tensor(np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]), dtype=ms.int64)
out = ms.ops.scatter(input=input, axis=1, index=index, src=src)
print(out)
# [[1. 2. 3. 0. 0.]
#  [4. 5. 6. 0. 0.]
#  [7. 8. 9. 0. 0.]
#  [0. 0. 0. 0. 0.]
#  [0. 0. 0. 0. 0.]]