mindspore.nn.probability.bnn_layers.WithBNNLossCell

查看源文件
class mindspore.nn.probability.bnn_layers.WithBNNLossCell(backbone, loss_fn, dnn_factor=1, bnn_factor=1)[源代码]

为 BNN 生成一个合适的 WithLossCell,用损失函数包装贝叶斯网络。

参数:
  • backbone (Cell) - 目标网络。

  • loss_fn (Cell) - 用于计算损失的损失函数。

  • dnn_factor (int, float) - backbone 的损失系数,由损失函数计算。默认值:1。

  • bnn_factor (int, float) - KL 损失系数,即贝叶斯层的 KL 散度。默认值:1。

输入:
  • data (Tensor) - data 的 shape \((N, \ldots)\)

  • label (Tensor) - label 的 shape \((N, \ldots)\)

输出:

Tensor,任意 shape 的标量 Tensor。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn.probability import bnn_layers
>>> from mindspore import Tensor
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.dense = bnn_layers.DenseReparam(16, 1)
...     def construct(self, x):
...         return self.dense(x)
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> net_with_criterion = bnn_layers.WithBNNLossCell(net, loss_fn)
>>>
>>> batch_size = 2
>>> data = Tensor(np.ones([batch_size, 16]).astype(np.float32) * 0.01)
>>> label = Tensor(np.ones([batch_size, 1]).astype(np.float32))
>>> output = net_with_criterion(data, label)
>>> print(output.shape)
(2,)
property backbone_network

返回backbone_network。

返回:

Cell,backbone_network。