mindspore.ops.TopK

class mindspore.ops.TopK(sorted=True)[source]

Finds values and indices of the k largest entries along the last dimension.

Warning

  • If sorted is set to False, it will use the aicpu operator, the performance may be reduced. In addition, due to different memory layout and traversal methods on different platforms, the display order of calculation results may be inconsistent when sorted is False.

If the input_x is a one-dimensional Tensor, finds the k largest entries in the Tensor, and outputs its value and index as a Tensor. values[k] is the k largest item in input_x, and its index is indices [k].

For a multi-dimensional matrix, calculates the first k entries in each row (corresponding vector along the last dimension), therefore:

\[values.shape = indices.shape = input.shape[:-1] + [k].\]

If the two compared elements are the same, the one with the smaller index value is returned first.

Parameters

sorted (bool, optional) – If True, the obtained elements will be sorted by the values in descending order. If False, the obtained elements will not be sorted. Default: True.

Inputs:
  • input_x (Tensor) - Input to be computed, data type must be float16, float32 or int32 on CPU, and float16 or float32 on GPU.

  • k (int) - The number of top elements to be computed along the last dimension, constant input is needed.

Outputs:

A tuple consisting of values and indexes.

  • values (Tensor) - The k largest elements in each slice of the last dimension.

  • indices (Tensor) - The indices of values within the last dimension of input.

Raises
  • TypeError – If sorted is not a bool.

  • TypeError – If input_x is not a Tensor.

  • TypeError – If k is not an int.

  • TypeError – If dtype of input_x is not one of the following: float16, float32 or int32.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import Tensor
>>> from mindspore import ops
>>> import mindspore
>>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
>>> k = 3
>>> values, indices = ops.TopK(sorted=True)(input_x, k)
>>> print((values, indices))
(Tensor(shape=[3], dtype=Float16, value= [ 5.0000e+00,  4.0000e+00,  3.0000e+00]), Tensor(shape=[3],
  dtype=Int32, value= [4, 3, 2]))