Source code for mindspore.nn.probability.bnn_layers.dense_variational

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dense_variational"""
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool
from ...cell import Cell
from ...layer.activation import get_activation
from .layer_distribution import NormalPrior, NormalPosterior

__all__ = ['DenseReparam']


class _DenseVariational(Cell):
    """
    Base class for all dense variational layers.
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            activation=None,
            has_bias=True,
            weight_prior_fn=NormalPrior,
            weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
            bias_prior_fn=NormalPrior,
            bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
        super(_DenseVariational, self).__init__()
        self.in_channels = check_int_positive(in_channels)
        self.out_channels = check_int_positive(out_channels)
        self.has_bias = check_bool(has_bias)

        if isinstance(weight_prior_fn, Cell):
            self.weight_prior = weight_prior_fn
        else:
            self.weight_prior = weight_prior_fn()
        for prior_name, prior_dist in self.weight_prior.name_cells().items():
            if prior_name != 'normal':
                raise TypeError("The type of distribution of `weight_prior_fn` should be `normal`")
            if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and
                    isinstance(getattr(prior_dist, '_sd_value'), Tensor)):
                raise TypeError("The input form of `weight_prior_fn` is incorrect")

        try:
            self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
        except TypeError:
            raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
        for posterior_name, _ in self.weight_posterior.name_cells().items():
            if posterior_name != 'normal':
                raise TypeError("The type of distribution of `weight_posterior_fn` should be `normal`")

        if self.has_bias:
            if isinstance(bias_prior_fn, Cell):
                self.bias_prior = bias_prior_fn
            else:
                self.bias_prior = bias_prior_fn()
            for prior_name, prior_dist in self.bias_prior.name_cells().items():
                if prior_name != 'normal':
                    raise TypeError("The type of distribution of `bias_prior_fn` should be `normal`")
                if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and
                        isinstance(getattr(prior_dist, '_sd_value'), Tensor)):
                    raise TypeError("The input form of `bias_prior_fn` is incorrect")

            try:
                self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
            except TypeError:
                raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
            for posterior_name, _ in self.bias_posterior.name_cells().items():
                if posterior_name != 'normal':
                    raise TypeError("The type of distribution of `bias_posterior_fn` should be `normal`")

        self.activation = activation
        if not self.activation:
            self.activation_flag = False
        else:
            self.activation_flag = True
            if isinstance(self.activation, str):
                self.activation = get_activation(activation)
            elif isinstance(self.activation, Cell):
                self.activation = activation
            else:
                raise ValueError('The type of `activation` is wrong.')

        self.matmul = P.MatMul(transpose_b=True)
        self.bias_add = P.BiasAdd()
        self.sum = P.ReduceSum()

    def construct(self, x):
        outputs = self._apply_variational_weight(x)
        if self.has_bias:
            outputs = self._apply_variational_bias(outputs)
        if self.activation_flag:
            outputs = self.activation(outputs)
        return outputs

    def extend_repr(self):
        str_info = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
            .format(self.in_channels, self.out_channels, self.weight_posterior.mean,
                    self.weight_posterior.untransformed_std, self.has_bias)
        if self.has_bias:
            str_info = str_info + ', bias_mean={}, bias_std={}' \
                .format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)

        if self.activation_flag:
            str_info = str_info + ', activation={}'.format(self.activation)
        return str_info

    def _apply_variational_bias(self, inputs):
        bias_posterior_tensor = self.bias_posterior("sample")
        return self.bias_add(inputs, bias_posterior_tensor)

    def compute_kl_loss(self):
        """Compute kl loss."""
        weight_post_mean = self.weight_posterior("mean")
        weight_post_sd = self.weight_posterior("sd")

        kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd)
        kl_loss = self.sum(kl)
        if self.has_bias:
            bias_post_mean = self.bias_posterior("mean")
            bias_post_sd = self.bias_posterior("sd")

            kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd)
            kl = self.sum(kl)
            kl_loss += kl
        return kl_loss


[docs]class DenseReparam(_DenseVariational): r""" Dense variational layers with Reparameterization. See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_. Applies dense-connected layer for the input. This layer implements the operation as: .. math:: \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}), where :math:`\text{activation}` is the activation function passed as the activation argument (if passed in), :math:`\text{activation}` is a weight matrix with the same data type as the inputs created by the layer, :math:`\text{weight}` is a weight matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a bias vector with the same data type as the inputs created by the layer (only if has_bias is True). The bias vector is sampling from posterior distribution of :math:`\text{bias}`. Args: in_channels (int): The number of input channel. out_channels (int): The number of output channel . has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. activation (str, Cell): Regularizer function applied to the output of the layer. The type of activation can be str (eg. 'relu') or Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must have been instantiated. Default: None. weight_prior_fn: prior distribution for weight. It should return a mindspore distribution instance. Default: NormalPrior. (which creates an instance of standard normal distribution). The current version only supports normal distribution. weight_posterior_fn: posterior distribution for sampling weight. It should be a function handle which returns a mindspore distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). The current version only supports normal distribution. bias_prior_fn: prior distribution for bias vector. It should return a mindspore distribution. Default: NormalPrior(which creates an instance of standard normal distribution). The current version only supports normal distribution. bias_posterior_fn: posterior distribution for sampling bias vector. It should be a function handle which returns a mindspore distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). The current version only supports normal distribution. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. Outputs: Tensor of shape :math:`(N, out\_channels)`. Examples: >>> net = DenseReparam(3, 4) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> net(input).shape (2, 4) """ def __init__( self, in_channels, out_channels, activation=None, has_bias=True, weight_prior_fn=NormalPrior, weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), bias_prior_fn=NormalPrior, bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): super(DenseReparam, self).__init__( in_channels, out_channels, activation=activation, has_bias=has_bias, weight_prior_fn=weight_prior_fn, weight_posterior_fn=weight_posterior_fn, bias_prior_fn=bias_prior_fn, bias_posterior_fn=bias_posterior_fn ) def _apply_variational_weight(self, inputs): weight_posterior_tensor = self.weight_posterior("sample") outputs = self.matmul(inputs, weight_posterior_tensor) return outputs