mindspore.parallel.sync_pipeline_shared_parameters

mindspore.parallel.sync_pipeline_shared_parameters(net)[source]

Synchronization of shared weights between stages for pipeline parallel inference scenarios. For example, embedding table is shared by WordEmbedding layer and LMHead layer, which are usually split into different stages. It is necessary to perform synchronization after embedding table changes.

Note

The network should be compiled before shared parameters are synchronized in the pipeline parallel stage.

Parameters

net (Cell) – the inference network.

Raises

TypeErrornet is not in Cell type.

Supported Platforms:

Ascend

Examples

Note

Before running the following examples, you need to configure the communication environment variables.

For the Ascend device, users need to write a dynamic cluster startup script, please see the Dynamic Cluster Startup .

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.communication.management as D
>>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor
>>> from mindspore.parallel.auto_parallel import AutoParallel
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class Embedding(nn.Cell):
...     def __init__(self, shape):
...         super().__init__()
...         self.w = Parameter(Tensor(np.ones(shape), ms.float32), name='w')
...         self.matmul = ops.MatMul().shard(((1, 1), (1, 1)))
...     def construct(self, x):
...         return self.matmul(x, self.w), self.w
...
>>> class LMHead(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.matmul = ops.MatMul(transpose_b=True).shard(((1, 1), (1, 1)))
...     def construct(self, x, w):
...         return self.matmul(x, w)
...
>>> class Network(nn.Cell):
...     @lazy_inline
...     def __init__(self):
...         super().__init__()
...         shape = (4, 4)
...         self.word_embedding = Embedding(shape)
...         self.lm_head = LMHead()
...         self.word_embedding.pipeline_stage = 0
...         self.lm_head.pipeline_stage = 1
...     def construct(self, x):
...         x, embed = self.word_embedding(x)
...         return self.lm_head(x, embed)
...
>>> class PipelineCellInference(nn.Cell):
...     def __init__(self, network, micro_batch_num):
...         super().__init__()
...         self.network = network
...         self.micro_batch_num = micro_batch_num
...         self.concat = ops.Concat()
...     def construct(self, x):
...         ret = ()
...         for i in range(self.micro_batch_num):
...             micro_batch_size = x.shape[0] // self.micro_batch_num
...             start = micro_batch_size * i
...             end = micro_batch_size * (i + 1)
...             micro_input = x[start:end]
...             y = self.network(micro_input)
...             ret = ret + (y,)
...         ret = self.concat(ret)
...         return ret
>>> D.init()
>>> net = Network()
>>> net = PipelineCellInference(net, 2)
>>> net.set_train(False)
>>> x = Tensor(np.ones((2, 4)), ms.float32)
>>> net.compile(x)
>>> pp_net = AutoParallel(net, parallel_mode="semi_auto")
>>> pp_net.full_batch = True
>>> pp_net.pipeline(stages=2, scheduler="1f1b")
>>> ms.parallel.sync_pipeline_shared_parameters(pp_net)
>>> print(pp_net.network.network.word_embedding.w.asnumpy())
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]