mindspore.ops.moe_token_unpermute
- mindspore.ops.moe_token_unpermute(permuted_tokens, sorted_indices, probs=None, padded_mode=False, restore_shape=None)[source]
Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.
Warning
It is only supported on Atlas A2 Training Series Products.
sorted_indices must not have duplicate values, otherwise the result is undefined.
- Parameters
permuted_tokens (Tensor) – The tensor of permuted tokens to be unpermuted. The shape is \([num\_tokens * topk, hidden\_size]\) , where num_tokens, topk and hidden_size are positive integers.
sorted_indices (Tensor) – The tensor of sorted indices used to unpermute the tokens. The shape is \([num\_tokens * topk,]\), where num_tokens and topk are positive integers. It only supports the int32 data type.
probs (Tensor, optional) – The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. The shape is \([num\_tokens, topk]\), where num_tokens and topk are positive integers. Default:
None
.padded_mode (bool, optional) – If
True
, indicating the indices are padded to denote selected tokens per expert. Default:False
.restore_shape (Union[tuple[int], list[int]], optional) – The input shape before permutation, only used in padding mode. Default:
None
.
- Returns
Tensor, with the same dtype as permuted_tokens. If padded_mode is
False
, the shape will be [num_tokens, hidden_size]. If padded_mode isTrue
, the shape will be specified by restore_shape.- Raises
TypeError – If permuted_tokens is not a Tensor.
ValueError – Only supported when padded_mode is
False
.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> from mindspore import Tensor, ops >>> permuted_token = Tensor([ ... [1, 1, 1], ... [0, 0, 0], ... [0, 0, 0], ... [3, 3, 3], ... [2, 2, 2], ... [1, 1, 1], ... [2, 2, 2], ... [3, 3, 3]], dtype=mindspore.bfloat16) >>> sorted_indices = Tensor([0, 6, 7, 5, 3, 1, 2, 4], dtype=mindspore.int32) >>> out = ops.moe_token_unpermute(permuted_token, sorted_indices) >>> out.shape (8, 3)