"""Convolutional variational layers."""
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore._checkparam import twice
from ...layer.conv import _Conv
from ...cell import Cell
from .layer_distribution import NormalPrior, NormalPosterior

__all__ = ['ConvReparam']

class _ConvVariational(_Conv):
    Base class for all convolutional variational layers.

    def __init__(self,
                 weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
                 bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
        kernel_size = twice(kernel_size)
        stride = twice(stride)
        dilation = twice(dilation)
        super(_ConvVariational, self).__init__(
        if pad_mode not in ('valid', 'same', 'pad'):
            raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
                             + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')

        # convolution args
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.pad_mode = pad_mode
        self.padding = padding
        self.dilation = dilation = group
        self.has_bias = has_bias

        # distribution trainable parameters
        self.shape = [self.out_channels,
                      self.in_channels //, *self.kernel_size]

        self.weight.requires_grad = False

        if isinstance(weight_prior_fn, Cell):
            self.weight_prior = weight_prior_fn
            self.weight_prior = weight_prior_fn()
        for prior_name, prior_dist in self.weight_prior.name_cells().items():
            if prior_name != 'normal':
                raise TypeError("The type of distribution of `weight_prior_fn` should be `normal`")
            if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and
                    isinstance(getattr(prior_dist, '_sd_value'), Tensor)):
                raise TypeError("The input form of `weight_prior_fn` is incorrect")

            self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight')
        except TypeError:
            raise TypeError('The input form of `weight_posterior_fn` is incorrect')
        for posterior_name, _ in self.weight_posterior.name_cells().items():
            if posterior_name != 'normal':
                raise TypeError("The type of distribution of `weight_posterior_fn` should be `normal`")

        if self.has_bias:
            self.bias.requires_grad = False

            if isinstance(bias_prior_fn, Cell):
                self.bias_prior = bias_prior_fn
                self.bias_prior = bias_prior_fn()
            for prior_name, prior_dist in self.bias_prior.name_cells().items():
                if prior_name != 'normal':
                    raise TypeError("The type of distribution of `bias_prior_fn` should be `normal`")
                if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and
                        isinstance(getattr(prior_dist, '_sd_value'), Tensor)):
                    raise TypeError("The input form of `bias_prior_fn` is incorrect")

                self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
            except TypeError:
                raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
            for posterior_name, _ in self.bias_posterior.name_cells().items():
                if posterior_name != 'normal':
                    raise TypeError("The type of distribution of `bias_posterior_fn` should be `normal`")

        # mindspore operations
        self.bias_add = P.BiasAdd()
        self.conv2d = P.Conv2D(out_channel=self.out_channels,

        self.log = P.Log()
        self.sum = P.ReduceSum()

    def construct(self, inputs):
        outputs = self._apply_variational_weight(inputs)
        if self.has_bias:
            outputs = self._apply_variational_bias(outputs)
        return outputs

    def extend_repr(self):
        str_info = 'in_channels={}, out_channels={}, kernel_size={}, stride={},  pad_mode={}, ' \
                    'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}'\
            .format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
                    self.dilation,, self.weight_posterior.mean, self.weight_posterior.untransformed_std,
        if self.has_bias:
            str_info = str_info + ', bias_mean={}, bias_std={}'\
                .format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)
        return str_info

    def _apply_variational_bias(self, inputs):
        bias_posterior_tensor = self.bias_posterior("sample")
        return self.bias_add(inputs, bias_posterior_tensor)

    def compute_kl_loss(self):
        """Compute kl loss"""
        weight_post_mean = self.weight_posterior("mean")
        weight_post_sd = self.weight_posterior("sd")

        kl = self.weight_prior("kl_loss", "Normal",
                               weight_post_mean, weight_post_sd)
        kl_loss = self.sum(kl)
        if self.has_bias:
            bias_post_mean = self.bias_posterior("mean")
            bias_post_sd = self.bias_posterior("sd")

            kl = self.bias_prior("kl_loss", "Normal",
                                 bias_post_mean, bias_post_sd)
            kl = self.sum(kl)
            kl_loss += kl
        return kl_loss

[docs]class ConvReparam(_ConvVariational): r""" Convolutional variational layers with Reparameterization. For more details, refer to the paper `Auto-Encoding Variational Bayes <>`_. Args: in_channels (int): The number of input channel :math:`C_{in}`. out_channels (int): The number of output channel :math:`C_{out}`. kernel_size (Union[int, tuple[int]]): The data type is an integer or a tuple of 2 integers. The kernel size specifies the height and width of the 2D convolution window. a single integer stands for the value is for both height and width of the kernel. With the `kernel_size` being a tuple of 2 integers, the first value is for the height and the other is the width of the kernel. stride(Union[int, tuple[int]]): The distance of kernel moving, an integer number represents that the height and width of movement are both strides, or a tuple of two integers numbers represents that height and width of movement respectively. Default: 1. pad_mode (str): Specifies the padding mode. The optional values are "same", "valid", and "pad". Default: "same". - same: Adopts the way of completion. Output height and width will be the same as the input. The total number of padding will be calculated for in horizontal and vertical directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the last extra padding will be done from the bottom and the right side. If this mode is set, `padding` must be 0. - valid: Adopts the way of discarding. The possible largest height and width of the output will be returned without padding. Extra pixels will be discarded. If this mode is set, `padding` must be 0. - pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input Tensor borders. `padding` must be greater than or equal to 0. padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input. Default: 0. dilation (Union[int, tuple[int]]): The data type is an integer or a tuple of 2 integers. This parameter specifies the dilation rate of the dilated convolution. If set to be :math:`k > 1`, there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater or equal to 1 and bounded by the height and width of the input. Default: 1. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be divisible by the number of groups. Default: 1. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. weight_prior_fn: The prior distribution for weight. It must return a mindspore distribution instance. Default: NormalPrior. (which creates an instance of standard normal distribution). The current version only supports normal distribution. weight_posterior_fn: The posterior distribution for sampling weight. It must be a function handle which returns a mindspore distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). The current version only supports normal distribution. bias_prior_fn: The prior distribution for bias vector. It must return a mindspore distribution. Default: NormalPrior(which creates an instance of standard normal distribution). The current version only supports normal distribution. bias_posterior_fn: The posterior distribution for sampling bias vector. It must be a function handle which returns a mindspore distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). The current version only supports normal distribution. Inputs: - **input** (Tensor) - The shape of the tensor is :math:`(N, C_{in}, H_{in}, W_{in})`. Outputs: Tensor, with the shape being :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: >>> net = ConvReparam(120, 240, 4, has_bias=False) >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) >>> net(input).shape (1, 240, 1024, 640) """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1, group=1, has_bias=False, weight_prior_fn=NormalPrior, weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), bias_prior_fn=NormalPrior, bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): super(ConvReparam, self).__init__( in_channels, out_channels, kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, dilation=dilation, group=group, has_bias=has_bias, weight_prior_fn=weight_prior_fn, weight_posterior_fn=weight_posterior_fn, bias_prior_fn=bias_prior_fn, bias_posterior_fn=bias_posterior_fn ) def _apply_variational_weight(self, inputs): weight_posterior_tensor = self.weight_posterior("sample") outputs = self.conv2d(inputs, weight_posterior_tensor) return outputs