mindspore.load_distributed_checkpoint

查看源文件
mindspore.load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM')[源代码]

给分布式预测加载checkpoint文件到网络。用于分布式推理。关于分布式推理的细节,请参考: 分布式模型加载

参数:
  • network (Cell) - 分布式预测网络。

  • checkpoint_filenames (list[str]) - checkpoint文件的名称,按rank id顺序排列。

  • predict_strategy (dict) - 预测时参数的切分策略。默认值: None

  • train_strategy_filename (str) - 训练策略proto文件名。默认值: None

  • strict_load (bool) - 表示是否严格加载参数到网络。如果值为 False ,则当checkpoint文件中参数名称的后缀与网络中的参数相同时,加载参数到网络。当类型不一致时,对相同类型的参数进行类型转换,如从float32到float16。默认值: False

  • dec_key (Union[None, bytes]) - 用于解密的字节类型key。如果value为 None ,则不需要解密。默认值: None

  • dec_mode (str) - 仅当dec_key不设为 None 时,该参数有效。指定了解密模式,目前支持 'AES-GCM''AES-CBC''SM4-CBC' 。默认值: 'AES-GCM'

异常:
  • TypeError - 输入类型不符合要求。

  • ValueError - 无法加载checkpoint文件到网络。

支持平台:

Ascend GPU

样例:

说明

运行以下样例之前,需要配置好通信环境变量。

针对Ascend设备,用户需要准备rank表,设置rank_id和device_id,详见 rank table启动

针对GPU设备,用户需要准备host文件和mpi,详见 mpirun启动

针对CPU设备,用户需要编写动态组网启动脚本,详见 动态组网启动

>>> import os
>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn, ops, train
>>> from mindspore.communication import init
>>>
>>> step_per_epoch = 4
>>> device_num = 8
>>>
>>> # Define the network structure.
>>> class Net(nn.Cell):
...     def __init__(self, matmul_size, strategy=None):
...         super().__init__()
...         matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
...         self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
...         self.matmul = ops.MatMul()
...         self.neg = ops.Neg()
...         if strategy is not None:
...             self.matmul.shard(strategy)
...
...     def construct(self, inputs):
...         x = self.matmul(inputs, self.matmul_weight)
...         x = self.neg(x)
...         return x
>>>
>>> # Create dataset.
>>> def get_dataset(*inputs):
...     def generate():
...         for _ in range(step_per_epoch):
...             yield inputs
...     return generate
>>>
>>> # Train network and save distributed checkpoint.
>>> def train_net():
...     ms.set_context(mode=ms.GRAPH_MODE)
...     init()
...     np.random.seed(1)
...     input_data = np.random.rand(16, 96).astype(np.float32)
...     label_data = np.random.rand(16, 16).astype(np.float32)
...     fake_dataset = get_dataset(input_data, label_data)
...     dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
...
...     # Set parallel strategy.
...     strategy = ((1, 4), (4, 1))
...     ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
...                                  strategy_ckpt_save_file="./train_strategy.ckpt")
...     network = Net(matmul_size=(96, 16), strategy=strategy)
...     net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
...     net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
...     model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
...     ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
...     global_rank_id = int(os.getenv("RANK_ID"))
...     ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
...     ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
...     model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
...     ms.reset_auto_parallel_context()
>>>
>>> # Load distributed checkpoint and test.
>>> def load_model():
...     ms.set_context(mode=ms.GRAPH_MODE)
...     init()
...     ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
...                                  strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
...     predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
...     network = Net(matmul_size=(96, 16))
...     model = ms.Model(network)
...     predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
...     ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
...     ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
...     predict_result = model.predict(predict_data)
...     print(predict_result)
>>>
>>> train_net()
>>> load_model()
[[-7.3259363 -7.497216  -7.398196  ... -7.374962  -7.204874  -7.234935 ]
[ 3.362938   3.3535435  3.3832688 ...  3.4263954  3.279045   3.3202887]
...
[ 1.6067538  1.6244187  1.5384722 ...  1.5449994  1.6195512  1.6176052]]