比较与torch.norm的功能差异

查看源文件

torch.norm

torch.norm(
    input,
    p='fro',
    dim=None,
    keepdim=False,
    out=None,
    dtype=None
)

更多内容详见torch.norm

mindspore.nn.Norm

class mindspore.nn.Norm(
    axis=(),
    keep_dims=False
)(input)

更多内容详见mindspore.nn.Norm

使用方式

PyTorch:支持包括L2在内的多种范式。

MindSpore:目前仅支持L2范式。

代码示例

import mindspore
from mindspore import Tensor, nn
import torch
import numpy as np

# In MindSpore, only L2 norm is supported.
net = nn.Norm(axis=0)
input_x = Tensor(np.array([[4, 4, 9, 1], [2, 1, 3, 6]]), mindspore.float32)
output = net(input_x)
print(output)
# Out:
# [4.4721 4.1231 9.4868 6.0828]

# In torch, you can set parameter p to implement the desired norm.
input_x = torch.tensor(np.array([[4, 4, 9, 1], [2, 1, 3, 6]]), dtype=torch.float)
output1 = torch.norm(input_x, dim=0, p=2)
print(output1)
# Out:
# tensor([4.4721, 4.1231, 9.4868, 6.0828])

input_x = torch.tensor(np.array([[4, 4, 9, 1], [2, 1, 3, 6]]), dtype=torch.float)
output2 = torch.norm(input_x, dim=0, p=1)
print(output2)
# Out:
# tensor([6., 5., 12., 7.])