Source code for mindspore.nn.probability.bnn_layers.dense_variational

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dense_variational"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator
from ...cell import Cell
from ...layer.activation import get_activation
from ..distribution.normal import Normal
from .layer_distribution import NormalPrior, normal_post_fn
from ._util import check_prior, check_posterior

__all__ = ['DenseReparam', 'DenseLocalReparam']


class _DenseVariational(Cell):
    """
    Base class for all dense variational layers.
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            activation=None,
            has_bias=True,
            weight_prior_fn=NormalPrior,
            weight_posterior_fn=normal_post_fn,
            bias_prior_fn=NormalPrior,
            bias_posterior_fn=normal_post_fn):
        super(_DenseVariational, self).__init__()
        self.in_channels = Validator.check_positive_int(in_channels)
        self.out_channels = Validator.check_positive_int(out_channels)
        self.has_bias = Validator.check_bool(has_bias)

        self.weight_prior = check_prior(weight_prior_fn, "weight_prior_fn")
        self.weight_posterior = check_posterior(weight_posterior_fn, shape=[self.out_channels, self.in_channels],
                                                param_name='bnn_weight', arg_name="weight_posterior_fn")

        if self.has_bias:
            self.bias_prior = check_prior(bias_prior_fn, "bias_prior_fn")
            self.bias_posterior = check_posterior(bias_posterior_fn, shape=[self.out_channels], param_name='bnn_bias',
                                                  arg_name="bias_posterior_fn")

        self.activation = activation
        if not self.activation:
            self.activation_flag = False
        else:
            self.activation_flag = True
            if isinstance(self.activation, str):
                self.activation = get_activation(activation)
            elif isinstance(self.activation, Cell):
                self.activation = activation
            else:
                raise ValueError('The type of `activation` is wrong.')

        self.matmul = P.MatMul(transpose_b=True)
        self.bias_add = P.BiasAdd()
        self.sum = P.ReduceSum()

    def construct(self, x):
        outputs = self.apply_variational_weight(x)
        if self.has_bias:
            outputs = self.apply_variational_bias(outputs)
        if self.activation_flag:
            outputs = self.activation(outputs)
        return outputs

    def extend_repr(self):
        s = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
            .format(self.in_channels, self.out_channels, self.weight_posterior.mean,
                    self.weight_posterior.untransformed_std, self.has_bias)
        if self.has_bias:
            s += ', bias_mean={}, bias_std={}' \
                .format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)
        if self.activation_flag:
            s += ', activation={}'.format(self.activation)
        return s

    def apply_variational_bias(self, inputs):
        bias_posterior_tensor = self.bias_posterior("sample")
        return self.bias_add(inputs, bias_posterior_tensor)

    def compute_kl_loss(self):
        """Compute kl loss"""
        weight_args_list = self.weight_posterior("get_dist_args")
        weight_type = self.weight_posterior("get_dist_type")

        kl = self.weight_prior("kl_loss", weight_type, *weight_args_list)
        kl_loss = self.sum(kl)
        if self.has_bias:
            bias_args_list = self.bias_posterior("get_dist_args")
            bias_type = self.bias_posterior("get_dist_type")

            kl = self.bias_prior("kl_loss", bias_type, *bias_args_list)
            kl = self.sum(kl)
            kl_loss += kl
        return kl_loss


[docs]class DenseReparam(_DenseVariational): r""" Dense variational layers with Reparameterization. For more details, refer to the paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_. Applies dense-connected layer to the input. This layer implements the operation as: .. math:: \text{outputs} = \text{activation}(\text{inputs} * \text{weight} + \text{bias}), where :math:`\text{activation}` is the activation function passed as the activation argument (if passed in), :math:`\text{activation}` is a weight matrix with the same data type as the inputs created by the layer, :math:`\text{weight}` is a weight matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a bias vector with the same data type as the inputs created by the layer (only if has_bias is True). The bias vector is sampling from posterior distribution of :math:`\text{bias}`. Args: in_channels (int): The number of input channel. out_channels (int): The number of output channel . has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. activation (str, Cell): A regularization function applied to the output of the layer. The type of `activation` can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must be instantiated beforehand. Default: None. weight_prior_fn: The prior distribution for weight. It must return a mindspore distribution instance. Default: NormalPrior. (which creates an instance of standard normal distribution). The current version only supports normal distribution. weight_posterior_fn: The posterior distribution for sampling weight. It must be a function handle which returns a mindspore distribution instance. Default: normal_post_fn. The current version only supports normal distribution. bias_prior_fn: The prior distribution for bias vector. It must return a mindspore distribution. Default: NormalPrior(which creates an instance of standard normal distribution). The current version only supports normal distribution. bias_posterior_fn: The posterior distribution for sampling bias vector. It must be a function handle which returns a mindspore distribution instance. Default: normal_post_fn. The current version only supports normal distribution. Inputs: - **input** (Tensor) - The shape of the tensor is :math:`(N, in\_channels)`. Outputs: Tensor, the shape of the tensor is :math:`(N, out\_channels)`. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> net = DenseReparam(3, 4) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> output = net(input).shape >>> print(output) (2, 4) """ def __init__( self, in_channels, out_channels, activation=None, has_bias=True, weight_prior_fn=NormalPrior, weight_posterior_fn=normal_post_fn, bias_prior_fn=NormalPrior, bias_posterior_fn=normal_post_fn): super(DenseReparam, self).__init__( in_channels, out_channels, activation=activation, has_bias=has_bias, weight_prior_fn=weight_prior_fn, weight_posterior_fn=weight_posterior_fn, bias_prior_fn=bias_prior_fn, bias_posterior_fn=bias_posterior_fn ) def apply_variational_weight(self, inputs): weight_posterior_tensor = self.weight_posterior("sample") outputs = self.matmul(inputs, weight_posterior_tensor) return outputs
[docs]class DenseLocalReparam(_DenseVariational): r""" Dense variational layers with Local Reparameterization. For more details, refer to the paper `Variational Dropout and the Local Reparameterization Trick <https://arxiv.org/abs/1506.02557>`_. Applies dense-connected layer to the input. This layer implements the operation as: .. math:: \text{outputs} = \text{activation}(\text{inputs} * \text{weight} + \text{bias}), where :math:`\text{activation}` is the activation function passed as the activation argument (if passed in), :math:`\text{activation}` is a weight matrix with the same data type as the inputs created by the layer, :math:`\text{weight}` is a weight matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a bias vector with the same data type as the inputs created by the layer (only if has_bias is True). The bias vector is sampling from posterior distribution of :math:`\text{bias}`. Args: in_channels (int): The number of input channel. out_channels (int): The number of output channel . has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. activation (str, Cell): A regularization function applied to the output of the layer. The type of `activation` can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must be instantiated beforehand. Default: None. weight_prior_fn: The prior distribution for weight. It must return a mindspore distribution instance. Default: NormalPrior. (which creates an instance of standard normal distribution). The current version only supports normal distribution. weight_posterior_fn: The posterior distribution for sampling weight. It must be a function handle which returns a mindspore distribution instance. Default: normal_post_fn. The current version only supports normal distribution. bias_prior_fn: The prior distribution for bias vector. It must return a mindspore distribution. Default: NormalPrior(which creates an instance of standard normal distribution). The current version only supports normal distribution. bias_posterior_fn: The posterior distribution for sampling bias vector. It must be a function handle which returns a mindspore distribution instance. Default: normal_post_fn. The current version only supports normal distribution. Inputs: - **input** (Tensor) - The shape of the tensor is :math:`(N, in\_channels)`. Outputs: Tensor, the shape of the tensor is :math:`(N, out\_channels)`. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> net = DenseLocalReparam(3, 4) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> output = net(input).shape >>> print(output) (2, 4) """ def __init__( self, in_channels, out_channels, activation=None, has_bias=True, weight_prior_fn=NormalPrior, weight_posterior_fn=normal_post_fn, bias_prior_fn=NormalPrior, bias_posterior_fn=normal_post_fn): super(DenseLocalReparam, self).__init__( in_channels, out_channels, activation=activation, has_bias=has_bias, weight_prior_fn=weight_prior_fn, weight_posterior_fn=weight_posterior_fn, bias_prior_fn=bias_prior_fn, bias_posterior_fn=bias_posterior_fn ) self.sqrt = P.Sqrt() self.square = P.Square() self.normal = Normal() def apply_variational_weight(self, inputs): mean = self.matmul(inputs, self.weight_posterior("mean")) std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd")))) weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std) return weight_posterior_affine_tensor