mindspore.mint.optim.AdamW
- class mindspore.mint.optim.AdamW(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, *, maximize=False)[source]
Implements Adam Weight Decay algorithm.
\[\begin{split}\begin{array}{l} &\newline &\hline \\ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \: \epsilon \text{ (epsilon)} \\ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, \: \textit{maximize} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] &\newline &\hline \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{6mm}\textbf{if} \: \textit{maximize}: \\ &\hspace{11mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{6mm}\textbf{else} \\ &\hspace{11mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{6mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{6mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ &\hspace{6mm}\textbf{if} \: amsgrad \\ &\hspace{11mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, \widehat{v_t}) \\ &\hspace{11mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ &\hspace{6mm}\textbf{else} \\ &\hspace{11mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ &\newline &\hline \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\newline &\hline \\[-1.ex] \end{array}\end{split}\]More details of the AdamW algorithm can be found in the paper Decoupled Weight Decay Regularization and On the Convergence of Adam and Beyond.
Warning
This is an experimental optimizer API that is subject to change. This module must be used with lr scheduler module in LRScheduler Class .
For Ascend, it is only supported on platforms above Atlas A2.
- Parameters
params (Union[list(Parameter), list(dict)]) – list of parameters to optimize or dicts defining parameter groups.
lr (float, optional) – learning rate. Default:
1e-3
.betas (Tuple[float, float], optional) – The exponential decay rate for the moment estimations. Default:
(0.9, 0.999)
.eps (float, optional) – term added to the denominator to improve numerical stability. Must be greater than 0. Default:
1e-8
.weight_decay (float, optional) – weight decay (L2 penalty). Default:
1e-2.
.amsgrad (bool, optional) – whether to use the AMSGrad algorithm. Default:
False
.
- Keyword Arguments
maximize (bool, optional) – maximize the params based on the objective, instead of minimizing. Default:
False
.
- Inputs:
gradients (tuple[Tensor]) - The gradients of params.
- Raises
ValueError – If the learning rate is not float.
ValueError – If the learning rate is less than 0.
ValueError – If the eps is less than 0.
ValueError – If the betas not in the range of [0, 1).
ValueError – If the weight_decay is less than 0.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> from mindspore import mint >>> from mindspore.mint import optim >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> optimizer = optim.AdamW(net.trainable_params(), lr=0.1) >>> def forward_fn(data, label): ... logits = net(data) ... loss = loss_fn(logits, label) ... return loss, logits >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) >>> def train_step(data, label): ... (loss, _), grads = grad_fn(data, label) ... optimizer(grads) ... return loss