Source code for mindspore.nn.optim.thor

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""thor"""
import numpy as np
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.log as logger
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
from mindspore.parallel._auto_parallel_context import auto_parallel_context


# Enumerates types of Layer
Other = -1
Conv = 1
FC = 2
Embedding = 3
LayerNorm = 4
BatchNorm = 5

op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")


@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
    """Get grad with weight_decay."""
    if if_apply:
        return op_add((weight * weight_decay, gradient))
    return gradient


@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
    """Apply momentum optimizer to the weight parameter using Tensor."""
    success = True
    success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
    return success

IS_ENABLE_GLOBAL_NORM = False
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
hyper_map_op = C.HyperMap()


@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
    """
    Clip gradients.

    Inputs:
        clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
        clip_value (float): Specifies how much to clip.
        grad (tuple[Tensor]): Gradients.

    Outputs:
        tuple[Tensor], clipped gradients.
    """
    if clip_type not in [0, 1]:
        return grad
    dt = F.dtype(grad)
    if clip_type == 0:
        new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
                                   F.cast(F.tuple_to_array((clip_value,)), dt))
    else:
        new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
    return new_grad


def clip_gradient(enable_clip_grad, gradients):
    """clip gradients"""
    if enable_clip_grad:
        if IS_ENABLE_GLOBAL_NORM:
            gradients = C.clip_by_global_norm(gradients, GRADIENT_CLIP_VALUE, None)
        else:
            gradients = hyper_map_op(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), gradients)
    return gradients

C0 = 16


def _check_param(momentum, frequency, lr, cls_name):
    """Check param."""
    Validator.check_value_type("momentum", momentum, [float], cls_name)
    if isinstance(momentum, float) and momentum < 0.0:
        raise ValueError("For 'thor', the argument 'momentum' should be at least 0.0, "
                         "but got 'momentum' {}.".format(momentum))
    Validator.check_value_type("frequency", frequency, [int], cls_name)
    if isinstance(frequency, int) and frequency < 2:
        raise ValueError("For 'thor', the argument 'frequency' should be at least 2, "
                         "but got 'frequency' {}.".format(frequency))
    Validator.check_value_type("learning rate", lr, [Tensor], cls_name)


def caculate_device_shape(matrix_dim, channel, is_a):
    if is_a:
        if channel // C0 == 0:
            matrix_dim = (matrix_dim / channel) * C0
    ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
    return ll


def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape):
    """is conv layer matmul support shape"""
    temp = (matrix_g_shape, matrix_a_shape)
    support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)),
                     ((4, 4, 16, 16), (4, 4, 16, 16)),
                     ((4, 4, 16, 16), (36, 36, 16, 16)),
                     ((16, 16, 16, 16), (4, 4, 16, 16)),
                     ((4, 4, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (72, 72, 16, 16)),
                     ((32, 32, 16, 16), (8, 8, 16, 16)),
                     ((32, 32, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (144, 144, 16, 16)),
                     ((64, 64, 16, 16), (16, 16, 16, 16)),
                     ((64, 64, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (288, 288, 16, 16)),
                     ((128, 128, 16, 16), (32, 32, 16, 16)),
                     ((128, 128, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (128, 128, 16, 16))]
    if temp in support_shape:
        return True
    return False


def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
    """get matmul shape"""
    split_dima = split_dim
    split_dimg = split_dim
    if matrix_a_dim % split_dim == 0:
        batch_w = matrix_a_dim // split_dim
    else:
        if matrix_a_dim < split_dim:
            batch_w = 1
            split_dima = matrix_a_dim
        else:
            batch_w = matrix_a_dim // split_dim + 1

    if matrix_g_dim % split_dim == 0:
        batch_h = matrix_g_dim // split_dim
    else:
        if matrix_g_dim < split_dim:
            batch_h = 1
            split_dimg = matrix_g_dim
        else:
            batch_h = matrix_g_dim // split_dim + 1
    matrix_a_shape = (batch_h, batch_w, split_dima, split_dima)
    matrix_g_shape = (batch_h, split_dimg, split_dimg)
    return matrix_a_shape, matrix_g_shape


def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map):
    """get layer type for dense layer and conv layer"""
    if subcell.weight.requires_grad:
        if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \
                or "rpn_with_loss.rpn_convs_list.0." in prefix.lower():
            layertype_map.append(Other)


def find_net_layertype_recur(net, layertype_map):
    """get net layer type recursively."""
    cells = net.name_cells()
    for name in cells:
        subcell = cells[name]
        prefix = subcell.param_prefix
        if subcell == net:
            continue
        elif isinstance(subcell, Conv2dThor):
            layertype_map.append(Conv)
        elif isinstance(subcell, DenseThor):
            layertype_map.append(FC)
        elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)):
            layertype_map.append(Embedding)
        elif isinstance(subcell, nn.LayerNorm):
            layertype_map.append(LayerNorm)
        elif isinstance(subcell, nn.BatchNorm2d):
            if subcell.gamma.requires_grad:
                layertype_map.append(BatchNorm)
        elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
                                  nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
            if isinstance(subcell, (nn.Dense, nn.Conv2d)):
                get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
            else:
                layertype_map.append(Other)
        else:
            find_net_layertype_recur(subcell, layertype_map)


def get_net_layertype_mask(net):
    layertype_map = []
    find_net_layertype_recur(net, layertype_map)
    return layertype_map


def get_layer_counter(layer_type, layer_counter, params, idx):
    """get layer counter"""
    if layer_type in [Conv, FC]:
        if "bias" in params[idx].name.lower():
            layer_counter = layer_counter + 1
        else:
            if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
                layer_counter = layer_counter + 1
    elif layer_type in [LayerNorm, BatchNorm]:
        if "beta" in params[idx].name.lower():
            layer_counter = layer_counter + 1
    else:
        if "bias" in params[idx].name.lower():
            layer_counter = layer_counter + 1
        elif "weight" in params[idx].name.lower():
            if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
                layer_counter = layer_counter + 1
        else:
            layer_counter = layer_counter + 1
    return layer_counter


[docs]def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100): r""" Updates gradients by second-order algorithm--THOR. Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in: `THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation <https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_ The updating formulas are as follows, .. math:: \begin{array}{ll} \\ A_i = a_i{a_i}^T \\ G_i = D_{s_i}{ D_{s_i}}^T \\ m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\ w_i = w_i - \alpha * m_i \\ \end{array} :math:`D_{s_i}` represents the derivative of the loss function of the output of the i-th layer, :math:`a_{i-1}` represents the input of i-th layer,and which is the activations of previous layer, :math:`\beta` represents momentum, :math:`I` represents the identity matrix, :math:`\overline A` represents the transpose of matrix A, :math:`\lambda` represents 'damping', :math:`g_i` represents gradients of the i-th layer, :math:`\otimes` represents Kronecker product, :math:`\alpha` represents 'learning rate' Args: net (Cell): The training network. learning_rate (Tensor): A value for the learning rate. damping (Tensor): A value for the damping. momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0. weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0. loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the default value. Default: 1.0. batch_size (int): The size of a batch. Default: 32 use_nesterov (bool): Enable Nesterov momentum. Default: False. decay_filter (function): A function to determine which layers the weight decay applied to. And it only works when the weight_decay > 0. Default: lambda x: x.name not in [] split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other is 27~53. Default: None enable_clip_grad (bool): Whether to clip the gradients. Default: False frequency(int): The update interval of A/G and $A^{-1}/G^{-1}$. When frequency equals N (N is greater than 1), A/G and $A^{-1}/G^{-1}$ will be updated every N steps, and other steps will use the stale A/G and $A^{-1}/G^{-1}$ to update weights. Default: 100. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. Outputs: tuple[bool], all elements are True. Raises: TypeError: If `learning_rate` is not Tensor. TypeError: If `loss_scale`, `momentum` or `frequency` is not a float. TypeError: If `weight_decay` is neither float nor int. TypeError: If `use_nesterov` is not a bool. ValueError: If `loss_scale` is less than or equal to 0. ValueError: If `weight_decay` or `momentum` is less than 0. ValueError: If `frequency` is not int. ValueError: If `frequency` is less than 2. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindspore.nn import thor >>> from mindspore import Model >>> from mindspore import FixedLossScaleManager >>> from mindspore.train.callback import LossMonitor >>> from mindspore.train.train_thor import ConvertModelUtils >>> from mindspore import nn >>> from mindspore import Tensor >>> >>> net = Net() >>> dataset = create_dataset() >>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32) >>> optim = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4) >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> loss_scale = FixedLossScaleManager(128, drop_overflow_update=False) >>> model = Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'}, ... amp_level="O2", keep_batchnorm_fp32=False) >>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim, ... loss_scale_manager=loss_scale, metrics={'acc'}, ... amp_level="O2", keep_batchnorm_fp32=False) >>> loss_cb = LossMonitor() >>> model.train(1, dataset, callbacks=loss_cb, sink_size=4, dataset_sink_mode=True) """ context.set_context(max_call_depth=10000) ConvertNetUtils().convert_to_thor_net(net) if context.get_context("device_target") == "Ascend": return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency) return ThorGpu(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, use_nesterov, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency)
class ThorGpu(Optimizer): """ ThorGpu """ def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100): params = filter(lambda x: x.requires_grad, net.get_parameters()) super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale) _check_param(momentum, frequency, learning_rate, self.__class__.__name__) self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters self.use_nesterov = Validator.check_bool(use_nesterov) self.moments = self.params.clone(prefix="moments", init='zeros') self.hyper_map = C.HyperMap() self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) self.net = net self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters())) self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters())) self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters())) self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters())) self.batch_size = Tensor(batch_size, mstype.float32) self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32) self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32) self.damping = damping self._define_gpu_operator() logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov))) self.thor = True self.matrix_a = () self.matrix_g = () self.matrix_a_shape = () self.thor_layer_count = 0 self.conv_layer_count = 0 self.weight_fim_idx_map = () self.weight_conv_idx_map = () self.weight_layertype_idx_map = () self._process_matrix_init_and_weight_idx_map(self.net) self.matrix_a = ParameterTuple(self.matrix_a) self.matrix_g = ParameterTuple(self.matrix_g) self.weight_decay = weight_decay self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim) self.enable_clip_grad = enable_clip_grad self.frequency = frequency self._define_gpu_reducer(split_indices) def get_frequency(self): """get thor frequency""" return self.frequency def _define_gpu_operator(self): """define gpu operator""" self.transpose = P.Transpose() self.shape = P.Shape() self.reshape = P.Reshape() self.matmul = P.MatMul() self.assign = P.Assign() self.mul = P.Mul() self.gather = P.GatherV2() self.one = Tensor(1, mstype.int32) self.feature_map = Tensor(1.0, mstype.float32) self.axis = 0 self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) self.cast = P.Cast() self.sqrt = P.Sqrt() self.eye = P.Eye() self.split_dim = 128 self.embedding_cholesky = P.CholeskyTrsm() self.cholesky = P.CholeskyTrsm(split_dim=self.split_dim) self.vector_matmul = P.BatchMatMul(transpose_a=True) self.reduce_sum = P.ReduceSum(keep_dims=False) self.inv = P.Reciprocal() self.square = P.Square() self.expand = P.ExpandDims() def _define_gpu_reducer(self, split_indices): """define gpu reducer""" self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) if self.is_distributed: mean = _get_gradients_mean() degree = _get_device_num() if not split_indices: self.split_indices = [len(self.matrix_a_cov) - 1] else: self.split_indices = split_indices auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8") self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6) self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8) def _process_matrix_init_and_weight_idx_map(self, net): """for GPU, process matrix init shape, and get weight idx map""" layer_type_map = get_net_layertype_mask(net) layer_counter = 0 for idx in range(len(self.params)): layer_type = layer_type_map[layer_counter] weight = self.params[idx] weight_shape = self.shape(weight) if layer_type in [Conv, FC] and "bias" not in self.params[idx].name.lower(): in_channels = weight_shape[1] out_channels = weight_shape[0] matrix_a_dim = in_channels if layer_type == Conv: matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3] matrix_g_dim = out_channels matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, self.split_dim) matrix_a_inv = Parameter(np.zeros(matrix_a_shape).astype(np.float32), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) matrix_g_inv = Parameter(np.zeros(matrix_g_shape).astype(np.float32), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.matrix_a = self.matrix_a + (matrix_a_inv,) self.matrix_g = self.matrix_g + (matrix_g_inv,) self.matrix_a_shape = self.matrix_a_shape + (matrix_a_shape,) elif layer_type == Embedding: vocab_size = weight_shape[0] embedding_size = weight_shape[1] matrix_a_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) matrix_g_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.matrix_a = self.matrix_a + (matrix_a_inv,) self.matrix_g = self.matrix_g + (matrix_g_inv,) self.matrix_a_shape = self.matrix_a_shape + ((vocab_size,),) if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower(): self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,) self.thor_layer_count = self.thor_layer_count + 1 self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,) if layer_type == Conv: self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,) self.conv_layer_count = self.conv_layer_count + 1 else: self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) else: self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,) if layer_type == LayerNorm: self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,) else: self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,) # bert.cls1.output_bias: not a network layer, only a trainable param if "output_bias" not in self.params[idx].name.lower(): layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx) def _get_ainv_ginv_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce): """get matrixA inverse list and matrix G inverse list""" for i in range(len(self.params)): thor_layer_count = self.weight_fim_idx_map[i] conv_layer_count = self.weight_conv_idx_map[i] layer_type = self.weight_layertype_idx_map[i] if layer_type in [Conv, FC, Embedding]: g = gradients[i] matrix_a = self.matrix_a_cov[thor_layer_count] matrix_g = self.matrix_g_cov[thor_layer_count] matrix_a = F.depend(matrix_a, g) matrix_g = F.depend(matrix_g, g) damping_a = damping_step damping_g = damping_step feature_map = self.feature_map if layer_type == Conv: a_normalizer = self.a_normalizer[conv_layer_count] g_normalizer = self.g_normalizer[conv_layer_count] a_normalizer = F.depend(a_normalizer, g) g_normalizer = F.depend(g_normalizer, g) damping_a = self.mul(damping_step, 1.0 / a_normalizer) damping_g = self.mul(damping_step, 1.0 / g_normalizer) feature_map = self.sqrt(1.0 / a_normalizer) a_shape = self.shape(matrix_a) a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) damping_a = self.sqrt(damping_a) damping_g = self.sqrt(damping_g) g_shape = self.shape(matrix_g) g_eye = self.eye(g_shape[0], g_shape[1], mstype.float32) matrix_g = self.mul(matrix_g, self.loss_scale) matrix_g = self.mul(matrix_g, self.batch_size_scale) matrix_g = matrix_g + damping_g * g_eye if layer_type == Embedding: a_eye = P.OnesLike()(matrix_a) matrix_a = self.mul(matrix_a, 1.0 / self.batch_size) matrix_a = matrix_a + damping_a * a_eye matrix_a = self.inv(matrix_a) matrix_g = self.embedding_cholesky(matrix_g) matrix_g = self.matmul(matrix_g, matrix_g) else: matrix_a = matrix_a + damping_a * a_eye matrix_a = self.cholesky(matrix_a) matrix_a = self.vector_matmul(matrix_a, matrix_a) matrix_a = P.BroadcastTo(self.matrix_a_shape[thor_layer_count])(matrix_a) matrix_g = self.cholesky(matrix_g) matrix_g = self.vector_matmul(matrix_g, matrix_g) matrix_a = self.mul(matrix_a, feature_map) matrix_g = self.mul(matrix_g, feature_map) matrix_a_allreduce = matrix_a_allreduce + (matrix_a,) matrix_g_allreduce = matrix_g_allreduce + (matrix_g,) return matrix_a_allreduce, matrix_g_allreduce def _process_layernorm(self, damping_step, gradient): """process layernorm""" damping = self.sqrt(damping_step) normalizer = self.batch_size normalizer = self.cast(normalizer, mstype.float32) fim_cov = self.square(gradient) fim_cov = self.mul(fim_cov, 1.0 / normalizer) fim_cov = fim_cov + damping fim_inv = self.inv(fim_cov) gradient = self.mul(fim_inv, gradient) return gradient def _reshape_gradient(self, conv_layer_count, g, g_shape): """reshape gradient""" if conv_layer_count != -1: g = self.reshape(g, g_shape) return g def construct(self, gradients): params = self.params moments = self.moments gradients = self.scale_grad(gradients) damping_step = self.gather(self.damping, self.cov_step, self.axis) damping_step = self.cast(damping_step, mstype.float32) new_grads = () if self.thor: matrix_ainv_list = () matrix_ginv_list = () matrix_a_allreduce, matrix_g_allreduce = self._get_ainv_ginv_list(gradients, damping_step, matrix_ainv_list, matrix_ginv_list) if self.is_distributed: matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce) matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce) for i in range(len(self.params)): g = gradients[i] thor_layer_count = self.weight_fim_idx_map[i] conv_layer_count = self.weight_conv_idx_map[i] layer_type = self.weight_layertype_idx_map[i] if layer_type in [Conv, FC]: g_shape = self.shape(g) g = self.reshape(g, (g_shape[0], -1)) matrix_a = matrix_a_allreduce[thor_layer_count] matrix_g = matrix_g_allreduce[thor_layer_count] g = self.update_gradient(matrix_g, g, matrix_a) self.assign(self.matrix_a[thor_layer_count], matrix_a) self.assign(self.matrix_g[thor_layer_count], matrix_g) g = self._reshape_gradient(conv_layer_count, g, g_shape) elif layer_type == Embedding: matrix_a = matrix_a_allreduce[thor_layer_count] matrix_g = matrix_g_allreduce[thor_layer_count] self.assign(self.matrix_a[thor_layer_count], matrix_a) self.assign(self.matrix_g[thor_layer_count], matrix_g) temp_a = self.expand(matrix_a, 1) g = self.mul(temp_a, g) g = self.matmul(g, matrix_g) elif layer_type == LayerNorm: g = self._process_layernorm(damping_step, g) new_grads = new_grads + (g,) else: for j in range(len(self.params)): g = gradients[j] thor_layer_count = self.weight_fim_idx_map[j] conv_layer_count = self.weight_conv_idx_map[j] layer_type = self.weight_layertype_idx_map[j] if layer_type in [Conv, FC]: g_shape = self.shape(g) g = self.reshape(g, (g_shape[0], -1)) matrix_a = self.matrix_a[thor_layer_count] matrix_g = self.matrix_g[thor_layer_count] g = self.update_gradient(matrix_g, g, matrix_a) g = self._reshape_gradient(conv_layer_count, g, g_shape) elif layer_type == Embedding: matrix_a = self.matrix_a[thor_layer_count] matrix_g = self.matrix_g[thor_layer_count] g = gradients[j] temp_a = self.expand(matrix_a, 1) g = self.mul(temp_a, g) g = self.matmul(g, matrix_g) elif layer_type == LayerNorm: g = self._process_layernorm(damping_step, g) new_grads = new_grads + (g,) gradients = new_grads self.cov_step = self.cov_step + self.one if self.weight_decay > 0: gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) gradients = clip_gradient(self.enable_clip_grad, gradients) lr = self.get_lr() success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) return success class ThorAscend(Optimizer): """ThorAscend""" def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100): params = filter(lambda x: x.requires_grad, net.get_parameters()) super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale) _check_param(momentum, frequency, learning_rate, self.__class__.__name__) self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters self.moments = self.params.clone(prefix="moments", init='zeros') self.hyper_map = C.HyperMap() self.opt = P.ApplyMomentum() self.net = net self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters())) self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters())) self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters())) self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters())) logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov))) self._define_ascend_operator() self.C0 = 16 self.device_shape_pad_flag = () self.diag_block_dim = 128 self.matrix_a = () self.matrix_g = () self.thor_layer_count = 0 self.conv_layer_count = 0 self.weight_conv_idx_map = () self.weight_fim_idx_map = () self.weight_layertype_idx_map = () self.a_split_pad_dim_map = () self.g_split_pad_dim_map = () self.conv_matmul_support_map = () self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36] self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32] self._process_matrix_init_and_weight_idx_map(self.net) self.matrix_a = ParameterTuple(self.matrix_a) self.matrix_g = ParameterTuple(self.matrix_g) self.matrix_max_inv = () for i in range(len(self.matrix_a)): self.matrix_max_inv = self.matrix_max_inv + ( Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) self.thor = True self.weight_decay = weight_decay self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.damping = damping self.batch_size = Tensor(batch_size, mstype.float32) self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32) self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32) self.enable_clip_grad = enable_clip_grad self.frequency = frequency self._define_ascend_reducer(split_indices) def get_frequency(self): """get thor frequency""" return self.frequency def _get_pad_dim(self, matrix_dim): """get diag split pad dim """ split_pad_dim = 0 if matrix_dim == 64: return split_pad_dim res = matrix_dim % self.diag_block_dim if res != 0: split_pad_dim = self.diag_block_dim - res return split_pad_dim def _define_ascend_operator(self): """define ascend operator""" self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() self.transpose = P.Transpose() self.shape = P.Shape() self.reshape = P.Reshape() self.mul = P.Mul() self.log = P.Log() self.exp = P.Exp() self.sqrt = P.Sqrt() self.gather = P.GatherV2() self.assign = P.Assign() self.cast = P.Cast() self.eye = P.Eye() self.concat = P.Concat(0) self.cholesky = P.CusCholeskyTrsm() self.vector_matmul = P.CusBatchMatMul() self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True) self.fused_abs_max2 = P.CusFusedAbsMax1() self.matrix_combine = P.CusMatrixCombine() self.slice = P.Slice() self.expand = P.ExpandDims() self.reduce_sum = P.ReduceSum(keep_dims=False) self.square = P.Square() self.inv = P.Inv() self.matmul = P.MatMul() self.axis = 0 self.one = Tensor(1, mstype.int32) self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) def _define_ascend_reducer(self, split_indices): """define ascend reducer""" self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) if self.is_distributed: mean = _get_gradients_mean() degree = _get_device_num() if not split_indices: self.split_indices = [len(self.matrix_a_cov) - 1] else: self.split_indices = split_indices if self.conv_layer_count > 0: auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2") auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4") self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2) self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4) auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8") self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6) self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8) def _get_weight_idx_map(self, layer_type, idx, weight_shape): """for Ascend, get weight idx map""" if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower(): self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,) self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,) if layer_type == Embedding: a_pad_dim = 0 g_pad_dim = 0 self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,) self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,) else: out_channels = weight_shape[0] g_pad_dim = self._get_pad_dim(out_channels) self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,) matrix_a_dim = weight_shape[1] if layer_type == Conv: matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3] a_pad_dim = self._get_pad_dim(matrix_a_dim) self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,) self.thor_layer_count = self.thor_layer_count + 1 if layer_type == Conv: self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,) self.conv_layer_count = self.conv_layer_count + 1 else: self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) else: self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,) self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,) if layer_type == LayerNorm: self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,) else: self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,) def _get_fc_matrix(self, weight_shape): """for Ascend, get fc matrix_a and matrix_g""" out_channels = weight_shape[0] in_channels = weight_shape[1] if self.conv_layer_count > 0: if out_channels == 1001: fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) else: fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.matrix_a = self.matrix_a + (fc_matrix_a,) self.matrix_g = self.matrix_g + (fc_matrix_g,) def _process_matrix_init_and_weight_idx_map(self, net): """for Ascend, process matrix init shape, and get weight idx map""" layer_counter = 0 layer_type_map = get_net_layertype_mask(net) for idx in range(len(self.params)): layer_type = layer_type_map[layer_counter] weight = self.params[idx] weight_shape = self.shape(weight) if layer_type == Conv and "bias" not in self.params[idx].name.lower(): in_channels = weight_shape[1] out_channels = weight_shape[0] matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3] matrix_g_dim = out_channels matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True) matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False) ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape) if ret: matrix_a_inv = Parameter( Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) matrix_g_inv = Parameter( Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.conv_matmul_support_map = self.conv_matmul_support_map + (1,) else: matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)), name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False) matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)), name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False) self.conv_matmul_support_map = self.conv_matmul_support_map + (0,) self.matrix_a = self.matrix_a + (matrix_a_inv,) self.matrix_g = self.matrix_g + (matrix_g_inv,) device_shape_pad_flag = False if matrix_a_dim != matrix_a_device_dim: device_shape_pad_flag = True self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,) elif layer_type == FC and "bias" not in self.params[idx].name.lower(): self._get_fc_matrix(weight_shape) self._get_weight_idx_map(layer_type, idx, weight_shape) # bert.cls1.output_bias: not a network layer, only a trainable param if "output_bias" not in self.params[idx].name.lower(): layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx) def _process_batch_matmul(self, input_matrix): """process batch matmul""" input_matrix_shape = self.shape(input_matrix) if input_matrix_shape[0] in self.batch_matmul_support_list: input_matrix = self.vector_matmul(input_matrix, input_matrix) else: input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix) return input_matrix def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0): """process cholesky pad""" if pad_dim > 0: matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32) matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup) input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix) input_matrix = self.concat((input_matrix, matrix_sup)) return input_matrix def _get_abs_max(self, matrix_inv, origin_dim): """get matrix abs max""" cholesky_shape = self.shape(matrix_inv) if cholesky_shape[0] in self.abs_max_support_list: matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv) matrix_max = self.fused_abs_max2(matrix_inv_max) matrix_inv = self.matrix_combine(matrix_inv) else: matrix_inv = self.matrix_combine(matrix_inv) matrix_abs = P.Abs()(matrix_inv) matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs) return matrix_max, matrix_inv def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce): """get fc layer ainv and ginv""" thor_layer_count = self.weight_fim_idx_map[index] g = gradients[index] matrix_a = self.matrix_a_cov[thor_layer_count] matrix_g = self.matrix_g_cov[thor_layer_count] matrix_a = F.depend(matrix_a, g) matrix_g = F.depend(matrix_g, g) a_shape = self.shape(matrix_a) a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) g_shape = self.shape(matrix_g) g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) damping = self.sqrt(damping_step) matrix_a = matrix_a + damping * a_eye a_pad_dim = self.a_split_pad_dim_map[thor_layer_count] matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0]) matrix_a_inv = self.cholesky(matrix_a) matrix_a_inv = self._process_batch_matmul(matrix_a_inv) weight_shape = self.shape(self.params[index]) out_channels = weight_shape[0] in_channels = weight_shape[1] if out_channels == 2: matrix_a_inv = self.matrix_combine(matrix_a_inv) matrix_g_inv = g_eye else: matrix_g = self.mul(matrix_g, self.loss_scale) matrix_g = self.mul(matrix_g, self.batch_size_scale) matrix_g = matrix_g + damping * g_eye g_pad_dim = self.g_split_pad_dim_map[thor_layer_count] matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0]) matrix_g_inv = self.cholesky(matrix_g) matrix_g_inv = self._process_batch_matmul(matrix_g_inv) if self.conv_layer_count > 0: a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels) g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels) a_max = F.depend(a_max, g) g_max = F.depend(g_max, g) matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,) matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,) else: matrix_a_inv = self.matrix_combine(matrix_a_inv) matrix_g_inv = self.matrix_combine(matrix_g_inv) if a_pad_dim > 0: matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels)) if g_pad_dim > 0: matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels)) matrix_a_inv_shape = self.shape(matrix_a_inv) matrix_g_combine_shape = self.shape(matrix_g_inv) if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001: matrix_a_inv = self.reshape(matrix_a_inv, (matrix_a_inv_shape[0] / 16, 16, matrix_a_inv_shape[0] / 16, 16)) matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3)) matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv) matrix_g_inv_shape = self.shape(matrix_g_inv) matrix_g_inv = self.reshape(matrix_g_inv, (matrix_g_inv_shape[0] / 16, 16, matrix_g_inv_shape[0] / 16, 16)) matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3)) matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv): """process conv matmul device pad""" if self.device_shape_pad_flag[conv_layer_count]: kernel_hw = weight_shape[2] * weight_shape[3] in_channels = weight_shape[1] matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels)) matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0), (0, self.C0 - in_channels)))(matrix_a_inv) return matrix_a_inv def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce): """get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list""" for i in range(len(self.params)): thor_layer_count = self.weight_fim_idx_map[i] conv_layer_count = self.weight_conv_idx_map[i] layer_type = self.weight_layertype_idx_map[i] weight_shape = self.shape(self.params[i]) out_channels = weight_shape[0] if layer_type == Conv: g = gradients[i] matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3] matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] matrix_a = self.matrix_a_cov[thor_layer_count] matrix_g = self.matrix_g_cov[thor_layer_count] matrix_a = F.depend(matrix_a, g) matrix_g = F.depend(matrix_g, g) a_shape = self.shape(matrix_a) a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32) g_shape = self.shape(matrix_g) g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) a_normalizer = self.a_normalizer[conv_layer_count] g_normalizer = self.g_normalizer[conv_layer_count] a_normalizer = F.depend(a_normalizer, g) g_normalizer = F.depend(g_normalizer, g) damping_a = self.mul(damping_step, self.batch_size / a_normalizer) damping_g = self.mul(damping_step, self.batch_size / g_normalizer) damping_a = self.sqrt(damping_a) matrix_a = matrix_a + damping_a * a_eye a_pad_dim = self.a_split_pad_dim_map[thor_layer_count] matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0]) matrix_a_inv = self.cholesky(matrix_a) matrix_a_inv = self._process_batch_matmul(matrix_a_inv) a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim) damping_g = self.sqrt(damping_g) matrix_g = self.mul(matrix_g, self.loss_scale) matrix_g = self.mul(matrix_g, self.batch_size_scale) matrix_g = matrix_g + damping_g * g_eye g_pad_dim = self.g_split_pad_dim_map[thor_layer_count] matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0]) matrix_g_inv = self.cholesky(matrix_g) matrix_g_inv = self._process_batch_matmul(matrix_g_inv) g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels) if a_pad_dim > 0: matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim)) if g_pad_dim > 0: matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels)) if matmul_support_flag == 1: matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv) matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count]) matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2], matrix_a_inv_shape[1], matrix_a_inv_shape[3]) matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape) matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3)) matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count]) matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2], matrix_g_inv_shape[1], matrix_g_inv_shape[3]) matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape) matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3)) a_max = F.depend(a_max, g) g_max = F.depend(g_max, g) matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,) matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,) elif layer_type == FC: matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \ self._get_fc_ainv_ginv(i, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce) elif layer_type == Embedding: g = gradients[i] matrix_a = self.matrix_a_cov[thor_layer_count] matrix_g = self.matrix_g_cov[thor_layer_count] matrix_a = F.depend(matrix_a, g) matrix_g = F.depend(matrix_g, g) g_shape = self.shape(matrix_g) g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32) damping = self.sqrt(damping_step) a_eye = P.OnesLike()(matrix_a) matrix_a = self.mul(matrix_a, 1.0 / self.batch_size) matrix_a = matrix_a + damping * a_eye matrix_a_inv = self.inv(matrix_a) matrix_g = self.mul(matrix_g, self.loss_scale) matrix_g = self.mul(matrix_g, self.batch_size_scale) matrix_g = matrix_g + damping * g_eye matrix_g_inv = self.cholesky(matrix_g) matrix_g_inv = self._process_batch_matmul(matrix_g_inv) matrix_g_inv = self.matrix_combine(matrix_g_inv) matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,) return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce def _process_layernorm(self, damping_step, gradient): """process layernorm layer for thor""" damping = self.sqrt(damping_step) normalizer = self.cast(self.batch_size, mstype.float32) fim_cov = self.square(gradient) fim_cov = self.mul(fim_cov, 1.0 / normalizer) fim_cov = fim_cov + damping fim_inv = self.inv(fim_cov) gradient = self.mul(fim_inv, gradient) return gradient def _process_thor_fc(self, thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g): """process thor graph fc layer""" temp_a = matrix_a_allreduce[thor_layer_count] temp_g = matrix_g_allreduce[thor_layer_count] self.assign(self.matrix_a_cov[thor_layer_count], temp_a) self.assign(self.matrix_g_cov[thor_layer_count], temp_g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) return g def _get_second_gradients_one(self, params_len, gradients, new_grads): """get second gradients one""" for i in range(params_len): g = gradients[i] thor_layer_count = self.weight_fim_idx_map[i] conv_layer_count = self.weight_conv_idx_map[i] layer_type = self.weight_layertype_idx_map[i] matrix_a = self.matrix_a[thor_layer_count] matrix_g = self.matrix_g[thor_layer_count] matrix_max = self.matrix_max_inv[thor_layer_count] grad_shape = self.shape(g) if layer_type == FC: if grad_shape[0] == 1001: g = self.cube_matmul_left_fc(matrix_g, g) g = self.cube_matmul_right_fc(g, matrix_a, matrix_max) else: temp_a = self.cast(matrix_a, mstype.float16) temp_g = self.cast(matrix_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) g = self.mul(g, matrix_max) elif layer_type == Conv: matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] if matmul_support_flag == 1: g = self.cube_matmul_left(matrix_g, g) g = self.cube_matmul_right_mul(g, matrix_a, matrix_max) else: g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3])) temp_a = self.cast(matrix_a, mstype.float16) temp_g = self.cast(matrix_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) g = self.mul(g, matrix_max) g = self.reshape(g, grad_shape) new_grads = new_grads + (g,) return new_grads def _get_second_gradients(self, new_grads, damping_step, gradients): """get second gradients for thor""" params_len = len(self.params) if self.conv_layer_count > 0: new_grads = self._get_second_gradients_one(params_len, gradients, new_grads) else: for i in range(params_len): g = gradients[i] thor_layer_count = self.weight_fim_idx_map[i] layer_type = self.weight_layertype_idx_map[i] if layer_type == Embedding: temp_a_ori = self.matrix_a_cov[thor_layer_count] temp_g = self.matrix_g_cov[thor_layer_count] temp_a = self.expand(temp_a_ori, 1) g = self.mul(temp_a, g) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(g, temp_g) g = self.cast(g, mstype.float32) elif layer_type == FC: temp_a = self.matrix_a_cov[thor_layer_count] temp_g = self.matrix_g_cov[thor_layer_count] temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) elif layer_type == LayerNorm: g = self._process_layernorm(damping_step, g) new_grads = new_grads + (g,) return new_grads def _get_second_grad_by_matmul(self, index, temp_a, temp_g, g, temp_max): """get second gradient by matmul""" conv_layer_count = self.weight_conv_idx_map[index] layer_type = self.weight_layertype_idx_map[index] grad_shape = self.shape(g) if layer_type == FC: if grad_shape[0] == 1001: g = self.cube_matmul_left_fc(temp_g, g) g = self.cube_matmul_right_fc(g, temp_a, temp_max) else: temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) g = self.mul(g, temp_max) elif layer_type == Conv: a_normalizer = self.a_normalizer[conv_layer_count] a_normalizer = F.depend(a_normalizer, g) temp_max = self.mul(temp_max, self.batch_size / a_normalizer) matmul_support_flag = self.conv_matmul_support_map[conv_layer_count] if matmul_support_flag == 1: g = self.cube_matmul_left(temp_g, g) g = self.cube_matmul_right_mul(g, temp_a, temp_max) else: g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3])) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(temp_g, g) g = self.matmul(g, temp_a) g = self.cast(g, mstype.float32) g = self.mul(g, temp_max) g = self.reshape(g, grad_shape) return g, temp_max def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step): """get second gradient by layertype""" thor_layer_count = self.weight_fim_idx_map[index] layer_type = self.weight_layertype_idx_map[index] if layer_type == Embedding: temp_a_ori = matrix_a_allreduce[thor_layer_count] temp_g = matrix_g_allreduce[thor_layer_count] self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori) self.assign(self.matrix_g_cov[thor_layer_count], temp_g) temp_a = self.expand(temp_a_ori, 1) g = self.mul(temp_a, g) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) g = self.matmul(g, temp_g) g = self.cast(g, mstype.float32) elif layer_type == FC: g = self._process_thor_fc(thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g) elif layer_type == LayerNorm: g = self._process_layernorm(damping_step, g) return g def construct(self, gradients): params = self.params moments = self.moments gradients = self.scale_grad(gradients) damping_step = self.gather(self.damping, self.cov_step, self.axis) damping_step = self.cast(damping_step, mstype.float32) if self.thor: matrix_a_allreduce = () matrix_g_allreduce = () matrix_a_max_allreduce = () matrix_g_max_allreduce = () matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \ self._get_ainv_ginv_amax_gmax_list(gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce) if self.is_distributed: matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce) matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce) if self.conv_layer_count > 0: matrix_a_max_allreduce = self.grad_reducer_amax(matrix_a_max_allreduce) matrix_g_max_allreduce = self.grad_reducer_gmax(matrix_g_max_allreduce) new_grads = () if self.conv_layer_count > 0: for i in range(len(self.params)): g = gradients[i] thor_layer_count = self.weight_fim_idx_map[i] temp_a = matrix_a_allreduce[thor_layer_count] temp_g = matrix_g_allreduce[thor_layer_count] matrix_a_inv_max = self.log(matrix_a_max_allreduce[thor_layer_count]) matrix_a_inv_max = self.mul(matrix_a_inv_max, -1) matrix_a_inv_max = self.exp(matrix_a_inv_max) temp_a = self.mul(temp_a, matrix_a_inv_max) matrix_g_inv_max = self.log(matrix_g_max_allreduce[thor_layer_count]) matrix_g_inv_max = self.mul(matrix_g_inv_max, -1) matrix_g_inv_max = self.exp(matrix_g_inv_max) temp_g = self.mul(temp_g, matrix_g_inv_max) temp_max = self.mul(matrix_g_max_allreduce[thor_layer_count], matrix_g_max_allreduce[thor_layer_count]) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max) self.assign(self.matrix_a[thor_layer_count], temp_a) self.assign(self.matrix_g[thor_layer_count], temp_g) self.assign(self.matrix_max_inv[thor_layer_count], temp_max) new_grads = new_grads + (g,) gradients = new_grads else: for i in range(len(self.params)): g = gradients[i] g = self._get_second_grad_by_layertype(i, matrix_a_allreduce, matrix_g_allreduce, g, damping_step) new_grads = new_grads + (g,) gradients = new_grads else: new_grads = () gradients = self._get_second_gradients(new_grads, damping_step, gradients) self.cov_step = self.cov_step + self.one if self.weight_decay > 0: gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) gradients = clip_gradient(self.enable_clip_grad, gradients) lr = self.get_lr() success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) return success