mindformers.wrapper.wrapper 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Self-Define Wrapper."""
import hashlib
import os
import shutil
from copy import deepcopy

from mindformers.core.clip_grad import ClipGradNorm
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.tools.utils import get_real_rank
from mindformers.version_control import get_identity, is_dump_supported

from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.parallel._utils import _get_enable_parallel_optimizer
import mindspore.common.dtype as mstype
from mindspore import nn, Parameter, ParallelMode, get_auto_parallel_context
from mindspore.common import RowTensor
from mindspore.common.tensor import Tensor
from mindspore.communication.management import (create_group, get_rank, get_group_size)
from mindspore.nn.wrap.cell_wrapper import _MicroBatch
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.auto_generate import SumExt
if is_dump_supported():
    from mindspore.ops._grad_experimental.grad_comm_ops import get_squared_device_local_norm_param





__all__ = ['MFTrainOneStepCell', 'MFPipelineWithLossScaleCell', 'PipelineCellWithTwoOutput',
           'GradAccumulationCellWithTwoOutput']

_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()


@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
    return F.cast(grad, mstype.float32) * reciprocal(scale)


@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
    return RowTensor(grad.indices,
                     grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
                     grad.dense_shape)


get_square_sum = C.MultitypeFuncGraph("get_square_sum")
get_size = C.MultitypeFuncGraph("get_size")


@get_square_sum.register("Tensor", "Number")
def _get_square_sum(grad, value):
    norm = SumExt()(F.square(F.cast(grad, mstype.float32)), ()) * value
    norm = F.expand_dims(norm, 0)
    return norm


# pylint: disable=E0102
@get_square_sum.register("Tensor")
def _get_square_sum(grad):
    norm = SumExt()(F.square(F.cast(grad, mstype.float32)), ())
    norm = F.expand_dims(norm, 0)
    return norm


@get_size.register("Tensor")
def _get_size(grad):
    size = P.Size()(grad)
    return size


class LocalNorm(nn.Cell):
    def __init__(self):
        super(LocalNorm, self).__init__()
        self.hyper_map = C.HyperMap()

    def construct(self, grads):
        square_sum = self.hyper_map(get_square_sum, grads)
        size = self.hyper_map(get_size, grads)
        return square_sum, size


# pylint: disable=W1401
[文档]@MindFormerRegister.register(MindFormerModuleType.WRAPPER) class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell): r"""TrainOneStep For MindFormer. Network training with loss scaling, grad clip, gradient accumulation, exponential moving average and so on. This is a training step with loss scaling. It takes a network, an optimizer and a scale update Cell(or a Tensor) as args. The loss scale value can be updated in both host side or device side. If you want to update it on host side, using a value of Tensor type as `scale_sense`, otherwise, using a Cell instance for updating loss scale as `scale_sense`. Args: network (Cell): The training network. The network only supports single output. optimizer (Cell): Optimizer for updating the network parameters. use_clip_grad (bool, optional): Whether to use the gradient clipping function. Default: ``False``. max_grad_norm (float, optional): Maximum gradient value. Default: ``1.0``. scale_sense (Union[int, float, Tensor, Cell], optional): The scaling number to be filled as the input of backpropagation. If this value is a Cell, it will be called by `MFTrainOneStepCell` to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`, the shape should be :math:`()` or :math:`(1,)`. Default: ``1.0``. local_norm (bool, optional): Whether to calculate the local norm. Default: ``False``. calculate_per_token_loss (bool, optional): Whether to calculate the loss of each token. Default: `False`. global_norm_spike_threshold (float, optional): The threshold of global norm when the skip data function is enabled. Default: ``1.0``. use_skip_data_by_global_norm (bool, optional): The switch of the skip data function. Default: ``False``. **kwargs (Any): Additional parameters. Inputs: - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. Outputs: Tuple of 5 or 7 Tensor, the loss, overflow flag, current loss scale value, learning rate, global grads norm, local grads norm and size of local norm grads. - **loss** (Tensor) - A scalar, the loss value. - **overflow** (Tensor) - A scalar, whether overflow occur or not, the type is bool. - **loss scale** (Tensor) - The loss scale value, the shape is :math:`()` or :math:`(1,)`. - **learning rate** (Tensor) - A scalar, the learning rate of the optimizer. - **global norm** (Tensor) - A scalar, the global norm of all grads, only be calculated when `use_clip_grad=True`, otherwise None. - **local_norm** (Tensor) - The local norm of the grads by group, only be returned when `local_norm=True`. - **size** (Tensor) - The sizes of each grads group, only be returned when `local_norm=True`. Raises: TypeError: If `scale_sense` is neither Cell nor Tensor. ValueError: If shape of `scale_sense` is neither (1,) nor (). Examples: >>> from mindformers.models.llama import LlamaConfig, LlamaForCausalLM >>> from mindformers.wrapper import MFTrainOneStepCell >>> import mindspore as ms >>> from mindformers.core.optim import AdamW >>> import numpy as np >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> >>> config = LlamaConfig(num_layers=2) >>> net = LlamaForCausalLM(config=config) >>> net.set_train(True) >>> optimizer = AdamW(net.trainable_params()) >>> >>> mft = MFTrainOneStepCell(net, optimizer) >>> inputs = ms.Tensor(np.ones([1, 2049]), ms.int32) >>> out = mft(inputs) >>> >>> loss, overflow, loss_scale, lr, global_norm = out >>> print(loss.shape, overflow, loss_scale, lr, global_norm) (1,) False 1.0 0.001, None """ @args_type_check(global_norm_spike_threshold=float, use_skip_data_by_global_norm=bool) def __init__(self, network, optimizer, use_clip_grad=False, max_grad_norm=1.0, scale_sense=1.0, local_norm=False, calculate_per_token_loss=False, global_norm_spike_threshold=1.0, use_skip_data_by_global_norm=False, **kwargs): if isinstance(scale_sense, (int, float)): scale_sense = Tensor(scale_sense) super(MFTrainOneStepCell, self).__init__(network, optimizer, scale_sense) self.use_clip_grad = use_clip_grad self.clip_grad_norm = ClipGradNorm(max_norm=max_grad_norm) self.parallel_config = kwargs.pop("parallel_config", None) self.learning_rate = deepcopy(self.optimizer.learning_rate) self.localnorm = LocalNorm() self.concat = P.Concat() self.cast = P.Cast() self.local_norm = local_norm self.calculate_per_token_loss = calculate_per_token_loss self.zero_t = Tensor([0], dtype=mstype.float32) self.grad_scale_factor = Tensor([1], dtype=mstype.float32) if is_dump_supported(): self.dump_device_local_norm = get_auto_parallel_context("dump_device_local_norm") self.if_dump = bool(get_auto_parallel_context("dump_local_norm_path")) else: self.dump_device_local_norm = False self.if_dump = False if self.if_dump: self.dump = P.TensorDump() self.dump_path = os.path.join(get_auto_parallel_context("dump_local_norm_path"), f"rank_{get_real_rank()}") if os.path.exists(self.dump_path): shutil.rmtree(self.dump_path) os.makedirs(self.dump_path, exist_ok=True) self.device_local_norm_filename = os.path.join(self.dump_path, "device_local_norm") self.finish_step_filename = os.path.join(self.dump_path, "finish_step") if self.dump_device_local_norm: self.squared_device_local_norm = get_squared_device_local_norm_param() self.global_norm_spike_threshold = global_norm_spike_threshold self.use_skip_data_by_global_norm = use_skip_data_by_global_norm def construct(self, *inputs): """forward and backward.""" weights = self.weights scaling_sens = self.scale_sense if self.calculate_per_token_loss: numerator, denominator = self.network(*inputs) loss = numerator / denominator if self.dump_device_local_norm: _ = F.assign(self.squared_device_local_norm, Tensor(0.0, mstype.float32)) scaling_sens = F.depend(self.scale_sense, _) scaling_sens_filled = C.ones_like(numerator) * F.cast(scaling_sens, F.dtype(numerator)) scaling_sens_filled2 = self.zero_t * F.cast(scaling_sens, F.dtype(denominator)) grads = self.grad(self.network, weights)(*inputs, (self.cast(scaling_sens_filled, mstype.float32), self.cast(scaling_sens_filled2, mstype.float32))) grad_scale_factor = denominator else: loss = self.network(*inputs) if self.dump_device_local_norm: _ = F.assign(self.squared_device_local_norm, Tensor(0.0, mstype.float32)) scaling_sens = F.depend(self.scale_sense, _) scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) grad_scale_factor = self.grad_scale_factor status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.hyper_map(F.partial(_grad_scale, scaling_sens * grad_scale_factor), grads) if self.local_norm: local_norm, size = self.localnorm(grads) local_norm = self.concat(local_norm) # apply grad reducer on grads grads = self.grad_reducer(grads) # get the overflow buffer cond = self.get_overflow_status(status, grads) overflow = self.process_loss_scale(cond) global_norm = None if self.use_clip_grad: grads, global_norm = self.clip_grad_norm(grads) learning_rate = self.learning_rate if self.optimizer.dynamic_lr: if self.optimizer.is_group_lr: learning_rate = self.learning_rate[-1](self.optimizer.global_step).reshape(()) else: learning_rate = self.learning_rate(self.optimizer.global_step).reshape(()) if self.use_graceful_exit: grads = self.graceful_exit.exit_by_request(grads, self.init_param, self.exit_param) # if there is no overflow, do optimize if not overflow: if self.use_skip_data_by_global_norm: if global_norm < self.global_norm_spike_threshold: loss = F.depend(loss, self.optimizer(grads)) else: loss = F.depend(loss, self.optimizer(grads)) if self.if_dump: zero = Tensor(0.0, mstype.float32) if self.dump_device_local_norm: zero = F.depend(zero, self.dump(self.device_local_norm_filename, F.sqrt(self.squared_device_local_norm))) filename = self.finish_step_filename self.dump(filename, zero) if self.local_norm: return loss, overflow, scaling_sens, learning_rate, global_norm, local_norm, size return loss, overflow, scaling_sens, learning_rate, global_norm
class DataOrderWrapperCell(nn.Cell): """For passing parameters in lexicographical order.""" def __init__(self, construct_args_key, network): super(DataOrderWrapperCell, self).__init__(auto_prefix=False) self.construct_args_key = construct_args_key self.network = network def construct(self, *inputs): """The construct processes of inputs in lexicographical order.""" key_inputs = {key: val for key, val in zip(self.construct_args_key, inputs)} return self.network(**key_inputs) grad_scale = C.MultitypeFuncGraph("grad_scale") shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale") @grad_scale.register("Tensor", "Tensor", "Tensor", "Tensor") def tensor_grad_scale_pipeline(scale, grad_scale_factor, grad, accu_grad): accu_grad = F.depend(accu_grad, grad) new_grad = accu_grad * reciprocal(scale) accu_grad = F.depend(accu_grad, new_grad) grad_val = F.cast(F.equal(accu_grad, accu_grad), F.dtype(accu_grad)) zeros = F.mul(grad_val, 0) new_grad = F.depend(new_grad, F.assign(accu_grad, zeros)) new_grad = new_grad * F.cast(reciprocal(grad_scale_factor), F.dtype(new_grad)) return new_grad @shard_grad_scale.register("Tensor", "Tensor", "Tensor", "Tensor") def tensor_shard_grad_scale_pipeline(scale, grad_scale_factor, grad, accu_grad): new_grad = grad * F.cast(reciprocal(scale), F.dtype(grad)) accu_grad = F.depend(accu_grad, new_grad) grad_val = F.cast(F.equal(accu_grad, accu_grad), F.dtype(accu_grad)) zeros = F.mul(grad_val, 0) new_grad = F.depend(new_grad, F.assign(accu_grad, zeros)) new_grad = new_grad * F.cast(reciprocal(grad_scale_factor), F.dtype(new_grad)) return new_grad class GradAccumulationCellWithTwoOutput(nn.Cell): """ Wrap the network with Micro Batch to enable the grad accumulation in semi_auto_parallel/auto_parallel mode. Note: micro_size must be greater or equal to pipeline stages. Args: network (Cell): The target network to wrap. micro_size (int): MicroBatch size. Supported Platforms: ``Ascend`` ``GPU`` """ def __init__(self, network, micro_size): super().__init__(auto_prefix=False) self.network = network self.micro_inputs = nn.CellList() self.micro_size = micro_size self.add_list = [] if not isinstance(network, nn.Cell): raise TypeError("For 'GradAccumulationCellWithTwoOutput', the argument 'network' must cell type, " "but got the type : {}.".format(type(network))) if not isinstance(micro_size, int): raise TypeError("For 'GradAccumulationCellWithTwoOutput', the argument 'micro_size' must be integer, " "but got the type : {}.".format(type(micro_size))) if micro_size <= 0: raise ValueError("For 'GradAccumulationCellWithTwoOutput', the argument 'micro_size' must be large than 0, " "but got {}.".format(micro_size)) for i in range(micro_size): micro_input = _MicroBatch(micro_size) micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size) self.micro_inputs.append(micro_input) self.add = P.Add().add_prim_attr("forward_end", i) self.add_list.append(self.add) self._get_attr_from_cell(network) def construct(self, *inputs): """Construct function for pipeline with multiple outputs.""" ret = None ret2 = None for i in range(self.micro_size): micro_input = self.micro_inputs[i](i, *inputs) output1, output2 = self.network(*micro_input) if ret is not None: ret = self.add_list[i](ret, output1) else: ret = output1 if ret2 is not None: ret2 = self.add_list[i](ret2, output2) else: ret2 = output2 loss = ret, ret2 return loss def _get_pipeline_group(): """ Calculate the communication group between all pipeline stages """ rank = get_rank() stage_nums = auto_parallel_context().get_pipeline_stages() device_nums = get_group_size() per_stage_device_nums = device_nums // stage_nums local_stage_rank_id = rank % per_stage_device_nums group = range(0, stage_nums) rank_list = [local_stage_rank_id + x * per_stage_device_nums for x in group] rank_str_list = [str(local_stage_rank_id + x * per_stage_device_nums) for x in group] rank_list_str = "-".join(rank_str_list) return rank_list, rank_list_str class PipelineCellWithTwoOutput(nn.Cell): """ Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training. Note: micro_size must be greater or equal to pipeline stages. Args: network (Cell): The target network to wrap. micro_size (int): MicroBatch size. Supported Platforms: ``Ascend`` ``GPU`` """ def __init__(self, network, micro_size): super().__init__(auto_prefix=False) self.network = network self.micro_inputs = nn.CellList() self.micro_size = micro_size self.add_list = [] if not isinstance(network, nn.Cell): raise TypeError("For 'PipelineCellWithTwoOutput', the argument 'network' must cell type, " "but got the type : {}.".format(type(network))) if not isinstance(micro_size, int): raise TypeError("For 'PipelineCellWithTwoOutput', the argument 'micro_size' must be integer, " "but got the type : {}.".format(type(micro_size))) if micro_size <= 0: raise ValueError("For 'PipelineCellWithTwoOutput', the argument 'micro_size' must be large than 0, " "but got {}.".format(micro_size)) for i in range(micro_size): micro_input = _MicroBatch(micro_size) self.micro_inputs.append(micro_input) self.add = P.Add().add_prim_attr("pipeline_end", i) self.add_list.append(self.add) self._get_attr_from_cell(network) def construct(self, *inputs): """ Construct function for pipeline with multiple input Args: *inputs: Returns: """ ret = None ret2 = None for i in range(self.micro_size): micro_input = self.micro_inputs[i](i, *inputs) output1, output2 = self.network(*micro_input) if ret is not None: ret = self.add_list[i](ret, output1) else: ret = output1 if ret2 is not None: ret2 = self.add_list[i](ret2, output2) else: ret2 = output2 loss = ret, ret2 return loss # pylint: disable=W1401
[文档]@MindFormerRegister.register(MindFormerModuleType.WRAPPER) class MFPipelineWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): r""" Append a train-one-step cell with loss scale of pipeline parallel for MindFormers. Args: network (Cell): The training network. Note that loss function should have been added. optimizer (Optimizer): Optimizer for updating the weights. use_clip_grad (bool, optional): Whether to use gradient clipping. Default: ``True``. max_grad_norm (float, optional): Maximum gradient constraint value. Default: ``1.0``. scale_sense (Union[Tensor, Cell], optional): Cell to do the loss scale. Default: ``1.0``. micro_batch_num (int, optional): Micro batch number of pipeline parallel. Default: ``1``. local_norm (bool, optional): Whether to calculate the local norm. Default: ``False``. calculate_per_token_loss (bool, optional): Whether to calculate the loss of each token. Default: ``False``. global_norm_spike_threshold (float, optional): The threshold of global norm when the skip data function is enabled. Default: ``1.0``. use_skip_data_by_global_norm (bool, optional): The switch of the skip data function. Default: ``False``. **kwargs (Any): Additional parameters. Inputs: - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. Outputs: Tuple of 5 or 7 Tensor, the loss, overflow flag, current loss scale value, learning rate, global grads norm, local grads norm and size of local norm grads. - **loss** (Tensor) - A scalar, the loss value. - **overflow** (Tensor) - A scalar, whether overflow occur or not, the type is bool. - **loss scale** (Tensor) - The loss scale value, the shape is :math:`()` or :math:`(1,)`. - **learning rate** (Tensor) - A scalar, the learning rate of the optimizer. - **global norm** (Tensor) - A scalar, the global norm of all grads, only be calculated when `use_clip_grad=True`, otherwise None. - **local_norm** (Tensor) - The local norm of the grads by group, only be returned when `local_norm=True`. - **size** (Tensor) - The sizes of each grads group, only be returned when `local_norm=True`. Raises: TypeError: If `scale_sense` is neither Cell nor Tensor. ValueError: If shape of `scale_sense` is neither (1,) nor (). ValueError: If the parallel mode is not one of [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL]. """ @args_type_check(global_norm_spike_threshold=float, use_skip_data_by_global_norm=bool) def __init__(self, network, optimizer, use_clip_grad=True, max_grad_norm=1.0, scale_sense=1.0, micro_batch_num=1, local_norm=False, calculate_per_token_loss=False, global_norm_spike_threshold=1.0, use_skip_data_by_global_norm=False, **kwargs): if isinstance(scale_sense, (int, float)): scale_sense = Tensor(scale_sense) super(MFPipelineWithLossScaleCell, self).__init__(network, optimizer, scale_sense) self.network = network self.network.add_flags(defer_inline=True) self.weights = optimizer.parameters self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") self.optimizer = optimizer self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad_reducer = get_identity() self.degree = 1 self.cast = P.Cast() self.status = Tensor([0] * 8, mstype.int32) self.reduce_sum = SumExt() if self.parallel_mode not in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL]: raise ValueError(f"ParallelMode must be one of " f"[ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL], but found " f"{self.parallel_mode}.") self.allreduce = P.AllReduce() self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() self.reshape = P.Reshape() self.loss_scaling_manager = None if isinstance(scale_sense, nn.Cell): self.loss_scaling_manager = scale_sense self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), name="scale_sense") elif isinstance(scale_sense, Tensor): if scale_sense.shape == (1,) or scale_sense.shape == (): self.scale_sense = Parameter(scale_sense, name='scale_sense') else: raise ValueError("The shape of 'scale_sense' must be (1,) or (), but got {}" .format(scale_sense.shape)) else: raise TypeError("The 'scale_sense' must be Cell or Tensor, but got {}".format(type(scale_sense))) self.opt_shard = _get_enable_parallel_optimizer() self.use_clip_grad = use_clip_grad self.clip_grad_norm = ClipGradNorm(max_norm=max_grad_norm) self.micro_size = micro_batch_num self.parallel_config = kwargs.pop("parallel_config", None) self.learning_rate = deepcopy(self.optimizer.learning_rate) self.localnorm = LocalNorm() self.concat = P.Concat() self.local_norm = local_norm # create allreduce for synchronize denominator pipeline_group_list, pipeline_group_name = _get_pipeline_group() hashed = hashlib.sha256(pipeline_group_name.encode()).hexdigest()[:48] pipeline_group_name = str(hashed) create_group(pipeline_group_name, pipeline_group_list) self.allreduce2 = P.AllReduce(group=pipeline_group_name) self.calculate_per_token_loss = calculate_per_token_loss self.grad_scale_factor = Tensor([1], dtype=mstype.float32) self.zero_t = Tensor([0], dtype=mstype.float32) if is_dump_supported(): self.dump_device_local_norm = get_auto_parallel_context("dump_device_local_norm") self.if_dump = bool(get_auto_parallel_context("dump_local_norm_path")) else: self.dump_device_local_norm = False self.if_dump = False if self.if_dump: self.dump = P.TensorDump() self.dump_path = os.path.join(get_auto_parallel_context("dump_local_norm_path"), f"rank_{get_real_rank()}") if os.path.exists(self.dump_path): shutil.rmtree(self.dump_path) self.device_local_norm_filename = os.path.join(self.dump_path, "device_local_norm") self.finish_step_filename = os.path.join(self.dump_path, "finish_step") if self.dump_device_local_norm: self.squared_device_local_norm = get_squared_device_local_norm_param() self.global_norm_spike_threshold = global_norm_spike_threshold self.use_skip_data_by_global_norm = use_skip_data_by_global_norm @C.add_flags(has_effect=True) def construct(self, *inputs): """The construct processes of pipeline wrapper cell.""" scaling_sens = self.scale_sense if self.dump_device_local_norm: _ = F.assign(self.squared_device_local_norm, Tensor(0.0, mstype.float32)) scaling_sens = F.depend(self.scale_sense, _) if self.calculate_per_token_loss: numerator, denominator = self.network(*inputs) denominator = self.allreduce2(denominator) loss = numerator / denominator scaling_sens_filled = C.ones_like(numerator) * F.cast(scaling_sens, F.dtype(numerator)) scaling_sens_filled2 = self.zero_t * F.cast(scaling_sens, F.dtype(denominator)) grads = self.grad(self.network, self.weights)(*inputs, (self.cast(scaling_sens_filled, mstype.float32), self.cast(scaling_sens_filled2, mstype.float32))) grad_scale_factor = denominator else: loss = self.network(*inputs) scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) grads = self.grad(self.network, self.weights)(*inputs, self.cast(scaling_sens_filled / self.micro_size, mstype.float32)) grad_scale_factor = self.grad_scale_factor if self.local_norm: local_norm, size = self.localnorm(grads) local_norm = self.concat(local_norm) if self.opt_shard: grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree, grad_scale_factor), grads, self.accu_grads) else: accu_grads = self.grad_reducer(self.accu_grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree, grad_scale_factor), grads, accu_grads) global_norm = None if self.use_clip_grad: grads, global_norm = self.clip_grad_norm(grads) learning_rate = self.learning_rate if self.optimizer.dynamic_lr: if self.optimizer.is_group_lr: learning_rate = self.learning_rate[-1](self.optimizer.global_step).reshape(()) else: learning_rate = self.learning_rate(self.optimizer.global_step).reshape(()) # sum overflow flag over devices cond = self.get_overflow_status(self.status, grads) cond = F.depend(cond, grads) overflow = self.process_loss_scale(cond) if self.use_graceful_exit: grads = self.graceful_exit.exit_by_request(grads, self.init_param, self.exit_param) if not overflow: if self.use_skip_data_by_global_norm: if global_norm < self.global_norm_spike_threshold: loss = F.depend(loss, self.optimizer(grads)) else: loss = F.depend(loss, self.optimizer(grads)) if self.if_dump: zero = Tensor(0.0, mstype.float32) if self.dump_device_local_norm: zero = F.depend(zero, self.dump(self.device_local_norm_filename, F.sqrt(self.squared_device_local_norm))) filename = self.finish_step_filename self.dump(filename, zero) if self.local_norm: return loss, overflow, scaling_sens.value(), learning_rate, global_norm, local_norm, size return loss, overflow, scaling_sens.value(), learning_rate, global_norm