# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ms function for mixed precision."""
from __future__ import absolute_import
from abc import ABC, abstractmethod
from ._checkparam import Validator as validator
from .common import dtype as mstype
from . import context
from . import ops
from .ops import constexpr
from .common.api import ms_class
from .common.parameter import Parameter
from .common.tensor import Tensor
from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager
from .train.amp import build_train_network, auto_mixed_precision
_hypermap = ops.HyperMap()
_partial = ops.Partial()
@constexpr
def _ascend_target():
    return context.get_context("device_target") == "Ascend"
@constexpr
def _gpu_target():
    return context.get_context("device_target") == "GPU"
def _grad_unscale(scale, grad):
    return grad * ops.Reciprocal()(scale).astype(grad.dtype)
def _grad_scale(scale, grad):
    return grad * scale.astype(grad.dtype)
def _is_finite(inputs):
    if _gpu_target():
        return ops.FloatStatus()(inputs)[0] == 0
    status = ops.isfinite(inputs)
    return status.all()
[文档]def init_status():
    r"""
    Returns a Tensor indicating initialized status for overflow detection.
    Note:
        Only Ascend need status to capture overflow status, you can also call
        this function on GPU or CPU, but the return value is useless.
    Returns:
        Tensor, has the shape of `(8,)`.
    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``
    Examples:
        >>> status = amp.init_status()
    """
    if _ascend_target():
        status = ops.NPUAllocFloatStatus()()
        clear_status = ops.NPUClearFloatStatus()(status)
        status = ops.depend(status, clear_status)
    else:
        status = Tensor([0, 0, 0, 0, 0, 0, 0, 0], mstype.float32)
    return status 
[文档]def all_finite(inputs, status=None):
    r"""
    Returns a scalar Tensor indicating whether the inputs are finite.
    Note:
        This is an experimental interface that is subject to change or deletion.
        The interface must be used in whole network training scenario to detect
        whether grads are finite, and the results may be different on different
        device targets.
    Args:
        inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
        status (Tensor): the status Tensor for overflow detection, only required on
            Ascend. Default: None.
    Returns:
        Tensor, a scalar Tensor and the dtype is bool.
    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``
    Examples:
        >>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0]))
        >>> output = amp.all_finite(x)
    """
    if _ascend_target():
        if status is None:
            raise ValueError("The status must be initialized on Ascend, but get 'None'.")
        status = ops.depend(status, inputs)
        get_status = ops.NPUGetFloatStatus()(status)
        status = ops.depend(status, get_status)
        status_finite = status.sum() == 0
        _ = ops.NPUClearFloatStatus()(status)
        return status_finite
    outputs = _hypermap(_partial(_is_finite), inputs)
    return ops.stack(outputs).all() 
[文档]@ms_class
class LossScaler(ABC):
    r"""
    Loss scaler abstract class when using mixed precision.
    Derived class needs to implement all of its methods. During training, `scale` and `unscale` is used
    to scale and unscale the loss value and gradients to avoid overflow, `adjust` is used to update the
    loss scale value.
    Note:
        This is an experimental interface that is subject to change or deletion.
    """
[文档]    @abstractmethod
    def scale(self, inputs):
        """
        Scaling inputs by `scale_value`.
        Args:
            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
        """
        raise NotImplementedError 
[文档]    @abstractmethod
    def unscale(self, inputs):
        """
        Unscaling inputs by `scale_value`.
        Args:
            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
        """
        raise NotImplementedError 
[文档]    @abstractmethod
    def adjust(self, grads_finite):
        """
        Adjust the `scale_value` dependent on whether grads are finite.
        Args:
            grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite.
        """
        raise NotImplementedError  
[文档]class StaticLossScaler(LossScaler):
    r"""
    Static Loss scale class.
    Scales and unscales loss or gradients by a fixed constant.
    Note:
        This is an experimental interface that is subject to change or deletion.
    Args:
        scale_value (Union(float, int)): The initial loss scale value.
    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``
    Examples:
        >>> loss_scaler = amp.StaticLossScaler(scale_value=2**10)
        >>> loss_value = Tensor([1.], mindspore.float32)
        >>> scaled_loss_value = loss_scaler.scale(loss_value)
        >>> print(scaled_loss_value)
        [1024.]
        >>> grads = (Tensor(np.array([1.5, 1.0]), mindspore.float16),
        ...      Tensor(np.array([1.2]), mindspore.float16))
        >>> unscaled_grads = loss_scaler.unscale(grads)
        >>> print(unscaled_grads)
        (Tensor(shape=[2], dtype=Float16, value= [ 1.4648e-03,  9.7656e-04]),
        Tensor(shape=[1], dtype=Float16, value= [ 1.1721e-03]))
    """
    def __init__(self, scale_value):
        scale_value = validator.check_value_type("scale_value", scale_value, [float, int])
        if scale_value < 1.0:
            raise ValueError("The argument 'scale_value' must be > 1, but got {}".format(scale_value))
        self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value")
