mindspore.nn.AdaSumByDeltaWeightWrapCell
- class mindspore.nn.AdaSumByDeltaWeightWrapCell(optimizer)[source]
Enable the adasum in "auto_parallel/semi_auto_parallel" mode. The implementation of the Adaptive Summation (AdaSum) algorithm is calculated based on the difference of weights before and after the updating of optimizer. See the paper AdaSum: Scaling Distributed Training with Adaptive Summation.
\[\begin{split}\begin{array}{ll} w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\ w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 - \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\ \end{array}\end{split}\]In this implementation, \(g\) represents the weight difference before and after the updating of optimizer, and the subscripts represent different devices in the data parallel dimension.
Warning
This interface is deprecated and will be removed after version 2.9.0.
Note
When using AdaSum, the number of training cards needs to be a power of 2 and at least 16 cards are required. Currently, the optimizer sharding and pipeline parallel are not supported when using AdaSum. It is recommended to using AdaSumByDeltaWeightWrapCell in semi auto parallel or auto parallel mode. In data parallel mode, we recommend to using mindspore.boost to applying AdaSum.
- Parameters:
optimizer (Union[Cell]) – Optimizer for updating the weights. The construct function of the optimizer requires only one input.
- Inputs:
grads (Tuple(Tensor)) - Tuple of gradients, same as the input of passed optimizer.
- Raises:
RuntimeError – If parallel_mode uses stand_alone mode, AdaSum is only supported in distributed scenarios.
RuntimeError – If the optimizer parallel is used when using AdaSum.
RuntimeError – If the pipeline parallel is used when using AdaSum.
RuntimeError – If device_num is not a power of 2, or less than 16.
- Supported Platforms:
AscendGPU
Examples
>>> import mindspore as ms >>> from mindspore import nn >>> # Define the network structure of LeNet5. Refer to >>> # https://atomgit.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> optim = nn.AdaSumByDeltaWeightWrapCell(nn.Momentum(params=net.trainable_params(), ... learning_rate=0.1, momentum=0.9)) >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None)