mindspore.experimental.optim.lr_scheduler.CosineAnnealingWarmRestarts
- class mindspore.experimental.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=- 1)[source]
- Set the learning rate of each parameter group using a cosine annealing warm restarts schedule. \[\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)\]- Where \(\eta_{max}\) is set to the initial lr, \(\eta_{min}\) is the minimum value for learning rate, \(\eta_{t}\) is the current learning rate, \(T_{0}\) is the number of iterations for the first restar, \(T_{i}\) is the current number of iterations between two warm restarts in SGDR, \(T_{cur}\) is the number of epochs since the last restart in SGDR. - When \(T_{cur}=T_{i}\), set \(\eta_t = \eta_{min}\). When \(T_{cur}=0\) after restart, set \(\eta_t=\eta_{max}\). - For more details, please refer to: SGDR: Stochastic Gradient Descent with Warm Restarts. - Warning - This is an experimental lr scheduler module that is subject to change. This module must be used with optimizers in Experimental Optimizer . - Parameters
- optimizer ( - mindspore.experimental.optim.Optimizer) – Wrapped optimizer.
- T_0 (int) – Number of iterations for the first restart. 
- T_mult (int, optional) – A factor increases \(T_{i}\) after a restart. Default: - 1.
- eta_min (Union(float, int), optional) – Minimum learning rate. Default: - 0.
- last_epoch (int, optional) – The number of times the step() method of the current learning rate adjustment strategy has been executed. Default: - -1.
 
- Raises
- ValueError – T_0 is less than or equal than 0 or not an int. 
- ValueError – T_mult is less than or equal than 1 or not an int. 
- ValueError – eta_min is not int or float. 
 
 - Supported Platforms:
- Ascend- GPU- CPU
 - Examples - >>> from mindspore.experimental import optim >>> from mindspore import nn >>> net = nn.Dense(3, 2) >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1, momentum=0.9) >>> scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 2) >>> iters = 3 >>> for epoch in range(2): ... for i in range(iters): ... scheduler.step(epoch + i / iters) ... current_lr = scheduler.get_last_lr() ... print(current_lr) [Tensor(shape=[], dtype=Float32, value= 0.1)] [Tensor(shape=[], dtype=Float32, value= 0.0933013)] [Tensor(shape=[], dtype=Float32, value= 0.075)] [Tensor(shape=[], dtype=Float32, value= 0.05)] [Tensor(shape=[], dtype=Float32, value= 0.025)] [Tensor(shape=[], dtype=Float32, value= 0.00669873)]