# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Parallel Loss for the Parallel Training
This is an experimental interface that is subject to change or deletion.
"""
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn import Cell
from mindspore.nn.loss.loss import _check_is_tensor
from .layers import _check_input_dtype, _check_input_shape
from .op_parallel_config import default_dpmp_config, OpParallelConfig
__all__ = ["CrossEntropyLoss"]
[docs]class CrossEntropyLoss(Cell):
    """
    Calculate the cross entropy loss.
    Args:
        parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
                                           an instance of `OpParallelConfig` with default args.
    Inputs:
        - **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
          the backbone.
        - **labels** (Tensor) - Tensor of shape (N, ). The ground truth label of the sample.
        - **input_mask** (Tensor) - Tensor of shape (N, ). input_mask indicates whether there are padded inputs and for
          padded inputs it will not be counted into loss.
    Outputs:
        Tensor. The corresponding cross entropy loss
    Examples:
        >>> import numpy as np
        >>> from mindspore import dtype as mstype
        >>> from mindspore.nn.transformer import CrossEntropyLoss
        >>> from mindspore import Tensor
        >>> loss = CrossEntropyLoss()
        >>> logits = Tensor(np.array([[3, 5, 6, 9, 12, 33, 42, 12, 32, 72]]), mstype.float32)
        >>> labels_np = np.array([1]).astype(np.int32)
        >>> input_mask = Tensor(np.ones(1).astype(np.float32))
        >>> labels = Tensor(labels_np)
        >>> output = loss(logits, labels, input_mask)
        >>> print(output.shape)
        (1,)
    """
    def __init__(self, parallel_config=default_dpmp_config):
        super(CrossEntropyLoss, self).__init__()
        if not isinstance(parallel_config, OpParallelConfig):
            raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig, "
                            "but got the type: {}.".format(type(parallel_config)))
        dp = parallel_config.data_parallel
        mp = parallel_config.model_parallel
        self.sum = P.ReduceSum().shard(((dp, mp),))
        self.onehot = P.OneHot().shard(((dp, mp), (), ()))
        # on/off value for onehot, for smooth labeling, modify the off_value
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard(
            ((dp, mp),))
        self.eps_const = Tensor(1e-24, mstype.float32)
        self.sub = P.Sub().shard(((dp, mp), (dp, 1)))
        self.exp = P.Exp().shard(((dp, mp),))
        self.div = P.RealDiv().shard(((dp, mp), (dp, 1)))
        self.log = P.Log().shard(((dp, mp),))
        self.add = P.Add().shard(((dp, mp), ()))
        self.mul = P.Mul().shard(
            ((dp, mp), (dp, mp)))
        self.neg = P.Neg().shard(((dp, mp),))
        self.sum2 = P.ReduceSum().shard(((1,),))
        self.mul2 = P.Mul().shard(((1,), (1,)))
        self.add2 = P.Add()
        self.div2 = P.RealDiv()
    def construct(self, logits, label, input_mask):
        self._check_input(logits, label, input_mask)
        # the shape is [bs*seq_length, vocab_size]
        logits = F.cast(logits, mstype.float32)
        # LogSoftmax for logits over last dimension
        _, logit_max = self.max(logits)
        logit_sub = self.sub(logits, logit_max)
        logit_exp = self.exp(logit_sub)
        exp_sum = self.sum(logit_exp, -1)
        exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
        softmax_result = self.div(logit_exp, exp_sum)
        log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
        # Flatten label to [bs*seq_length]
        label = P.Reshape()(label, (-1,))
        # Get onehot label [bs*seq_length, vocab_size]
        one_hot_label = self.onehot(label, F.shape(logits)[-1], self.on_value,
                                    self.off_value)
        # Cross-Entropy loss
        loss = self.mul(log_softmax_result, one_hot_label)
        loss_unsum = self.neg(loss)
        loss_reduce = self.sum(loss_unsum, -1)
        # input_mask indicates whether there is padded inputs and for padded inputs it will not be counted into loss
        input_mask = P.Reshape()(input_mask, (-1,))
        numerator = self.sum2(self.mul2(loss_reduce, input_mask))
        denominator = self.add2(
            self.sum2(input_mask),
            P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32))
        loss = self.div2(numerator, denominator)
        return loss
    def _check_input(self, logits, label, input_mask):
        r"""Check the input tensor shape and type"""
        _check_is_tensor('logits', logits, self.cls_name)
        _check_is_tensor('label', label, self.cls_name)
        _check_is_tensor('input_mask', input_mask, self.cls_name)
        _check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name)
        _check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name)
        _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
        _check_input_shape(F.shape(logits), "logits", self.cls_name, 2)
        _check_input_shape(F.shape(label), "label", self.cls_name, 1)
        _check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 1)
        return True