mindspore.ops.gumbel_softmax

mindspore.ops.gumbel_softmax(logits, tau=1, hard=False, dim=- 1)[source]

Returns the samples from the Gumbel-Softmax distribution and optionally discretizes. If hard = True, the returned samples will be one-hot, otherwise it will be probability distributions that sum to 1 across dim.

Parameters
  • logits (Tensor) – Unnormalized log probabilities. The data type must be float16 or float32.

  • tau (float) – The scalar temperature, which is a positive number. Default: 1.0.

  • hard (bool) – if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd. Default: False.

  • dim (int) – Dim for softmax to compute. Default: -1.

Returns

Tensor, has the same dtype and shape as logits.

Raises
Supported Platforms:

Ascend GPU CPU

Examples

>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> output = ops.gumbel_softmax(input_x, 1.0, True, -1)
>>> print(output.shape)
(2, 3)