mindspore.nn.PipelineCell

class mindspore.nn.PipelineCell(network, micro_size, stage_config=None)[源代码]

将MiniBatch切分成更细粒度的MicroBatch,用于流水线并行的训练中。

说明

参数:
  • network (Cell) - 要修饰的目标网络。

  • micro_size (int) - MicroBatch大小。

  • stage_config (dict,可选) - 流水线并行对于每个cell的stage配置。默认值: None

支持平台:

Ascend GPU

样例:

>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.7.0rc1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> net = nn.PipelineCell(net, 4)