比较与torch.nn.GaussianNLLLoss的功能差异

查看源文件

torch.nn.GaussianNLLLoss

class torch.nn.GaussianNLLLoss(
    *,
    full=False,
    eps=1e-06,
    reduction='mean'
)(input, target, var) -> Tensor/Scalar

更多内容详见torch.nn.GaussianNLLLoss

mindspore.nn.GaussianNLLLoss

class mindspore.nn.GaussianNLLLoss(
    *,
    full=False,
    eps=1e-06,
    reduction='mean'
)(logits, labels, var) -> Tensor/Scalar

更多内容详见mindspore.nn.GaussianNLLLoss

差异对比

PyTorch:服从高斯分布的负对数似然损失。

MindSpore:与PyTorch实现同样的功能。如果var中存在小于0的数字,PyTorch会直接报错,而MindSpore则会计算max(var, eps) 之后,将结果传给log进行计算。

分类

子类

PyTorch

MindSpore

差异

参数

参数1

full

full

功能一致

参数2

eps

eps

功能一致

参数3

reduction

reduction

功能一致

输入

输入1

input

logits

功能一致,参数名不同

输入2

target

labels

功能一致,参数名不同

输入3

var

var

功能一致

代码示例

两API实现功能和使用方法基本相同,但PyTorch和MindSpore针对输入 var<0 的情况做了不同处理。

# PyTorch
import torch
from torch import nn
import numpy as np

arr1 = np.arange(8).reshape((4, 2))
arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
logits = torch.tensor(arr1, dtype=torch.float32)
labels = torch.tensor(arr2, dtype=torch.float32)
loss = nn.GaussianNLLLoss(reduction='mean')
var = torch.tensor(np.ones((4, 1)), dtype=torch.float32)
output = loss(logits, labels, var)
# tensor(1.4375)

# 如果var中有小于0的元素,PyTorch会直接报错
var[0] = -1
output2 = loss(logits, labels, var)
# ValueError: var has negative entry/entries

# MindSpore
import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
from mindspore import dtype as mstype

arr1 = np.arange(8).reshape((4, 2))
arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
logits = Tensor(arr1, mstype.float32)
labels = Tensor(arr2, mstype.float32)
loss = nn.GaussianNLLLoss(reduction='mean')
var = Tensor(np.ones((4, 1)), mstype.float32)
output = loss(logits, labels, var)
print(output)
# 1.4374993

# 如果var中有小于0的元素,MindSpore会使用max(var, eps)的结果
var[0] = -1
output2 = loss(logits, labels, var)
print(output2)
# 499999.22