mindspore.Tensor.gather_nd

mindspore.Tensor.gather_nd(indices)[源代码]

按索引从输入Tensor中获取切片。 使用给定的索引从具有指定形状的输入Tensor中搜集切片。 输入Tensor的shape是 \((N,*)\) ,其中 \(*\) 表示任意数量的附加维度。下文中的 input_x 代指输入Tensor本身。 indices 是一个K维的整数张量,假定它的K-1维张量中的每一个元素是输入Tensor的切片,那么有:

\[output[(i_0, ..., i_{K-2})] = input\_x[indices[(i_0, ..., i_{K-2})]]\]

indices 的最后一维不能超过输入Tensor的秩: \(indices.shape[-1] <= input\_x.rank\)

参数:
  • indices (Tensor) - 获取收集元素的索引张量,其数据类型包括:int32,int64。

返回:

Tensor,具有与输入Tensor相同的数据类型,shape维度为 \(indices\_shape[:-1] + input\_x\_shape[indices\_shape[-1]:]\)

异常:
  • ValueError - 如果输入Tensor的shape长度小于 indices 的最后一个维度。

支持平台:

Ascend GPU CPU

样例:

>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> output = input_x.gather_nd(indices)
>>> print(output)
[-0.1  0.5]