Source code for mindspore.nn.probability.distribution.transformed_distribution
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformed Distribution"""
import numpy as np
from mindspore import _checkparam as validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import raise_not_impl_error
from ._utils.custom_ops import exp_generic, log_generic
[docs]class TransformedDistribution(Distribution):
    """
    Transformed Distribution.
    This class contains a bijector and a distribution and transforms the original distribution
    to a new distribution through the operation defined by the bijector.
    If :math:`X` is an random variable following the underying distribution,
    and :math:`g(x)` is a function represented by the bijector,
    then :math:`Y = g(X)` is a random variable following the transformed distribution.
    Args:
        bijector (Bijector): The transformation to perform.
        distribution (Distribution): The original distribution. Must be a float dtype.
        seed (int): The seed is used in sampling. The global seed is used if it is None. Default: ``None`` .
          If this seed is given when a TransformedDistribution object is initialized, the object's sampling function
          will use this seed; elsewise, the underlying distribution's seed will be used.
        name (str): The name of the transformed distribution. Default: ``'transformed_distribution'`` .
    Note:
        The arguments used to initialize the original distribution cannot be None.
        For example, mynormal = msd.Normal(dtype=mindspore.float32) cannot be used to initialized a
        TransformedDistribution since `mean` and `sd` are not specified.
        `batch_shape` is the batch_shape of the original distribution.
        `broadcast_shape` is the broadcast shape between the original distribution and bijector.
        `is_scalar_batch` is only true if both the original distribution and the bijector are scalar batches.
        `default_parameters`, `parameter_names` and `parameter_type` are set to be consistent with the original
        distribution. Derived class can overwrite `default_parameters` and `parameter_names` by calling
        `reset_parameters` followed by `add_parameter`.
    Raises:
        TypeError: When the input `bijector` is not a Bijector instance.
        TypeError: When the input `distribution` is not a Distribution instance.
    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``
    Examples:
        >>> import numpy as np
        >>> import mindspore
        >>> import mindspore.nn as nn
        >>> import mindspore.nn.probability.distribution as msd
        >>> import mindspore.nn.probability.bijector as msb
        >>> from mindspore import Tensor
        >>> class Net(nn.Cell):
        ...     def __init__(self, shape, dtype=mindspore.float32, seed=0, name='transformed_distribution'):
        ...         super(Net, self).__init__()
        ...         # create TransformedDistribution distribution
        ...         self.exp = msb.Exp()
        ...         self.normal = msd.Normal(0.0, 1.0, dtype=dtype)
        ...         self.lognormal = msd.TransformedDistribution(self.exp, self.normal, seed=seed, name=name)
        ...         self.shape = shape
        ...
        ...     def construct(self, value):
        ...         cdf = self.lognormal.cdf(value)
        ...         sample = self.lognormal.sample(self.shape)
        ...         return cdf, sample
        >>> shape = (2, 3)
        >>> net = Net(shape=shape, name="LogNormal")
        >>> x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
        >>> tx = Tensor(x, dtype=mindspore.float32)
        >>> cdf, sample = net(tx)
        >>> print(sample.shape)
        (2, 3)
    """
    def __init__(self,
                 bijector,
                 distribution,
                 seed=None,
                 name="transformed_distribution"):
        """
        Constructor of transformed_distribution class.
        """
        param = dict(locals())
        validator.check_value_type('bijector', bijector,
                                   [nn.probability.bijector.Bijector], type(self).__name__)
        validator.check_value_type('distribution', distribution,
                                   [Distribution], type(self).__name__)
        validator.check_type_name(
            "dtype", distribution.dtype, mstype.float_type, type(self).__name__)
        super(TransformedDistribution, self).__init__(
            seed, distribution.dtype, name, param)
        self._bijector = bijector
        self._distribution = distribution
        # set attributes
        self._is_linear_transformation = self.bijector.is_constant_jacobian
        self._dtype = self.distribution.dtype
        self._is_scalar_batch = self.distribution.is_scalar_batch and self.bijector.is_scalar_batch
        self._batch_shape = self.distribution.batch_shape
        self.default_parameters = self.distribution.default_parameters
        self.parameter_names = self.distribution.parameter_names
        # by default, set the parameter_type to be the distribution's parameter_type
        self.parameter_type = self.distribution.parameter_type
        self.exp = exp_generic
        self.log = log_generic
        self.isnan = P.IsNan()
        self.cast_base = P.Cast()
        self.equal_base = P.Equal()
        self.select_base = P.Select()
        # broadcast bijector batch_shape and distribution batch_shape
        self._broadcast_shape = self._broadcast_bijector_dist()
    @property
    def bijector(self):
        """
        Return the bijector of the transformed distribution.
        Output:
            Bijector, the bijector of the transformed distribution.
        """
        return self._bijector
    @property
    def distribution(self):
        """
        Return the underlying distribution of the transformed distribution.
        Output:
            Bijector, the underlying distribution of the transformed distribution.
        """
        return self._distribution
    @property
    def dtype(self):
        """
        Return the dtype of the transformed distribution.
        Output:
            Mindspore.dtype, the dtype of the transformed distribution.
        """
        return self._dtype
    @property
    def is_linear_transformation(self):
        """
        Return whether the transformation is linear.
        Output:
            Bool, true if the transformation is linear, and false otherwise.
        """
        return self._is_linear_transformation
    def _broadcast_bijector_dist(self):
        """
        check if the batch shape of base distribution and the bijector is broadcastable.
        """
        if self.batch_shape is None or self.bijector.batch_shape is None:
            return None
        bijector_shape_tensor = F.fill(
            self.dtype, self.bijector.batch_shape, 0.0)
        dist_shape_tensor = F.fill(self.dtype, self.batch_shape, 0.0)
        return (bijector_shape_tensor + dist_shape_tensor).shape
    def _cdf(self, value, *args, **kwargs):
        r"""
        .. math::
            Y = g(X)
            P(Y <= a) = P(X <= g^{-1}(a))
        """
        inverse_value = self.bijector("inverse", value)
        return self.distribution("cdf", inverse_value, *args, **kwargs)
    def _log_cdf(self, value, *args, **kwargs):
        return self.log(self._cdf(value, *args, **kwargs))
    def _survival_function(self, value, *args, **kwargs):
        return 1.0 - self._cdf(value, *args, **kwargs)
    def _log_survival(self, value, *args, **kwargs):
        return self.log(self._survival_function(value, *args, **kwargs))
    def _log_prob(self, value, *args, **kwargs):
        r"""
        .. math::
            Y = g(X)
            Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
            \log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
        """
        inverse_value = self.bijector("inverse", value)
        unadjust_prob = self.distribution(
            "log_prob", inverse_value, *args, **kwargs)
        log_jacobian = self.bijector("inverse_log_jacobian", value)
        isneginf = self.equal_base(unadjust_prob, -np.inf)
        isnan = self.equal_base(unadjust_prob + log_jacobian, np.nan)
        return self.select_base(isneginf,
                                self.select_base(
                                    isnan, unadjust_prob + log_jacobian, unadjust_prob),
                                unadjust_prob + log_jacobian)
    def _prob(self, value, *args, **kwargs):
        return self.exp(self._log_prob(value, *args, **kwargs))
    def _sample(self, *args, **kwargs):
        org_sample = self.distribution("sample", *args, **kwargs)
        return self.bijector("forward", org_sample)
    def _mean(self, *args, **kwargs):
        """
        Note:
            This function maybe overridden by derived class.
        """
        if not self.is_linear_transformation:
            raise_not_impl_error("mean")
        return self.bijector("forward", self.distribution("mean", *args, **kwargs))