Source code for mindspore.nn.probability.distribution.distribution

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""basic"""
from mindspore import context
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
    raise_not_implemented_util
from ._utils.utils import CheckTuple, CheckTensor
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic


[docs]class Distribution(Cell): """ Base class for all mathematical distributions. Args: seed (int): The seed is used in sampling. 0 is used if it is None. dtype (mindspore.dtype): The type of the event samples. name (str): The name of the distribution. param (dict): The parameters used to initialize the distribution. Supported Platforms: ``Ascend`` ``GPU`` Note: Derived class must override operations such as `_mean`, `_prob`, and `_log_prob`. Required arguments, such as `value` for `_prob`, must be passed in through `args` or `kwargs`. `dist_spec_args` which specifies a new distribution are optional. `dist_spec_args` is unique for each type of distribution. For example, `mean` and `sd` are the `dist_spec_args` for a Normal distribution, while `rate` is the `dist_spec_args` for an Exponential distribution. For all functions, passing in `dist_spec_args`, is optional. Function calls with the additional `dist_spec_args` passed in will evaluate the result with a new distribution specified by the `dist_spec_args`. However, it will not change the original distribution. """ def __init__(self, seed, dtype, name, param): """ Constructor of distribution class. """ super(Distribution, self).__init__() if seed is None: seed = 0 validator.check_value_type('name', name, [str], type(self).__name__) validator.check_non_negative_int(seed, 'seed', name) self._name = name self._seed = seed self._dtype = cast_type_for_device(dtype) self._parameters = {} # parsing parameters for k in param.keys(): if not(k == 'self' or k.startswith('_')): self._parameters[k] = param[k] # if not a transformed distribution, set the following attribute if 'distribution' not in self.parameters.keys(): self.parameter_type = set_param_type( self.parameters['param_dict'], dtype) self._batch_shape = self._calc_batch_shape() self._is_scalar_batch = self._check_is_scalar_batch() self._broadcast_shape = self._batch_shape # set the function to call according to the derived class's attributes self._set_prob() self._set_log_prob() self._set_sd() self._set_var() self._set_cdf() self._set_survival() self._set_log_cdf() self._set_log_survival() self._set_cross_entropy() self.context_mode = context.get_context('mode') self.device_target = context.get_context('device_target') self.checktuple = CheckTuple() self.checktensor = CheckTensor() self.broadcast = broadcast_to # ops needed for the base class self.cast_base = P.Cast() self.dtype_base = P.DType() self.exp_base = exp_generic self.fill_base = P.Fill() self.log_base = log_generic self.sametypeshape_base = P.SameTypeShape() self.sq_base = P.Square() self.sqrt_base = P.Sqrt() self.shape_base = P.Shape() @property def name(self): return self._name @property def dtype(self): return self._dtype @property def seed(self): return self._seed @property def parameters(self): return self._parameters @property def is_scalar_batch(self): return self._is_scalar_batch @property def batch_shape(self): return self._batch_shape @property def broadcast_shape(self): return self._broadcast_shape def _reset_parameters(self): self.default_parameters = [] self.parameter_names = [] def _add_parameter(self, value, name): """ Cast `value` to a tensor and add it to `self.default_parameters`. Add `name` into and `self.parameter_names`. """ # initialize the attributes if they do not exist yet if not hasattr(self, 'default_parameters'): self.default_parameters = [] self.parameter_names = [] # cast value to a tensor if it is not None value_t = None if value is None else cast_to_tensor( value, self.parameter_type) self.default_parameters += [value_t,] self.parameter_names += [name,] return value_t def _check_param_type(self, *args): """ Check the availability and validity of default parameters and `dist_spec_args`. `dist_spec_args` passed in must be tensors. If default parameters of the distribution are None, the parameters must be passed in through `args`. """ broadcast_shape = None broadcast_shape_tensor = None common_dtype = None out = [] for arg, name, default in zip(args, self.parameter_names, self.default_parameters): # check if the argument is a Tensor if arg is not None: self.checktensor(arg, name) else: arg = default if default is not None else raise_none_error( name) # broadcast if the number of args > 1 if broadcast_shape is None: broadcast_shape = self.shape_base(arg) common_dtype = self.dtype_base(arg) broadcast_shape_tensor = self.fill_base( common_dtype, broadcast_shape, 1.0) else: broadcast_shape = self.shape_base(arg + broadcast_shape_tensor) broadcast_shape_tensor = self.fill_base( common_dtype, broadcast_shape, 1.0) arg = self.broadcast(arg, broadcast_shape_tensor) # check if the arguments have the same dtype self.sametypeshape_base(arg, broadcast_shape_tensor) arg = self.cast_base(arg, self.parameter_type) out.append(arg) if len(out) == 1: return out[0] # broadcast all args to broadcast_shape result = () for arg in out: arg = self.broadcast(arg, broadcast_shape_tensor) result = result + (arg,) return result def _check_value(self, value, name): """ Check availability of `value` as a Tensor. """ if self.context_mode == 0: self.checktensor(value, name) return value return self.checktensor(value, name) def _check_is_scalar_batch(self): """ Check if the parameters used during initialization are scalars. """ param_dict = self.parameters['param_dict'] for value in param_dict.values(): if value is None: continue if not isinstance(value, (int, float)): return False return True def _calc_batch_shape(self): """ Calculate the broadcast shape of the parameters used during initialization. """ broadcast_shape_tensor = None param_dict = self.parameters['param_dict'] for value in param_dict.values(): if value is None: return None if broadcast_shape_tensor is None: broadcast_shape_tensor = cast_to_tensor(value) else: value = cast_to_tensor(value) broadcast_shape_tensor = (value + broadcast_shape_tensor) return broadcast_shape_tensor.shape def _set_prob(self): """ Set probability function based on the availability of `_prob` and `_log_likehood`. """ if hasattr(self, '_prob'): self._call_prob = self._prob elif hasattr(self, '_log_prob'): self._call_prob = self._calc_prob_from_log_prob else: self._call_prob = self._raise_not_implemented_error('prob') def _set_sd(self): """ Set standard deviation based on the availability of `_sd` and `_var`. """ if hasattr(self, '_sd'): self._call_sd = self._sd elif hasattr(self, '_var'): self._call_sd = self._calc_sd_from_var else: self._call_sd = self._raise_not_implemented_error('sd') def _set_var(self): """ Set variance based on the availability of `_sd` and `_var`. """ if hasattr(self, '_var'): self._call_var = self._var elif hasattr(self, '_sd'): self._call_var = self._calc_var_from_sd else: self._call_var = self._raise_not_implemented_error('var') def _set_log_prob(self): """ Set log probability based on the availability of `_prob` and `_log_prob`. """ if hasattr(self, '_log_prob'): self._call_log_prob = self._log_prob elif hasattr(self, '_prob'): self._call_log_prob = self._calc_log_prob_from_prob else: self._call_log_prob = self._raise_not_implemented_error('log_prob') def _set_cdf(self): """ Set cumulative distribution function (cdf) based on the availability of `_cdf` and `_log_cdf` and `survival_functions`. """ if hasattr(self, '_cdf'): self._call_cdf = self._cdf elif hasattr(self, '_log_cdf'): self._call_cdf = self._calc_cdf_from_log_cdf elif hasattr(self, '_survival_function'): self._call_cdf = self._calc_cdf_from_survival elif hasattr(self, '_log_survival'): self._call_cdf = self._calc_cdf_from_log_survival else: self._call_cdf = self._raise_not_implemented_error('cdf') def _set_survival(self): """ Set survival function based on the availability of _survival function and `_log_survival` and `_call_cdf`. """ if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or hasattr(self, '_cdf') or hasattr(self, '_log_cdf')): self._call_survival = self._raise_not_implemented_error( 'survival_function') elif hasattr(self, '_survival_function'): self._call_survival = self._survival_function elif hasattr(self, '_log_survival'): self._call_survival = self._calc_survival_from_log_survival elif hasattr(self, '_call_cdf'): self._call_survival = self._calc_survival_from_call_cdf def _set_log_cdf(self): """ Set log cdf based on the availability of `_log_cdf` and `_call_cdf`. """ if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or hasattr(self, '_survival_function') or hasattr(self, '_log_survival')): self._call_log_cdf = self._raise_not_implemented_error('log_cdf') elif hasattr(self, '_log_cdf'): self._call_log_cdf = self._log_cdf elif hasattr(self, '_call_cdf'): self._call_log_cdf = self._calc_log_cdf_from_call_cdf def _set_log_survival(self): """ Set log survival based on the availability of `_log_survival` and `_call_survival`. """ if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or hasattr(self, '_log_cdf') or hasattr(self, '_cdf')): self._call_log_survival = self._raise_not_implemented_error( 'log_cdf') elif hasattr(self, '_log_survival'): self._call_log_survival = self._log_survival elif hasattr(self, '_call_survival'): self._call_log_survival = self._calc_log_survival_from_call_survival def _set_cross_entropy(self): """ Set log survival based on the availability of `_cross_entropy`. """ if hasattr(self, '_cross_entropy'): self._call_cross_entropy = self._cross_entropy else: self._call_cross_entropy = self._raise_not_implemented_error( 'cross_entropy') def _get_dist_args(self, *args, **kwargs): return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs)
[docs] def get_dist_args(self, *args, **kwargs): """ Check the availability and validity of default parameters and `dist_spec_args`. Args: *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: `dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args` should follow the initialization order of default parameters through `_add_parameter`. If some `dist_spec_args` is None, the corresponding default parameter is returned. """ return self._get_dist_args(*args, **kwargs)
def _get_dist_type(self): return raise_not_implemented_util('get_dist_type', self.name)
[docs] def get_dist_type(self): """ Return the type of the distribution. """ return self._get_dist_type()
def _raise_not_implemented_error(self, func_name): name = self.name def raise_error(*args, **kwargs): return raise_not_implemented_util(func_name, name, *args, **kwargs) return raise_error
[docs] def log_prob(self, value, *args, **kwargs): """ Evaluate the log probability(pdf or pmf) at the given value. Args: value (Tensor): value to be evaluated. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its dist_spec_args through `args` or `kwargs`. """ return self._call_log_prob(value, *args, **kwargs)
def _calc_prob_from_log_prob(self, value, *args, **kwargs): r""" Evaluate prob from log probability. .. math:: probability(x) = \exp(log_likehood(x)) """ return self.exp_base(self._log_prob(value, *args, **kwargs))
[docs] def prob(self, value, *args, **kwargs): """ Evaluate the probability (pdf or pmf) at given value. Args: value (Tensor): value to be evaluated. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its dist_spec_args through `args` or `kwargs`. """ return self._call_prob(value, *args, **kwargs)
def _calc_log_prob_from_prob(self, value, *args, **kwargs): r""" Evaluate log probability from probability. .. math:: log_prob(x) = \log(prob(x)) """ return self.log_base(self._prob(value, *args, **kwargs))
[docs] def cdf(self, value, *args, **kwargs): """ Evaluate the cdf at given value. Args: value (Tensor): value to be evaluated. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its dist_spec_args through `args` or `kwargs`. """ return self._call_cdf(value, *args, **kwargs)
def _calc_cdf_from_log_cdf(self, value, *args, **kwargs): r""" Evaluate cdf from log_cdf. .. math:: cdf(x) = \exp(log_cdf(x)) """ return self.exp_base(self._log_cdf(value, *args, **kwargs)) def _calc_cdf_from_survival(self, value, *args, **kwargs): r""" Evaluate cdf from survival function. .. math:: cdf(x) = 1 - (survival_function(x)) """ return 1.0 - self._survival_function(value, *args, **kwargs) def _calc_cdf_from_log_survival(self, value, *args, **kwargs): r""" Evaluate cdf from log survival function. .. math:: cdf(x) = 1 - (\exp(log_survival(x))) """ return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs))
[docs] def log_cdf(self, value, *args, **kwargs): """ Evaluate the log cdf at given value. Args: value (Tensor): value to be evaluated. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its dist_spec_args through `args` or `kwargs`. """ return self._call_log_cdf(value, *args, **kwargs)
def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs): r""" Evaluate log cdf from cdf. .. math:: log_cdf(x) = \log(cdf(x)) """ return self.log_base(self._call_cdf(value, *args, **kwargs))
[docs] def survival_function(self, value, *args, **kwargs): """ Evaluate the survival function at given value. Args: value (Tensor): value to be evaluated. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its dist_spec_args through `args` or `kwargs`. """ return self._call_survival(value, *args, **kwargs)
def _calc_survival_from_call_cdf(self, value, *args, **kwargs): r""" Evaluate survival function from cdf. .. math:: survival_function(x) = 1 - (cdf(x)) """ return 1.0 - self._call_cdf(value, *args, **kwargs) def _calc_survival_from_log_survival(self, value, *args, **kwargs): r""" Evaluate survival function from log survival function. .. math:: survival(x) = \exp(survival_function(x)) """ return self.exp_base(self._log_survival(value, *args, **kwargs))
[docs] def log_survival(self, value, *args, **kwargs): """ Evaluate the log survival function at given value. Args: value (Tensor): value to be evaluated. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its dist_spec_args through `args` or `kwargs`. """ return self._call_log_survival(value, *args, **kwargs)
def _calc_log_survival_from_call_survival(self, value, *args, **kwargs): r""" Evaluate log survival function from survival function. .. math:: log_survival(x) = \log(survival_function(x)) """ return self.log_base(self._call_survival(value, *args, **kwargs)) def _kl_loss(self, *args, **kwargs): return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs)
[docs] def kl_loss(self, dist, *args, **kwargs): """ Evaluate the KL divergence, i.e. KL(a||b). Args: dist (str): type of the distribution. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. Passing in dist_spec_args of distribution a is optional. """ return self._kl_loss(dist, *args, **kwargs)
def _mean(self, *args, **kwargs): return raise_not_implemented_util('mean', self.name, *args, **kwargs)
[docs] def mean(self, *args, **kwargs): """ Evaluate the mean. Args: *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its *dist_spec_args* through *args* or *kwargs*. """ return self._mean(*args, **kwargs)
def _mode(self, *args, **kwargs): return raise_not_implemented_util('mode', self.name, *args, **kwargs)
[docs] def mode(self, *args, **kwargs): """ Evaluate the mode. Args: *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its *dist_spec_args* through *args* or *kwargs*. """ return self._mode(*args, **kwargs)
[docs] def sd(self, *args, **kwargs): """ Evaluate the standard deviation. Args: *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its *dist_spec_args* through *args* or *kwargs*. """ return self._call_sd(*args, **kwargs)
[docs] def var(self, *args, **kwargs): """ Evaluate the variance. Args: *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its *dist_spec_args* through *args* or *kwargs*. """ return self._call_var(*args, **kwargs)
def _calc_sd_from_var(self, *args, **kwargs): r""" Evaluate log probability from probability. .. math:: STD(x) = \sqrt(VAR(x)) """ return self.sqrt_base(self._var(*args, **kwargs)) def _calc_var_from_sd(self, *args, **kwargs): r""" Evaluate log probability from probability. .. math:: VAR(x) = STD(x) ^ 2 """ return self.sq_base(self._sd(*args, **kwargs)) def _entropy(self, *args, **kwargs): return raise_not_implemented_util('entropy', self.name, *args, **kwargs)
[docs] def entropy(self, *args, **kwargs): """ Evaluate the entropy. Args: *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its *dist_spec_args* through *args* or *kwargs*. """ return self._entropy(*args, **kwargs)
[docs] def cross_entropy(self, dist, *args, **kwargs): """ Evaluate the cross_entropy between distribution a and b. Args: dist (str): type of the distribution. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. Passing in dist_spec_args of distribution a is optional. """ return self._call_cross_entropy(dist, *args, **kwargs)
def _calc_cross_entropy(self, dist, *args, **kwargs): r""" Evaluate cross_entropy from entropy and kl divergence. .. math:: H(X, Y) = H(X) + KL(X||Y) """ return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs) def _sample(self, *args, **kwargs): return raise_not_implemented_util('sample', self.name, *args, **kwargs)
[docs] def sample(self, *args, **kwargs): """ Sampling function. Args: shape (tuple): shape of the sample. *args (list): the list of positional arguments forwarded to subclasses. **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. Note: A distribution can be optionally passed to the function by passing its *dist_spec_args* through *args* or *kwargs*. """ return self._sample(*args, **kwargs)
[docs] def construct(self, name, *args, **kwargs): """ Override `construct` in Cell. Note: Names of supported functions include: 'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival', 'var', 'sd', 'mode', 'mean', 'entropy', 'kl_loss', 'cross_entropy', 'sample', 'get_dist_args', and 'get_dist_type'. Args: name (str): The name of the function. *args (list): A list of positional arguments that the function needs. **kwargs (dictionary): A dictionary of keyword arguments that the function needs. """ if name == 'log_prob': return self._call_log_prob(*args, **kwargs) if name == 'prob': return self._call_prob(*args, **kwargs) if name == 'cdf': return self._call_cdf(*args, **kwargs) if name == 'log_cdf': return self._call_log_cdf(*args, **kwargs) if name == 'survival_function': return self._call_survival(*args, **kwargs) if name == 'log_survival': return self._call_log_survival(*args, **kwargs) if name == 'kl_loss': return self._kl_loss(*args, **kwargs) if name == 'mean': return self._mean(*args, **kwargs) if name == 'mode': return self._mode(*args, **kwargs) if name == 'sd': return self._call_sd(*args, **kwargs) if name == 'var': return self._call_var(*args, **kwargs) if name == 'entropy': return self._entropy(*args, **kwargs) if name == 'cross_entropy': return self._call_cross_entropy(*args, **kwargs) if name == 'sample': return self._sample(*args, **kwargs) if name == 'get_dist_args': return self._get_dist_args(*args, **kwargs) if name == 'get_dist_type': return self._get_dist_type() return raise_not_implemented_util(name, self.name, *args, **kwargs)