mindspore.nn.SGD

class mindspore.nn.SGD(params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, loss_scale=1.0)[source]

Implements stochastic gradient descent. Momentum is optional.

Introduction to SGD can be found at https://en.wikipedia.org/wiki/Stochastic_gradient_descent. Nesterov momentum is based on the formula from paper On the importance of initialization and momentum in deep learning.

\[v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)\]

If nesterov is True:

\[p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})\]

If nesterov is False:

\[p_{t+1} = p_{t} - lr \ast v_{t+1}\]

To be noticed, for the first step, \(v_{t+1} = gradient\)

Here : where p, v and u denote the parameters, accum, and momentum respectively.

Parameters
  • params (Union[list[Parameter], list[dict]]) –

    Must be list of Parameter or list of dict. When the params is a list of dict, the string “params”, “lr”, “grad_centralization” and “order_params” are the keys can be parsed.

    • params: Required. Parameters in current group. The value must be a list of Parameter.

    • lr: Optional. If “lr” in the keys, the value of corresponding learning rate will be used. If not, the learning_rate in optimizer will be used. Fixed and dynamic learning rate are supported.

    • weight_decay: Using different weight_decay by grouping parameters is currently not supported.

    • grad_centralization: Optional. Must be Boolean. If “grad_centralization” is in the keys, the set value will be used. If not, the grad_centralization is False by default. This configuration only works on the convolution layer.

    • order_params: Optional. When parameters is grouped, this usually is used to maintain the order of parameters that appeared in the network to improve performance. The value should be parameters whose order will be followed in optimizer. If order_params in the keys, other keys will be ignored and the element of ‘order_params’ must be in one group of params.

  • learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]) –

    Default: 0.1.

    • float: The fixed learning rate value. Must be equal to or greater than 0.

    • int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.

    • Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied. For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.

    • Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.

    • LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of LearningRateSchedule with step as the input to get the learning rate of current step.

  • momentum (float) – A floating point value the momentum. must be at least 0.0. Default: 0.0.

  • dampening (float) – A floating point value of dampening for momentum. must be at least 0.0. Default: 0.0.

  • weight_decay (float) – Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.

  • nesterov (bool) – Enables the Nesterov momentum. If use nesterov, momentum must be positive, and dampening must be equal to 0.0. Default: False.

  • loss_scale (float) – A floating point value for the loss scale, which must be larger than 0.0. In general, use the default value. Only when FixedLossScaleManager is used for training and the drop_overflow_update in FixedLossScaleManager is set to False, then this value needs to be the same as the loss_scale in FixedLossScaleManager. Refer to class mindspore.FixedLossScaleManager for more details. Default: 1.0.

Inputs:
  • gradients (tuple[Tensor]) - The gradients of params, the shape is the same as params.

Outputs:

Tensor[bool], the value is True.

Raises

ValueError – If the momentum, dampening or weight_decay value is less than 0.0.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> from mindspore import nn
>>>
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.SGD(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params,'grad_centralization':True},
...                 {'params': no_conv_params, 'lr': 0.01},
...                 {'order_params': net.trainable_params()}]
>>> optim = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and default weight decay of 0.0
>>> # and grad centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim)