mindspore.load_checkpoint_async

View Source On Gitee
mindspore.load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode='AES-GCM', specify_prefix=None, choice_func=None)[source]

Load checkpoint info from a specified file asyncly.

Warning

This is an experimental API that is subject to change or deletion.

Note

  • specify_prefix and filter_prefix do not affect each other.

  • If none of the parameters are loaded from checkpoint file, it will throw ValueError.

  • specify_prefix and filter_prefix are in the process of being deprecated, choice_func is recommended instead. And using either of those two args will override choice_func at the same time.

Parameters
  • ckpt_file_name (str) – Checkpoint file name.

  • net (Cell, optional) – The network where the parameters will be loaded. Default: None .

  • strict_load (bool, optional) – Whether to strict load the parameter into net. If False , it will load parameter into net when parameter name’s suffix in checkpoint file is the same as the parameter in the network. When the types are inconsistent perform type conversion on the parameters of the same type, such as float32 to float16. Default: False .

  • filter_prefix (Union[str, list[str], tuple[str]], optional) – Deprecated(see choice_func). Parameters starting with the filter_prefix will not be loaded. Default: None .

  • dec_key (Union[None, bytes], optional) – Byte type key used for decryption. If the value is None , the decryption is not required. Default: None .

  • dec_mode (str, optional) – This parameter is valid only when dec_key is not set to None . Specifies the decryption mode, currently supports "AES-GCM" and "AES-CBC" and "SM4-CBC" . Default: "AES-GCM" .

  • specify_prefix (Union[str, list[str], tuple[str]], optional) – Deprecated(see choice_func). Parameters starting with the specify_prefix will be loaded. Default: None .

  • choice_func (Union[None, function], optional) – Input value of the function is a Parameter name of type string, and the return value is a bool. If returns True , the Parameter that matches the custom condition will be loaded. If returns False , the Parameter that matches the custom condition will be removed. Default: None .

Returns

A custom inner class, calling its result method yields the mindspore.load_checkpoint() result.

Raises
  • ValueError – Checkpoint file’s format is incorrect.

  • ValueError – Parameter’s dict is None after load checkpoint file.

  • TypeError – The type of specify_prefix or filter_prefix is incorrect.

Examples

>>> import mindspore
>>> from mindspore import nn
>>> from mindspore.train import Model
>>> from mindspore.amp import FixedLossScaleManager
>>> from mindspore import context
>>> from mindspore import load_checkpoint_async
>>> from mindspore import load_param_into_net
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
...               loss_scale_manager=loss_scale_manager)
>>> pd_future = load_checkpoint_async(ckpt_file)
>>> model.build(train_dataset=dataset, epoch=2)
>>> param_dict = pd_future.result()
>>> load_param_into_net(net, param_dict)
>>> model.train(2, dataset)
>>> print("param dict len: ", len(param_dict), flush=True)