mindspore.sync_pipeline_shared_parameters

View Source On Gitee
mindspore.sync_pipeline_shared_parameters(net)[source]

synchronize pipeline parallel stage shared parameters. Parameters may be shared between different stages. For example, embedding table is shared by WordEmbedding layer and LMHead layer, which are usually split into different stages. It is necessary to perform synchronization after embedding table changes.

Note

The network should be compiled before synchronize pipeline parallel stage shared parameters.

Parameters

net (nn.Cell) – the inference network.

Supported Platforms:

Ascend

Examples

Note

Before running the following examples, you need to configure the communication environment variables.

For the Ascend device, users need to write a dynamic cluster startup script, please see the Dynamic Cluster Startup .

>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.communication.management as D
>>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class Embedding(nn.Cell):
...     def __init__(self, shape):
...         super().__init__()
...         self.w = Parameter(Tensor(np.ones(shape), ms.float32), name='w')
...         self.matmul = ops.MatMul().shard(((1, 1), (1, 1)))
...     def construct(self, x):
...         return self.matmul(x, self.w), self.w
...
>>> class LMHead(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.matmul = ops.MatMul(transpose_b=True).shard(((1, 1), (1, 1)))
...     def construct(self, x, w):
...         return self.matmul(x, w)
...
>>> class Network(nn.Cell):
...     @lazy_inline
...     def __init__(self):
...         super().__init__()
...         shape = (4, 4)
...         self.word_embedding = Embedding(shape)
...         self.lm_head = LMHead()
...         self.word_embedding.pipeline_stage = 0
...         self.lm_head.pipeline_stage = 1
...     def construct(self, x):
...         x, embed = self.word_embedding(x)
...         return self.lm_head(x, embed)
...
>>> class PipelineCellInference(nn.Cell):
...     def __init__(self, network, micro_batch_num):
...         super().__init__()
...         self.network = network
...         self.micro_batch_num = micro_batch_num
...         self.concat = ops.Concat()
...     def construct(self, x):
...         ret = ()
...         for i in range(self.micro_batch_num):
...             micro_batch_size = x.shape[0] // self.micro_batch_num
...             start = micro_batch_size * i
...             end = micro_batch_size * (i + 1)
...             micro_input = x[start:end]
...             y = self.network(micro_input)
...             ret = ret + (y,)
...         ret = self.concat(ret)
...         return ret
>>> D.init()
>>> context.set_auto_parallel_context(parallel_mode='semi_auto_parallel', full_batch=True, pipeline_stages=2)
>>> net = Network()
>>> net = PipelineCellInference(net, 2)
>>> net.set_train(False)
>>> x = Tensor(np.ones((2, 4)), ms.float32)
>>> net.compile(x)
>>> ms.sync_pipeline_shared_parameters(net)
>>> print(net.network.word_embedding.w.asnumpy())
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]