mindspore.experimental.optim.lr_scheduler.PolynomialLR

View Source On Gitee
class mindspore.experimental.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=5, power=1.0, last_epoch=- 1)[source]

For each epoch, the learning rate is adjusted by polynomial fitting. When the epoch is greater than or equal to total_iters , the learning rate is 0 . Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler.

The polynomial formula for learning rate calculation is as follows:

\[\begin{split}\begin{split} &factor = (\frac{1.0 - \frac{last\_epoch}{total\_iters}}{1.0 - \frac{last\_epoch - 1.0}{total\_iters}}) ^{power}\\ &lr = lr \times factor \end{split}\end{split}\]

Warning

This is an experimental lr scheduler module that is subject to change. This module must be used with optimizers in Experimental Optimizer .

Parameters
  • optimizer (mindspore.experimental.optim.Optimizer) – Wrapped optimizer.

  • total_iters (int, optional) – The number of iterations adjusting learning rate by polynomial fitting. Default: 5.

  • power (float, optional) – Power of polynomial. Default: 1.0.

  • last_epoch (int, optional) – The index of the last epoch. Default: -1.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import nn
>>> from mindspore.experimental import optim
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.fc = nn.Dense(16 * 5 * 5, 120)
...     def construct(self, x):
...         return self.fc(x)
>>> net = Net()
>>> optimizer = optim.Adam(net.trainable_params(), 0.01)
>>> scheduler = optim.lr_scheduler.PolynomialLR(optimizer)
>>> for i in range(6):
...     scheduler.step()
...     current_lr = scheduler.get_last_lr()
...     print(current_lr)
[Tensor(shape=[], dtype=Float32, value= 0.008)]
[Tensor(shape=[], dtype=Float32, value= 0.006)]
[Tensor(shape=[], dtype=Float32, value= 0.004)]
[Tensor(shape=[], dtype=Float32, value= 0.002)]
[Tensor(shape=[], dtype=Float32, value= 0)]
[Tensor(shape=[], dtype=Float32, value= 0)]