The updating formulas are as follows,

$\begin{split}\begin{array}{ll} \\ m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\ v_{t+1} = \max(\beta_2 * v_{t}, \left| g \right|) \\ var = var - \frac{l}{1 - \beta_1^{t+1}} * \frac{m_{t+1}}{v_{t+1} + \epsilon} \end{array}\end{split}$

$$t$$ represents updating step while $$m$$ represents the 1st moment vector, $$m_{t}$$ is the last moment of $$m_{t+1}$$, $$v$$ represents the 2nd moment vector, $$v_{t}$$ is the last moment of $$v_{t+1}$$, $$l$$ represents scaling factor lr, $$g$$ represents grad, $$\beta_1, \beta_2$$ represent beta1 and beta2, $$\beta_1^{t+1}$$ represents beta1_power, $$var$$ represents the variable to be updated, $$\epsilon$$ represents epsilon.

Inputs of var, m, v and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, the lower priority data type will be converted to the relatively highest priority data type.

Inputs:
• var (Parameter) - Variable to be updated. With float32 or float16 data type. The shape is $$(N, *)$$ where $$*$$ means, any number of additional dimensions.

• m (Parameter) - The 1st moment vector in the updating formula, has the same shape and type as var. With float32 or float16 data type.

• v (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients with the same shape and type as var. With float32 or float16 data type.

• beta1_power (Union[Number, Tensor]) - $$beta_1^t$$ in the updating formula, must be a scalar. With float32 or float16 data type.

• lr (Union[Number, Tensor]) - Learning rate, $$l$$ in the updating formula, must be a scalar. With float32 or float16 data type.

• beta1 (Union[Number, Tensor]) - The exponential decay rate for the 1st moment estimations, must be a scalar. With float32 or float16 data type.

• beta2 (Union[Number, Tensor]) - The exponential decay rate for the 2nd moment estimations, must be a scalar. With float32 or float16 data type.

• epsilon (Union[Number, Tensor]) - A small value added for numerical stability, must be a scalar. With float32 or float16 data type.

• grad (Tensor) - A tensor for gradient, has the same shape and type as var. With float32 or float16 data type.

Outputs:

Tuple of 3 Tensor, the updated parameters.

• var (Tensor) - The same shape and data type as var.

• m (Tensor) - The same shape and data type as m.

• v (Tensor) - The same shape and data type as v.

Raises
• TypeError – If dtype of var, m, v, beta_power, lr, beta1, beta2, epsilon or grad is neither float16 nor float32.

• TypeError – If beta_power, lr, beta1, beta2 or epsilon is neither a Number nor a Tensor.

• TypeError – If grad is not a Tensor.

• RuntimeError – If the data type of var, m, v and grad conversion of Parameter is not supported.

Supported Platforms:

Ascend

Examples

>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
...         self.m = Parameter(Tensor(np.array([[0.6, 0.5],
...                                             [0.2, 0.6]]).astype(np.float32)), name="m")
...         self.v = Parameter(Tensor(np.array([[0.9, 0.1],
...                                             [0.7, 0.8]]).astype(np.float32)), name="v")
...     def construct(self, beta1_power, lr, beta1, beta2, epsilon, grad):
...         out = self.apply_ada_max(self.var, self.m, self.v, beta1_power, lr, beta1, beta2, epsilon, grad)
...         return out
...
>>> net = Net()
>>> beta1_power =Tensor(0.9, mindspore.float32)
>>> lr = Tensor(0.001, mindspore.float32)
>>> beta1 = Tensor(0.9, mindspore.float32)
>>> beta2 = Tensor(0.99, mindspore.float32)
>>> epsilon = Tensor(1e-10, mindspore.float32)
>>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
>>> output = net(beta1_power, lr, beta1, beta2, epsilon, grad)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 5.93602717e-01,  3.92571449e-01],
[ 9.72582996e-02,  4.92249995e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 5.69999993e-01,  5.19999981e-01],
[ 1.89999998e-01,  6.20000005e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 8.90999973e-01,  6.99999988e-01],
[ 6.93000019e-01,  8.00000012e-01]]))