# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""grad accumulation"""
from mindspore.nn.cell import Cell
from mindspore.common import Parameter, Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
__all__ = ["GradientAccumulation", "gradient_accumulation_op", "gradient_clear_op"]
gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
@gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
def cumulative_grad_process(accumulation_step, cumulative_grad, grad):
    """Apply gradient accumulation to cumulative grad."""
    return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
@gradient_clear_op.register("Tensor")
def clear_grad(cumulative_grad):
    zero_grad = P.ZerosLike()(cumulative_grad)
    return F.assign(cumulative_grad, zero_grad)
[docs]class GradientAccumulation(Cell):
    """
    After accumulating the gradients of multiple steps, call to optimize its update.
    Args:
       max_accumulation_step (int): Steps to accumulate gradients.
       optimizer (Cell): Optimizer used.
    """
    def __init__(self, max_accumulation_step, optimizer):
        super(GradientAccumulation, self).__init__()
        self._max_accumulation_step = max_accumulation_step
        self.optimizer = optimizer
        self.weights = optimizer.parameters
        self.hyper_map = C.HyperMap()
        self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
        self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
    def construct(self, loss, grads):
        loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self._max_accumulation_step),
                                             self._grad_accumulation, grads))
        self._accumulation_step += 1
        if self._accumulation_step >= self._max_accumulation_step:
            loss = F.depend(loss, self.optimizer(self._grad_accumulation))
            self._accumulation_step = 0
        if self._accumulation_step == 0:
            loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation))
        return loss