mindspore.ops.Multinomial

View Source On Gitee
class mindspore.ops.Multinomial(seed=0, seed2=0, dtype=mstype.int32)[source]

Returns a tensor sampled from the multinomial probability distribution located in the corresponding row of tensor input.

Note

  • The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.

  • Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms, and the random seed determines the initial value of this random number. If the random seed is the same in two separate calls, the random number generated will not change.

  • Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn’t need to worry about which seed is more important.

Parameters
  • seed (int, optional) – The operator-level random seed, used to generate random numbers, must be non-negative. Default: 0 .

  • seed2 (int, optional) – The global random seed, which combines with the operator-level random seed to determine the final generated random number, must be non-negative. Default: 0 .

  • dtype (mindspore.dtype, optional) – The type of output, must be mstype.int32 or mstype.int64. Default: mstype.int32.

Inputs:
  • x (Tensor) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dimensions.

  • num_samples (int) - number of samples to draw, must be a nonnegative number.

Outputs:

Tensor with the same rows as x, each row has num_samples sampled indices.

Raises
  • TypeError – If neither seed nor seed2 is an int.

  • TypeError – If dtype of num_samples is not int.

  • TypeError – If dtype is not mstype.int32 or mstype.int64.

  • ValueError – If seed or seed2 is less than 0.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import Tensor, ops
>>> from mindspore import dtype as mstype
>>> x = Tensor([[0., 9., 4., 0.]], mstype.float32)
>>> multinomial = ops.Multinomial(seed=10)
>>> output = multinomial(x, 2)
>>> print(output)
[[1 1]]