比较与torch.take的功能差异

查看源文件

torch.take

torch.take(input, index)

更多内容详见torch.take

mindspore.Tensor.take

mindspore.Tensor.take(indices, axis=None, mode="clip")

更多内容详见mindspore.Tensor.take

使用方式

基础功能为根据传入的索引从输入Tensor中获取对应的元素。

torch.take首先将原始Tensor拉长,然后根据index获取元素,index设置值需小于输入Tensor的元素数。

mindspore.Tensor.take默认状态下(axis=None)同样先对Tensor做ravel操作,再按照indices返回元素。除此之外,可以通过axis设定按照指定axis选取元素。indices数值可以超出Tensor元素数目,此时可以通过入参mode设置不同的返回策略,具体说明请参考API注释。

代码示例

import mindspore as ms
import numpy as np

a = ms.Tensor([[1, 2, 8],[3, 4, 6]], ms.float32)
indices = ms.Tensor(np.array([1, 10]))
# take(self, indices, axis=None, mode='clip'):
print(a.take(indices))
# [2. 6.]
print(a.take(indices, axis=1))
# [[2. 8.]
#  [4. 6.]]
print(a.take(indices, mode="wrap"))
# [2. 4.]

import torch
b = torch.tensor([[1, 2, 8],[3, 4, 6]])
indices = torch.tensor([1, 5])
print(torch.take(b, indices))
# tensor([2, 6])