mindspore.nn.GradAccumulationCell

View Source On Gitee
class mindspore.nn.GradAccumulationCell(network, micro_size)[source]

Wrap the network with Micro Batch to enable the grad accumulation in semi_auto_parallel/auto_parallel mode.

Parameters
  • network (Cell) – The target network to wrap.

  • micro_size (int) – MicroBatch size.

Supported Platforms:

Ascend GPU

Examples

>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> net = nn.GradAccumulationCell(net, 4)