比较与torch.max的功能差异

torch.max

torch.max(
    input,
    dim,
    keepdim=False,
    out=None
)

更多内容详见torch.max

mindspore.ops.ArgMaxWithValue

class mindspore.ops.ArgMaxWithValue(
    axis=0,
    keep_dims=False
)(input_x)

更多内容详见mindspore.ops.ArgMaxWithValue

使用方式

PyTorch:输出为元组(最大值, 最大值的索引)。

MindSpore:输出为元组(最大值的索引, 最大值)。

代码示例

import mindspore as ms
import mindspore.ops as ops
import torch
import numpy as np

# Output tuple(index of max, max).
input_x = ms.Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), ms.float32)
argmax = ops.ArgMaxWithValue()
index, output = argmax(input_x)
print(index)
print(output)
# Out:
# 3
# 0.7

# Output tuple(max, index of max).
input_x = torch.tensor([0.0, 0.4, 0.6, 0.7, 0.1])
output, index = torch.max(input_x, 0)
print(index)
print(output)
# Out:
# tensor(3)
# tensor(0.7000)