Function Differences with torch.take
torch.take
torch.take(input, index)
For more information, see torch.take.
mindspore.Tensor.take
mindspore.Tensor.take(indices, axis=None, mode="clip")
For more information, see mindspore.Tensor.take.
Uasge
The basic function is to get the corresponding element from the input Tensor based on the index passed in.
torch.take first stretches the original Tensor, and then gets the elements according to index, which is set to be smaller than the number of elements in the input Tensor.
The default state of mindspore.Tensor.take (axis=None) also does a ravel operation on the Tensor first, and then returns the elements according to indices. In addition, you can set axis to select elements according to the specified axis. The value of indices can exceed the number of Tensor elements, so you can set a different return strategy by input parameter mode. Please refer to the API notes for details.
Code Example
import mindspore as ms
import numpy as np
a = ms.Tensor([[1, 2, 8],[3, 4, 6]], ms.float32)
indices = ms.Tensor(np.array([1, 10]))
# take(self, indices, axis=None, mode='clip'):
print(a.take(indices))
# [2. 6.]
print(a.take(indices, axis=1))
# [[2. 8.]
# [4. 6.]]
print(a.take(indices, mode="wrap"))
# [2. 4.]
import torch
b = torch.tensor([[1, 2, 8],[3, 4, 6]])
indices = torch.tensor([1, 5])
print(torch.take(b, indices))
# tensor([2, 6])
