mindspore.parallel.distributed.DistributedDataParallel
- class mindspore.parallel.distributed.DistributedDataParallel(module, init_sync=True, process_group=None, bucket_cap_mb: Optional[int] = None, find_unused_parameters=False, average_in_collective: bool = False, static_graph=False, reducer_mode='CppReducer')[source]
DistributedDataParallel wrapper. DistributedDataParallel allocates contiguous memory buffer for gradients. Parameters' gradients will be combined into multiple buckets which are the unit to conduct all-reduce communication among data parallel group to overlap communication latency.
Warning
The method is currently only supported in PyNative mode.
- Parameters
module (nn.Cell) – the module to be wrapped with DDP.
init_sync (bool, optional) – whether to sync params from rank0 of process_group when init. Default:
True
.process_group (str, optional) – the comm group of data prallel. Default:
None
.bucket_cap_mb (int, optional) – size of bucket in MB, default is 25MB if not set. Default:
None
.find_unused_parameters (bool, optional) – whether to find unused params in the bucket. Default:
False
.average_in_collective (bool, optional) – True means allreduce sum within DP group firstly then scaling with dp size. Otherwise scaling local rank grad first and then allreduce sum. Default:
False
.static_graph (bool, optional) – Indicate whether it is a static network. When it is a static network, the parameter find_unused_parameters will be ignored, and unused parameters will be searched for in the first step. Bucket reconstruction will be performed in execution order before the second step to achieve better performance. Default:
False
.reducer_mode (str, optional) – the backend to be used, could be "CppReducer" for cpp backend or "PythonReducer" for Python backend. Default:
"CppReducer"
.
- Returns
Model wrapped with DistributedDataParallel.
- Supported Platforms:
Ascend
Examples
Note
Current API does not support GPU/CPU version of MindSpore
When enabling recomputation or gradient freezing, the model should be wrapped by DistributedDataParallel at the outermost layer.
Before running the following examples, you need to configure the communication environment variables. For Ascend devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. For detailed information, refer to msrun launch .
>>> from mindspore.parallel.distributed import DistributedDataParallel >>> from mindspore.mint.optim import AdamW >>> from mindspore import Parameter, Tensor, ops, nn >>> import mindspore as ms >>> from mindspore.communication import init >>> from mindspore.mint.distributed.distributed import init_process_group >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> init_process_group() >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = DistributedDataParallel(module=net, ... bucket_cap_mb=None, ... average_in_collective=True, ... static_graph=True) >>> optimizer = AdamW(net.trainable_params(), 1e-4) >>> loss_fn = nn.CrossEntropyLoss() >>> >>> def forward_fn(data, target): ... logits = net(data) ... loss = loss_fn(logits, target) ... return loss, logits >>> >>> grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True) >>> >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> for epoch in range(1): ... step = 0 ... for image, label in dataset: ... (loss_value, _), grads = grad_fn(image, label) ... optimizer(grads) ... net.zero_grad() ... step += 1 ... print("epoch: %s, step: %s, loss is %.15f" % (epoch, step, loss_value))