mindspore.experimental.optim.Optimizer

View Source On Gitee
class mindspore.experimental.optim.Optimizer(params, defaults)[source]

Base class for all optimizers.

Warning

This is an experimental optimizer API that is subject to change. This module must be used with lr scheduler module in LRScheduler Class .

Parameters
  • params (Union[list(Parameter), list(dict)]) – an iterable of mindspore.Parameter or dict. Specifies what Tensors should be optimized.

  • defaults (dict) – a dict containing default values of optimization options (used when a parameter group doesn’t specify them).

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import nn, Tensor, Parameter
>>> from mindspore import ops
>>> from mindspore.experimental import optim
>>>
>>> class MySGD(optim.Optimizer):
...    def __init__(self, params, lr):
...        defaults = dict(lr=lr)
...        super(MySGD, self).__init__(params, defaults)
...
...    def construct(self, gradients):
...         for group_id, group in enumerate(self.param_groups):
...            id = self.group_start_id[group_id]
...            for i, param in enumerate(group["params"]):
...                next_param = param + gradients[id+i] * group["lr"]
...                ops.assign(param, next_param)
>>>
>>> net = nn.Dense(8, 2)
>>> data = Tensor(np.random.rand(20, 8).astype(np.float32))
>>> label = Tensor(np.random.rand(20, 2).astype(np.float32))
>>>
>>> optimizer = MySGD(net.trainable_params(), 0.01)
>>> optimizer.add_param_group({"params": Parameter([0.01, 0.02])})
>>>
>>> criterion = nn.MAELoss(reduction="mean")
>>>
>>> def forward_fn(data, label):
...    logits = net(data)
...    loss = criterion(logits, label)
...    return loss, logits
>>>
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
>>>
>>> def train_step(data, label):
...    (loss, _), grads = grad_fn(data, label)
...    optimizer(grads)
...    print(loss)
>>>
>>> train_step(data, label)
add_param_group(param_group)[source]

Add a param group to the Optimizer.param_groups.

Parameters

param_group (dict) – Specifies what Parameters should be optimized along with group specific optimization options.