mindspore.mint.index_select

查看源文件
mindspore.mint.index_select(input, dim, index)[源代码]

返回一个新的tensor,该tensor沿维度 dimindex 中给定的索引对 input 进行选择。

返回的tensor和输入 input 的维度数量相同,其第 dim 维度的大小和 index 的长度相同;其他维度和 input 相同。

说明

index 的值必须在 [0, input.shape[dim]) 范围内,超出该范围结果未定义。

参数:
  • input (Tensor) - 输入tensor。

  • dim (int) - 根据索引进行选择的维度。

  • index (Tensor) - 包含索引的一维tensor。

返回:

Tensor,数据类型与输入 input 相同。

异常:
  • TypeError - inputindex 的类型不是Tensor。

  • TypeError - dim 的类型不是int。

  • ValueError - dim 值超出范围[-input.ndim, input.ndim - 1]。

  • ValueError - index 不是一维tensor。

支持平台:

Ascend

样例:

>>> import mindspore
>>> from mindspore import Tensor, mint
>>> import numpy as np
>>> input = Tensor(np.arange(16).astype(np.float32).reshape(2, 2, 4))
>>> print(input)
[[[ 0.  1.  2.  3.]
[ 4.  5.  6.  7.]]
[[ 8.  9. 10. 11.]
[12. 13. 14. 15.]]]
>>> index = Tensor([0,], mindspore.int32)
>>> y = mint.index_select(input, 1, index)
>>> print(y)
[[[ 0.  1.  2.  3.]]
[[ 8.  9. 10. 11.]]]