[文档]    def scale(self, inputs):
        """
        Scaling inputs by `scale_value`.
        Args:
            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
        Returns:
            Union(Tensor, tuple(Tensor)), the scaled value.
        """
        return _hypermap(_partial(_grad_scale, self.scale_value), inputs) 
[文档]    def unscale(self, inputs):
        """
        Unscaling inputs by `scale_value`.
        Args:
            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
        Returns:
            Union(Tensor, tuple(Tensor)), the unscaled value.
        """
        return _hypermap(_partial(_grad_unscale, self.scale_value), inputs) 
[文档]    def adjust(self, grads_finite):
        """
        `scale_value` is fixed.
        Args:
            grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite.
        """
        return False  
[文档]class DynamicLossScaler(LossScaler):
    r"""
    Dynamic Loss scale class.
    Dynamic loss scaling tries to determine the largest loss scale value that
    will keep gradients finite. It does this by increasing the loss scale every
    `scale_window` steps by `factor` if the grads remain finite, otherwise it reduces
    the loss scale by `1 / factor` and resets the counter.
    Note:
        This is an experimental interface that is subject to change or deletion.
    Args:
        scale_value (Union(float, int)): The initial loss scale value.
        scale_factor (int): The scale factor.
        scale_window (int): Maximum continuous training steps that do not have
            overflow to increase the loss scale.
    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``
    Examples:
        >>> loss_scaler = amp.DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=1)
        >>> grads = (Tensor(np.array([np.log(-1), 1.0]), mindspore.float16),
        ...             Tensor(np.array([0.2]), mindspore.float16))
        >>> unscaled_grads = loss_scaler.unscale(grads)
        >>> grads_finite = amp.all_finite(unscaled_grads)
        >>> loss_scaler.adjust(grads_finite)
        True
        >>> print(loss_scaler.scale_value.asnumpy())
        512.0
    """
    def __init__(self, scale_value, scale_factor, scale_window):
        scale_value = validator.check_value_type("scale_value", scale_value, [float, int])
        if scale_value < 1.0:
            raise ValueError("The argument 'scale_value' must be > 1, but got {}".format(scale_value))
        self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value")
        self.scale_window = validator.check_positive_int(scale_window, "scale_window")
        self.scale_factor = validator.check_positive_int(scale_factor, "scale_factor")
        self.counter = Parameter(Tensor(0, dtype=mstype.int32), name="counter")
[文档]    def scale(self, inputs):
        """
        Scaling inputs by `scale_value`.
        Args:
            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
        Returns:
            Union(Tensor, tuple(Tensor)), the scaled value.
        """
        return _hypermap(_partial(_grad_scale, self.scale_value), inputs) 
[文档]    def unscale(self, inputs):
        """
        Unscaling inputs by `scale_value`.
        Args:
            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
        Returns:
            Union(Tensor, tuple(Tensor)), the unscaled value.
        """
        return _hypermap(_partial(_grad_unscale, self.scale_value), inputs) 
[文档]    def adjust(self, grads_finite):
        """
        Adjust the `scale_value` dependent on whether grads are finite.
        Args:
            grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite.
        """
        one = ops.ones((), self.scale_value.dtype)
        scale_mul_factor = self.scale_value * self.scale_factor
        scale_value = ops.select(
            grads_finite,
            ops.select(
                self.counter == (self.scale_window - 1),
                ops.select(_is_finite(scale_mul_factor),
                           scale_mul_factor,
                           self.scale_value),
                self.scale_value),
            ops.maximum(one, self.scale_value / self.scale_factor))
        ops.assign(self.scale_value, scale_value)
        counter = ((self.counter + 1) % self.scale_window) * grads_finite
        ops.assign(self.counter, counter)
        return True  
__all__ = [
    "DynamicLossScaleManager", "LossScaleManager", "FixedLossScaleManager",
    "build_train_network", "DynamicLossScaler", "StaticLossScaler", "LossScaler",
    "auto_mixed_precision", "init_status", "all_finite"
]