mindspore.nn.AdaSumByGradWrapCell
- class mindspore.nn.AdaSumByGradWrapCell(optimizer)[source]
Enable the adasum in "auto_parallel/semi_auto_parallel" mode. The implementation of the Adaptive Summation (AdaSum) algorithm is calculated by gradients. See the paper AdaSum: Scaling Distributed Training with Adaptive Summation.
\[\begin{split}\begin{array}{ll} w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\ w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 - \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\ \end{array}\end{split}\]In this implementation, \(g\) represents the gradient of the weights, and the subscripts represent different devices in the data-parallel dimension.
Warning
This interface is deprecated and will be removed after version 2.9.0.
Note
It is recommended to using AdaSumByGradWrapCell in semi auto parallel or auto parallel mode. In data parallel mode, we recommend to using mindspore.boost to applying AdaSum.
When using AdaSum, the number of training cards needs to be a power of 2 and at least 16 cards are required. Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum.
- Parameters:
optimizer (Union[Cell]) – Optimizer for updating the weights. The construct function of the optimizer requires only one input.
- Inputs:
grads (tuple[Tensor]) - Tuple of gradients, same as the input of passed optimizer.
- Raises:
RuntimeError – If parallel_mode uses stand_alone mode, AdaSum is only supported in distributed scenarios.
RuntimeError – If the optimizer parallel is used when using AdaSum.
RuntimeError – If the pipeline parallel is used when using AdaSum.
RuntimeError – If device_num is not a power of 2, or less than 16.
- Supported Platforms:
AscendGPU
Examples
>>> import mindspore as ms >>> from mindspore import nn >>> # Define the network structure of LeNet5. Refer to >>> # https://atomgit.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> optim = nn.AdaSumByGradWrapCell(nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)) >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None)