mindscience.e3nn.nn.Scatter

class mindscience.e3nn.nn.Scatter(mode='add')[源代码]

易用的 scatter 操作封装:根据索引将源张量的值聚合到目标张量中。

参数:
  • mode (str,可选) - {'add', 'sum', 'div', 'max', 'min', 'mul'},scatter 模式。

    • 'add' 或 'sum':逐元素相加;

    • 'div':逐元素相除;

    • 'max':逐元素取最大值;

    • 'min':逐元素取最小值;

    • 'mul':逐元素相乘。

    默认值:'add'

输入:
  • src (Tensor) - 源张量。

  • index (Tensor) - 待聚合元素的索引,必须为整型张量。

  • out (Tensor,可选) - 目标张量;提供时在该张量上就地执行 scatter。默认值:None

  • dim_size (int,可选) - 当未提供 out 时,自动创建大小为 dim_size 的输出;若未提供,则返回最小尺寸的输出。默认值:None

输出:
  • output (Tensor) - scatter 操作结果的张量。

异常:
  • ValueError - 如果 mode 非法。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindscience.e3nn.nn import Scatter
>>> scatter = Scatter('add')
>>> src = Tensor([[1, 2], [3, 4], [5, 6]], ms.float32)
>>> index = Tensor([0, 0, 1], ms.int32)
>>> out = scatter(src, index)
>>> print(out.shape)
(3,2)