mindspore.nn.probability.distribution.TransformedDistribution

class mindspore.nn.probability.distribution.TransformedDistribution(bijector, distribution, seed=None, name='transformed_distribution')[source]

Transformed Distribution. This class contains a bijector and a distribution and transforms the original distribution to a new distribution through the operation defined by the bijector. If X is an random variable following the underying distribution, and g(x) is a function represented by the bijector, then Y = g(X) is a random variable following the transformed distribution.

Parameters
  • bijector (Bijector) – The transformation to perform.

  • distribution (Distribution) – The original distribution. Must be a float dtype.

  • seed (int) – The seed is used in sampling. The global seed is used if it is None. Default: None. If this seed is given when a TransformedDistribution object is initialized, the object’s sampling function will use this seed; elsewise, the underlying distribution’s seed will be used.

  • name (str) – The name of the transformed distribution. Default: ‘transformed_distribution’.

Inputs and Outputs of APIs:

The accessible APIs of the transformed distribution are defined in the base class, including:

  • prob, log_prob, cdf, log_cdf, survival_function, and log_survival

  • mean

  • sample

For more details of all APIs, including the inputs and outputs of all APIs of the transformed distribution, please refer to mindspore.nn.probability.distribution.Distribution, and examples below.

Supported Platforms:

Ascend GPU

Raises
  • TypeError – When the input bijector is not a Bijector instance.

  • TypeError – When the input distribution is not a Distribution instance.

Note

The arguments used to initialize the original distribution cannot be None. For example, mynormal = msd.Normal(dtype=mindspore.float32) cannot be used to initialized a TransformedDistribution since mean and sd are not specified. batch_shape is the batch_shape of the original distribution. broadcast_shape is the broadcast shape between the original distribution and bijector. is_scalar_batch is only true if both the original distribution and the bijector are scalar batches. default_parameters, parameter_names and parameter_type are set to be consistent with the original distribution. Derived class can overwrite default_parameters and parameter_names by calling reset_parameters followed by add_parameter.

Examples

>>> import numpy as np
>>> import mindspore
>>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd
>>> import mindspore.nn.probability.bijector as msb
>>> from mindspore import Tensor
>>> class Net(nn.Cell):
...     def __init__(self, shape, dtype=mindspore.float32, seed=0, name='transformed_distribution'):
...         super(Net, self).__init__()
...         # create TransformedDistribution distribution
...         self.exp = msb.Exp()
...         self.normal = msd.Normal(0.0, 1.0, dtype=dtype)
...         self.lognormal = msd.TransformedDistribution(self.exp, self.normal, seed=seed, name=name)
...         self.shape = shape
...
...     def construct(self, value):
...         cdf = self.lognormal.cdf(value)
...         sample = self.lognormal.sample(self.shape)
...         return cdf, sample
>>> shape = (2, 3)
>>> net = Net(shape=shape, name="LogNormal")
>>> x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
>>> tx = Tensor(x, dtype=mindspore.float32)
>>> cdf, sample = net(tx)
>>> print(sample.shape)
(2, 3)
property bijector

Return the bijector of the transformed distribution.

Output:

Bijector, the bijector of the transformed distribution.

property distribution

Return the underlying distribution of the transformed distribution.

Output:

Bijector, the underlying distribution of the transformed distribution.

property dtype

Return the dtype of the transformed distribution.

Output:

Mindspore.dtype, the dtype of the transformed distribution.

property is_linear_transformation

Return whether the transformation is linear.

Output:

Bool, true if the transformation is linear, and false otherwise.