# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Note: Mixture of Expert (MoE) structure. This is an experimental interface that is subject to change or deletion.
"""
from __future__ import absolute_import
from __future__ import division
import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.nn.transformer.op_parallel_config import default_moeparallel_config
__all__ = [
    "MoEConfig"]
[docs]class MoEConfig:
    r"""
        The configuration of MoE (Mixture of Expert).
        Args:
            expert_num (int): The number of experts employed. Default: 1
            capacity_factor (float): The factor is used to indicate how much to expand expert capacity,
                which is >=1.0. Default: 1.1.
            aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the
                router) to be added to the entire model loss, which is < 1.0. Default: 0.05.
            num_experts_chosen (int): The number of experts is chosen by each token and it should not be larger
                than expert_num. Default: 1.
            expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is
                effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
        Supported Platforms:
            ``Ascend``
        Examples:
            >>> from mindspore.nn.transformer import MoEConfig
            >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1,
            ...                        expert_group_size=64)
    """
    def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05,
                 num_experts_chosen=1, expert_group_size=None):
        Validator.check_positive_int(expert_num, "expert_num")
        Validator.check_positive_float(capacity_factor, "capacity_factor")
        Validator.check_positive_float(aux_loss_factor, "aux_loss_factor")
        Validator.check_positive_int(num_experts_chosen, "num_experts_chosen")
        if expert_group_size is not None:
            Validator.check_positive_int(expert_group_size, "expert_group_size")
        if capacity_factor < 1.0:
            raise ValueError(f"'capacity_factor' must be equal to or greater than 1.0, "
                             f"but got {capacity_factor}.")
        if aux_loss_factor >= 1.0:
            raise ValueError(f"'aux_loss_factor' must be less than 1.0, "
                             f"but got {aux_loss_factor}.")
        if num_experts_chosen > expert_num:
            raise ValueError(f"'num_experts_chosen' must not be larger than 'expert_num', "
                             f"but got {num_experts_chosen}.")
        self.expert_num = expert_num
        self.capacity_factor = capacity_factor
        self.aux_loss_factor = aux_loss_factor
        self.num_experts_chosen = num_experts_chosen
        self.expert_group_size = expert_group_size 
default_moe_config = MoEConfig()
def _check_moe_config(moe_config=None, parallel_config=None):
    """
        check if MoE with right configuration.
    """
    if not isinstance(moe_config, MoEConfig):
        raise TypeError(f"'moe_config' must be an instance of MoEConfig, but got {type(moe_config).__name__}.")
    use_moe = (moe_config.expert_num > 1)
    if use_moe is False:
        return
    if moe_config.expert_num % parallel_config.expert_parallel != 0:
        raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple "
                         f"of 'expert_parallel' value in {type(parallel_config).__name__}, but got "
                         f"{moe_config.expert_num} for 'expert_num' and {parallel_config.expert_parallel} for "
                         f"'expert_parallel'.")
    device_num = D.get_group_size()
    if device_num % parallel_config.expert_parallel != 0:
        raise ValueError(f"device_num: {device_num} must be a multiple of expert_parallel: "
                         f"{parallel_config.expert_parallel}.")
    if parallel_config.data_parallel % parallel_config.expert_parallel != 0:
        raise ValueError(f"data parallel: {parallel_config.data_parallel} must be a multiple of "
                         f"expert_parallel: {parallel_config.expert_parallel} when using MoE.")
    if parallel_config.data_parallel * parallel_config.model_parallel > device_num:
        raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel} and "
                         f"model parallel: {parallel_config.model_parallel} "
                         f"should be less than device_num: {device_num}.")
@constexpr
def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim):
    return math.ceil(k * tokens_per_group * capacity_factor / expert_dim)
class MoE(Cell):
    """
    The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer.
    The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is
    obtained by multiplying FeedForward's output and router's combine weight.
    Args:
        hidden_size (int): The dimension of the inputs.
        ffn_hidden_size (int): The intermediate hidden size.
        dropout_rate (float): The dropout rate for the second linear's output.
        hidden_act (str): The activation of the internal feedforward layer. Supports 'relu',
                         'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
                         'hsigmoid', 'logsigmoid' and so on. Default: gelu.
        param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
        moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
            default values. Please see `MoEConfig`.
        parallel_config(MoEParallelConfig): The parallel config for MoE, see `MoEParallelConfig`.
            Default `default_moeparallel_config`, an instance of `MoEParallelConfig` with default args.
    Inputs:
        - **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
    Outputs:
        Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
    """
    def __init__(self, hidden_size,
                 ffn_hidden_size,
                 dropout_rate,
                 hidden_act='gelu',
                 param_init_type=mstype.float32,
                 moe_config=default_moe_config,
                 parallel_config=default_moeparallel_config):
        super(MoE, self).__init__()
        self.dp = parallel_config.data_parallel
        self.ep = parallel_config.expert_parallel
        self.hidden_size = hidden_size
        self.expert_dim = moe_config.expert_num
        self.capacity_factor = moe_config.capacity_factor
        self.aux_loss_factor = moe_config.aux_loss_factor
        self.num_experts_chosen = moe_config.num_experts_chosen
        self.dp_group = parallel_config.data_parallel
        from mindspore.nn.transformer import FeedForward
        self.reshape = P.Reshape()
        self.shape = P.Shape()
        self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
        self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),))
        self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
        self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),))
        self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
        self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
        self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
                             training=True, parallel_config=parallel_config)
        self.cast = P.Cast()
        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
            self.expert_group_size = moe_config.expert_group_size
            self.ffn = FeedForward(hidden_size=hidden_size,
                                   ffn_hidden_size=ffn_hidden_size,
                                   dropout_rate=dropout_rate,
                                   hidden_act=hidden_act,
                                   expert_group_size=self.expert_group_size,
                                   expert_num=self.expert_dim,
                                   param_init_type=param_init_type,
                                   parallel_config=parallel_config)
            self.mul = P.Mul()
        else:
            self.ffn = FeedForward(hidden_size=hidden_size,
                                   ffn_hidden_size=ffn_hidden_size,
                                   dropout_rate=dropout_rate,
                                   hidden_act=hidden_act,
                                   expert_num=self.expert_dim,
                                   param_init_type=param_init_type,
                                   parallel_config=parallel_config)
            self.mul = P.Mul().shard(((), ()))
    def construct(self, input_tensor):
        input_shape = F.shape(input_tensor)
        input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
        bs_and_dmodel = self.shape(input_tensor)
        tokens_per_group = bs_and_dmodel[0] // self.dp_group
        input_tensor = self.reshape(input_tensor, (self.dp_group, tokens_per_group, self.hidden_size))
        expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group,
                                                    self.capacity_factor, self.expert_dim)
        # dispatch_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity)
        # combine_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity)
        dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor)
        # after transpose, input_tensor's shape: (self.dp_group, self.hidden_size, tokens_per_group)
        input_tensor = self.transpose_3dim(input_tensor, (0, 2, 1))
        dispatch_tensor = self.reshape(dispatch_tensor, (self.dp_group, tokens_per_group,
                                                         self.expert_dim * expert_capacity))
        dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor))
        # expert_input's shape: (self.dp_group, self.hidden_size, self.expert_dim * expert_capacity)
        expert_input = self.batch_mm(input_tensor, dispatch_tensor)
        expert_input = self.reshape(expert_input, (self.dp_group, self.hidden_size, self.expert_dim,
                                                   expert_capacity))
        # The following four ops are to implement transpose(expert_input, (2, 0, 3, 1)), for that a single transpose
        # has bad performance
        expert_input = self.reshape(expert_input, (self.dp_group * self.hidden_size,
                                                   self.expert_dim * expert_capacity))
        expert_input = self.transpose_2dim(expert_input, (1, 0))
        expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group,
                                                   self.hidden_size))
        # expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size)
        expert_input = self.transpose_4dim_ep(expert_input, (0, 2, 1, 3))
        expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * expert_capacity,
                                                   self.hidden_size))
        # expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size)
        expert_output = self.ffn(expert_input)
        expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group,
                                                     expert_capacity, self.hidden_size))
        # The following five ops are to implement transpose(expert_output, (1, 3, 0, 2)), for that a single transpose
        # has bad performance
        expert_output = self.reshape(expert_output, (self.expert_dim,
                                                     self.dp_group * expert_capacity * self.hidden_size))
        expert_output = self.transpose_2dim_ep(expert_output, (1, 0))
        expert_output = self.reshape(expert_output, (self.dp_group, expert_capacity,
                                                     self.hidden_size * self.expert_dim))
        expert_output = self.transpose_3dim(expert_output, (0, 2, 1))
        # expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
        expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim,
                                                     expert_capacity))
        expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size,
                                                     self.expert_dim * expert_capacity))
        combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group,
                                                       self.expert_dim * expert_capacity))
        # combine_tensor's shape: (self.dp_group, self.expert_dim*expert_capacity, tokens_per_group)
        combine_tensor = self.transpose_3dim(combine_tensor, (0, 2, 1))
        combine_tensor = self.cast(combine_tensor, F.dtype(expert_output))
        # combined_output's shape: (self.dp_group, self.hidden_size, tokens_per_group)
        combined_output = self.batch_mm2(expert_output, combine_tensor)
        # combined_output's shape: (self.dp_group, tokens_per_group, self.hidden_size)
        combined_output = self.transpose_3dim(combined_output, (0, 2, 1))
        combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
        combined_output = self.reshape(combined_output, input_shape)
        aux_loss = self.mul(self.aux_loss_factor, aux_loss)
        return combined_output, aux_loss
class Router(Cell):
    r"""
        A router backbone used to calculate logits of each token, which should be cascaded by router implementations
        mapping tokens to experts.
        when moe_config.num_experts_chosen = 1, use top1 routing;
        when moe_config.num_experts_chosen > 1, use topk routing
        Args:
            d_model (int): The hidden size of each token.
            moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
            routing_policy: The policy of mapping tokens to experts. Default: topkRouter
            training (bool): The value indicating whether is in training phase.
            parallel_config: The parallel-related configuration.
        Inputs:
            - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
            hidden\_size)`.
        Outputs:
            Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
    """
    def __init__(self,
                 d_model,
                 moe_config,
                 routing_policy=None,
                 training=True,
                 parallel_config=None):
        super(Router, self).__init__()
        dp = parallel_config.data_parallel
        self.d_model = d_model
        self.expert_dim = moe_config.expert_num
        self.capacity_factor = moe_config.capacity_factor
        self.num_experts_chosen = moe_config.num_experts_chosen
        self.training = training
        self.routing_policy = routing_policy
        self.noisy_policy = None  # candidate: ["jitter", "rsample", "None"]
        self.noisy_epsilon = 1e-2
        self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,)))
        self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False)
        self.router = routing_policy
        if self.routing_policy is None:
            self.router = TopkRouter(d_model=d_model, moe_config=moe_config, training=training,
                                     parallel_config=parallel_config)
        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
            self.dense.matmul.shard(((dp, 1), (1, 1)))
            self.mul = P.Mul()
            self.cast = P.Cast()
        else:
            self.dense.matmul.shard(((dp, 1), (1, 1)))
            self.mul = P.Mul().shard(((dp, 1, 1), (dp,)))
            self.cast = P.Cast()
    def construct(self, input_tensor):
        input_tensor = self.cast(input_tensor, mstype.float32)
        if self.noisy_policy == "jitter" and self.training:
            # Here, we temporarily implement the multiplicative jitter this way,
            # for the lack of UniforReal parallel operator.
            input_tensor = self.mul(input_tensor, self.noise)
        router_logits = self.dense(input_tensor)
        return self.router(router_logits)
class TopkRouter(Cell):
    r"""
        A router implementation which maps each tokens to the topk expert.
        Args:
            d_model (int): The hidden size of each token.
            moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
            training (bool): The value indicating whether is in training phase.
            config: The parallel-related configuration.
        Inputs:
            - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
            hidden\_size)`.
        Outputs:
            Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
            Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
            Tensor of shape :math:`(1)`.
    """
    def __init__(self,
                 d_model,
                 moe_config,
                 training=True,
                 parallel_config=None):
        super(TopkRouter, self).__init__()
        dp = parallel_config.data_parallel
        self.d_model = d_model
        self.expert_dim = moe_config.expert_num
        self.capacity_factor = moe_config.capacity_factor
        self.training = training
        self.dp_group = dp
        self.noisy_policy = None
        self.cast = P.Cast()
        self.reshape = P.Reshape()
        self.shape = P.Shape()
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.num_experts_chosen = moe_config.num_experts_chosen
        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
            self.softmax = P.Softmax(axis=-1)
            self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False)
            self.onehot = P.OneHot()
            self.onehot2 = P.OneHot()
            self.onehot3 = P.OneHot()
            self.reduce_mean = P.ReduceMean(keep_dims=False)
            self.reduce_mean2 = P.ReduceMean(keep_dims=False)
            self.reduce_mean3 = P.ReduceMean(keep_dims=False)
            self.mul = P.Mul()
            self.mul2 = P.Mul()
            self.mul3 = P.Mul()
            self.mul4 = P.Mul()
            self.mul5 = P.Mul()
            self.mul6 = P.Mul()
            self.mul7 = P.Mul()
            self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
            self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
            self.not_equal = P.NotEqual()
            self.div1 = P.RealDiv()
            self.div2 = P.RealDiv()
            self.add = P.Add()
            self.add1 = P.Add()
            self.add2 = P.Add()
            self.add3 = P.Add()
            self.add4 = P.Add()
            self.sub = P.Sub()
            self.cumsum = P.CumSum(exclusive=True)
            self.less = P.Less()
            self.reduce_sum = P.ReduceSum(keep_dims=False)
            self.reduce_sum_keep = P.ReduceSum(keep_dims=True)
            self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True)
            self.expand = P.ExpandDims()
            self.expand2 = P.ExpandDims()
            self.add_scala = P.Add()
            self.init_loss = Tensor(0.0, mstype.float32)
        else:
            self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),))
            self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),))
            self.onehot = P.OneHot().shard(((dp, 1, 1), (), ()))
            self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ()))
            self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ()))
            self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
            self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
            self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),))
            self.mul = P.Mul().shard(((dp, 1), (dp, 1)))
            self.mul2 = P.Mul().shard(((), ()))
            self.mul3 = P.Mul().shard(((), ()))
            self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
            self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
            self.mul6 = P.Mul().shard(((dp, 1), (dp, 1)))
            self.mul7 = P.Mul().shard(((dp, 1), (dp, 1)))
            self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
            self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
            self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ()))
            self.div1 = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
            self.div2 = P.RealDiv().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
            self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1)))
            self.add1 = P.Add().shard(((dp, 1, 1), ()))
            self.add2 = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
            self.add3 = P.Add().shard(((dp, 1), (dp, 1)))
            self.add4 = P.Add().shard(((dp, 1, 1, 1), ()))
            self.sub = P.Sub().shard(((), (dp, 1, 1)))
            self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),))
            self.less = P.Less().shard(((dp, 1, 1), ()))
            self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),))
            self.reduce_sum_keep = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1),))
            self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1, 1),))
            self.expand = P.ExpandDims().shard(((dp, 1),))
            self.expand2 = P.ExpandDims().shard(((dp, 1, 1),))
            self.add_scala = P.Add().shard(((), ()))
            self.init_loss = Tensor(0.0, mstype.float32)
    def construct(self, router_logits):
        router_logits_shape = self.shape(router_logits)
        router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1]))
        logits_shape = self.shape(router_logits)
        tokens_per_group = logits_shape[0] // self.dp_group
        expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group, self.capacity_factor,
                                                    self.expert_dim)
        router_logits = self.reshape(router_logits, (self.dp_group, tokens_per_group, self.expert_dim))
        accum_expert_mask = 0
        accum_expert_gate = 0
        loss = self.init_loss
        mask_count = 0
        accum_combine_tensor = 0
        # Probabilities for each token of what expert is should be sent to
        router_prob = self.softmax(router_logits)
        for expert_chosen_index in range(self.num_experts_chosen):
            # for each token, set the router_prob of the selected experts to zero
            router_prob = self.mul4(router_prob, self.sub(self.on_value, accum_expert_mask))
            # shape is : (dp_group, tokens_per_group)
            expert_index, expert_gate = self.argmax(router_prob)
            # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim)
            expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value)
            # renormalize the rest prob to be of sum 1
            router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9))
            # the balance loss is computed at each routing step
            loss = self.add_scala(loss, self._auxiliary_loss(expert_mask, router_prob_normal))
            output = self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate,
                                                     mask_count, expert_chosen_index)
            expert_mask, expert_gate, expert_mask_flat, position_in_expert = output[0], output[1], output[2], output[3]
            accum_expert_mask = self.add(accum_expert_mask, expert_mask)
            accum_expert_gate = self.add3(accum_expert_gate, expert_gate)
            mask_count = self.add(mask_count, self.reduce_sum_keep(expert_mask, 1))
            # combine_tensor's shape: (dp_group, tokens_per_group)
            combine_tensor = self.mul7(expert_gate, expert_mask_flat)
            # combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim)
            combine_tensor = self.mul8(self.expand(combine_tensor, -1),
                                       self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value))
            # combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim, self.expert_capacity)
            combine_tensor = self.mul9(self.expand2(combine_tensor, -1),
                                       self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity,
                                                    self.on_value, self.off_value))
            accum_combine_tensor = self.add2(accum_combine_tensor, combine_tensor)
        # expert weights normalization
        combine_tensor_sum = self.reduce_sum_keep2(self.reduce_sum_keep2(accum_combine_tensor, -1), -2)
        accum_combine_tensor = self.div2(accum_combine_tensor, self.add4(combine_tensor_sum, 1e-9))
        # dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has
        # bad performance
        dispatch_tensor = self.not_equal(accum_combine_tensor, 0.0)
        return dispatch_tensor, accum_combine_tensor, loss
    def _auxiliary_loss(self, expert_mask, router_prob):
        """
        Computing the load balance loss.
        """
        # density_1's shape: (dp_group, self.expert_dim)
        density_1 = self.reduce_mean(expert_mask, 1)
        # density_1_proxy's shape: (dp_group, self.expert_dim)
        density_1_proxy = self.reduce_mean2(router_prob, 1)
        loss = self.mul(density_1, density_1_proxy)
        loss = self.reduce_mean3(loss)
        loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim)
        return loss
    def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate, last_num, expert_chosen_index):
        """
        Keeping only the tokens that fit within expert_capacity.
        """
        cumsum = self.cumsum(expert_mask, 1)
        if expert_chosen_index > 0:
            cumsum = self.add(cumsum, last_num)
        # position_in_expert's shape: (dp_group, tokens_per_group, self.expert_dim)
        position_in_expert = self.mul4(cumsum, expert_mask)
        less_result = self.less(position_in_expert, expert_capacity)
        # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim)
        expert_mask = self.mul5(less_result, expert_mask)
        # expert_mask_flat's shape: (dp_group, tokens_per_group)
        expert_mask_flat = self.reduce_sum(expert_mask, -1)
        # Mask out the experts that have overflowed the expert_capacity.
        # expert_gate's shape: (dp_group, tokens_per_group)
        expert_gate = self.mul6(expert_gate, expert_mask_flat)
        output = (expert_mask, expert_gate, expert_mask_flat, position_in_expert)
        return output