mindspore.ops.ctc_greedy_decoder

View Source On Gitee
mindspore.ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True)[source]

Performs greedy decoding on the logits given in inputs.

Note

On Ascend, ‘merge_repeated’ can not be set to false.

Parameters
  • inputs (Tensor) – The input Tensor must be a 3-D tensor whose shape is \((max\_time, batch\_size, num\_classes)\). num_classes must be num_labels + 1 classes, num_labels indicates the number of actual labels. Blank labels are reserved. Default blank label is num_classes - 1. Data type must be float32 or float64.

  • sequence_length (Tensor) – A tensor containing sequence lengths with the shape of \((batch\_size, )\). The type must be int32. Each value in the tensor must be equal to or less than max_time.

  • merge_repeated (bool) – If true , merge repeated classes in output. Default: True .

Returns

decoded_indices (Tensor), A tensor with shape of \((total\_decoded\_outputs, 2)\). Data type is int64.

decoded_values (Tensor), A tensor with shape of \((total\_decoded\_outputs, )\), it stores the decoded classes. Data type is int64.

decoded_shape (Tensor), A tensor with shape of \((batch\_size, max\_decoded\_length)\). Data type is int64.

log_probability (Tensor), A tensor with shape of \((batch\_size, 1)\), containing sequence log-probability, has the same type as inputs.

Raises
  • TypeError – If merge_repeated is not a bool.

  • ValueError – If length of shape of inputs is not equal to 3.

  • ValueError – If length of shape of sequence_length is not equal to 1.

  • ValueError – If value in the sequence_length is larger than max_time.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> inputs = Tensor(np.array([[[0.6, 0.4, 0.2], [0.8, 0.6, 0.3]],
...                           [[0.0, 0.6, 0.0], [0.5, 0.4, 0.5]]]), mindspore.float32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> decoded_indices, decoded_values, decoded_shape, log_probability = ops.ctc_greedy_decoder(inputs,
...                                                                                          sequence_length)
>>> print(decoded_indices)
[[0 0]
 [0 1]
 [1 0]]
>>> print(decoded_values)
[0 1 0]
>>> print(decoded_shape)
[2 2]
>>> print(log_probability)
[[-1.2]
 [-1.3]]