mindspore.nn.Lamb

class mindspore.nn.Lamb(*args, **kwargs)[source]

An optimizer that implements the Lamb(Layer-wise Adaptive Moments optimizer for Batching training) algorithm.

LAMB is an optimization algorithm employing a layerwise adaptive large batch optimization technique. Refer to the paper LARGE BATCH OPTIMIZATION FOR DEEP LEARNING: TRAINING BERT IN 76 MINUTES.

The LAMB optimizer aims to increase the training batch size without reducing the accuracy, and it supports adaptive element-by-element update and accurate layered correction.

The updating of parameters follows:

\[\begin{split}\begin{gather*} m_t = \beta_1 m_{t - 1}+ (1 - \beta_1)g_t\\ v_t = \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2\\ m_t = \frac{m_t}{\beta_1^t}\\ v_t = \frac{v_t}{\beta_2^t}\\ r_t = \frac{m_t}{\sqrt{v_t}+\epsilon}\\ w_t = w_{t-1} -\eta_t \frac{\| w_{t-1} \|}{\| r_t + \lambda w_{t-1} \|} (r_t + \lambda w_{t-1}) \end{gather*}\end{split}\]

where \(m\) is the 1st moment, and \(v\) the 2nd moment, \(\eta\) the learning rate, \(\lambda\) the LAMB weight decay rate.

Note

There is usually no connection between a optimizer and mixed precision. But when FixedLossScaleManager is used and drop_overflow_update in FixedLossScaleManager is set to False, optimizer needs to set the ‘loss_scale’. As this optimizer has no argument of loss_scale, so loss_scale needs to be processed by other means, refer document LossScale to process loss_scale correctly.

If parameters are not grouped, the weight_decay in optimizer will be applied on the network parameters without ‘beta’ or ‘gamma’ in their names. Users can group parameters to change the strategy of decaying weight. When parameters are grouped, each group can set weight_decay, if not, the weight_decay in optimizer will be applied.

Parameters
  • params (Union[list[Parameter], list[dict]]) –

    Must be list of Parameter or list of dict. When the params is a list of dict, the string “params”, “lr”, “weight_decay”, “grad_centralization” and “order_params” are the keys can be parsed.

    • params: Required. Parameters in current group. The value must be a list of Parameter.

    • lr: Optional. If “lr” in the keys, the value of corresponding learning rate will be used. If not, the learning_rate in optimizer will be used. Fixed and dynamic learning rate are supported.

    • weight_decay: Optional. If “weight_decay” in the keys, the value of corresponding weight decay will be used. If not, the weight_decay in the optimizer will be used.

    • grad_centralization: Optional. Must be Boolean. If “grad_centralization” is in the keys, the set value will be used. If not, the grad_centralization is False by default. This configuration only works on the convolution layer.

    • order_params: Optional. When parameters is grouped, this usually is used to maintain the order of parameters that appeared in the network to improve performance. The value should be parameters whose order will be followed in optimizer. If order_params in the keys, other keys will be ignored and the element of ‘order_params’ must be in one group of params.

  • learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]) –

    • float: The fixed learning rate value. Must be equal to or greater than 0.

    • int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.

    • Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied. For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.

    • Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.

    • LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of LearningRateSchedule with step as the input to get the learning rate of current step.

  • beta1 (float) – The exponential decay rate for the 1st moment estimations. Default: 0.9. Should be in range (0.0, 1.0).

  • beta2 (float) – The exponential decay rate for the 2nd moment estimations. Default: 0.999. Should be in range (0.0, 1.0).

  • eps (float) – Term added to the denominator to improve numerical stability. Default: 1e-6. Should be greater than 0.

  • weight_decay (float) – Weight decay (L2 penalty). Default: 0.0. Should be equal to or greater than 0.

Inputs:
  • gradients (tuple[Tensor]) - The gradients of params, the shape is the same as params.

Outputs:

tuple[bool], all elements are True.

Raises
  • TypeError – If learning_rate is not one of int, float, Tensor, Iterable, LearningRateSchedule.

  • TypeError – If element of parameters is neither Parameter nor dict.

  • TypeError – If beta1, beta2 or eps is not a float.

  • TypeError – If weight_decay is neither float nor int.

  • ValueError – If eps is less than or equal to 0.

  • ValueError – If beta1, beta2 is not in range (0.0, 1.0).

  • ValueError – If weight_decay is less than 0.

Supported Platforms:

Ascend GPU

Examples

>>> from mindspore import nn, Model
>>>
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Lamb(params=net.trainable_params(), learning_rate=0.1)
>>>
>>> #2) Use parameter groups and set different values
>>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR(learning_rate=0.1, end_learning_rate=0.01,
...                                                    decay_steps=4, power = 0.5)
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
...                 {'params': no_conv_params, 'lr': poly_decay_lr},
...                 {'order_params': net.trainable_params(0.01)}]
>>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # centralization of True.
>>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default
>>> # weight decay of 0.0 and grad centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)