mindspore.parameter_broadcast

查看源文件
mindspore.parameter_broadcast(net, layout, cur_rank=0, initial_rank=0)[源代码]

在数据并行维度将参数广播给另外的卡。

警告

这是一个实验性API,后续可能修改或删除。

参数:
  • net (Cell) - 参数将被广播的网络。

  • layout (Dict) - 参数排布字典。 来自 mindspore.nn.Cell.parameter_layout_dict() 或 从文件中读取(如: 通过 mindspore.set_auto_parallel_context() 接口的 strategy_ckpt_config 参数保存的”strategy.ckpt”文件)。key为参数名, value为该参数的layout。

  • cur_rank (int,可选) - 当前卡的rank id。默认值: 0

  • initial_rank (int,可选) - 当前流水线并行stage起始rank id。默认值: 0

异常:
  • ValueError - cur_rank 不是当前卡的rank_id。

  • ValueError - initial_rank 不是当前pipeline_stage起始的rank_id。

  • ValueError - layout 中的参数名在 mindspore.nn.Cell.parameters_dict() 中找不到。

样例:

>>> import os
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops
>>> from mindspore.communication import init
>>> from mindspore.common.initializer import initializer
>>> from mindspore.train import Model
>>> from mindspore.parallel.parameter_broadcast import parameter_broadcast
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> ms.set_context(max_device_memory="28GB")
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
>>> init()
>>> ms.set_seed(1)
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.flatten = ops.Flatten()
...         self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32))
...         self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32))
...         self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32))
...         self.matmul1 = ops.MatMul()
...         self.relu1 = ops.ReLU()
...         self.matmul2 = ops.MatMul()
...         self.relu2 = ops.ReLU()
...         self.matmul3 = ops.MatMul()
...     def construct(self, x):
...         x = self.flatten(x)
...         x = self.matmul1(x, self.fc1_weight)
...         x = self.relu1(x)
...         x = self.matmul2(x, self.fc2_weight)
...         x = self.relu2(x)
...         logits = self.matmul3(x, self.fc3_weight)
...         return logits
>>> net = Network()
>>> net.matmul1.shard(((2, 4), (4, 1)))
>>> net.relu1.shard(((4, 1),))
>>> net.matmul2.shard(((1, 8), (8, 1)))
>>> net.relu2.shard(((8, 1),))
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> optim = nn.SGD(net.trainable_params(), 1e-2)
>>> loss = nn.CrossEntropyLoss()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model.train(1, dataset)
>>> ms.save_checkpoint(net, "./simple.ckpt", False)
>>> layout = model.train_network.parameter_layout_dict
>>> param_dict = load_checkpoint("./simple.ckpt")
>>> load_param_into_net(net, param_dict)
>>> rank_id = os.environ["RANK_ID"]
>>> parameter_broadcast(model.train_network, layout, int(rank_id), 0)
>>> class LossCallBack(Callback):
...     def step_end(self, run_context):
...         cb_params = run_context.original_args()
...         print("step end, cur step num: ", cb_params.cur_step_num, flush=True)
>>> model.train(1, dataset, callbacks=[LossCallBack()])