mindspore.parallel.distributed.DistributedDataParallel
- class mindspore.parallel.distributed.DistributedDataParallel(module, init_sync=True, process_group=None, bucket_cap_mb: Optional[int] = None, find_unused_parameters=False, average_in_collective: bool = False, static_graph=False, reducer_mode='CppReducer')[源代码]
分布式数据并行封装类。该类为梯度分配连续显存,各参数的梯度将被分入多个桶,该桶是在数据并行域执行 all-reduce 通信以实现通信掩盖的基本单元。
警告
该方法当前仅支持在PyNative模式下使用。
- 参数:
module (nn.Cell) - 需要进行分布式梯度规约的网络。
init_sync (bool,可选) - 初始化时,是否进行rank0网络参数广播同步。默认值:
True
。process_group (str,可选) - 梯度规约通信组。默认行为是全局同步。默认值:
None
。bucket_cap_mb (int,可选) - 分桶梯度规约的桶大小,单位为MB。不填写时默认采用25MB。默认值:
None
。find_unused_parameters (bool,可选) - 是否搜索未使用参数。默认值:
False
。average_in_collective (bool,可选) - 是否在通信后求平均,True时先做AllReduce SUM后scale dp size,否则先做scaling后规约。默认值:
False
。static_graph (bool,可选) - 指明是否是静态网络。当是静态网络时,将忽略参数 find_unused_parameters,并在第一个step搜索未使用参数,在第二个step前按执行顺序进行桶重建,以实现更好的性能收益。默认值:
False
。reducer_mode (str,可选) - 后端梯度规约模式,
"CppReducer"
表示采用CPP后端,"PythonReducer"
表示采用Python后端。默认值:"CppReducer"
。
- 返回:
被DistributedDataParallel类封装的nn.Cell网络,网络将自动完成反向梯度规约。
- 支持平台:
Ascend
样例:
说明
当前接口不支持GPU、CPU版本的MindSpore调用。
使能重计算、梯度冻结时,必须在最外层使用 DistributedDataParallel 类进行封装。
在运行以下示例之前,您需要配置通信环境变量。针对Ascend设备,推荐使用msrun启动方式,无第三方以及配置文件依赖。详见 msrun启动 。
>>> from mindspore.parallel.distributed import DistributedDataParallel >>> from mindspore.mint.optim import AdamW >>> from mindspore import Parameter, Tensor, ops, nn >>> import mindspore as ms >>> from mindspore.communication import init >>> from mindspore.mint.distributed.distributed import init_process_group >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> init_process_group() >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = DistributedDataParallel(module=net, ... bucket_cap_mb=None, ... average_in_collective=True, ... static_graph=True) >>> optimizer = AdamW(net.trainable_params(), 1e-4) >>> loss_fn = nn.CrossEntropyLoss() >>> >>> def forward_fn(data, target): ... logits = net(data) ... loss = loss_fn(logits, target) ... return loss, logits >>> >>> grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True) >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> for epoch in range(1): ... step = 0 ... for image, label in dataset: ... (loss_value, _), grads = grad_fn(image, label) ... optimizer(grads) ... net.zero_grad() ... step += 1 ... print("epoch: %s, step: %s, loss is %.15f" % (epoch, step, loss_value))