mindspore.load_checkpoint

mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode='AES-GCM')[source]

Load checkpoint info from a specified file.

Parameters
  • ckpt_file_name (str) – Checkpoint file name.

  • net (Cell) – The network where the parameters will be loaded. Default: None

  • strict_load (bool) – Whether to strict load the parameter into net. If False, it will load parameter into net when parameter name’s suffix in checkpoint file is the same as the parameter in the network. When the types are inconsistent perform type conversion on the parameters of the same type, such as float32 to float16. Default: False.

  • filter_prefix (Union[str, list[str], tuple[str]]) – Parameters starting with the filter_prefix will not be loaded. Default: None.

  • dec_key (Union[None, bytes]) – Byte type key used for decryption. If the value is None, the decryption is not required. Default: None.

  • dec_mode (str) – This parameter is valid only when dec_key is not set to None. Specifies the decryption mode, currently supports ‘AES-GCM’ and ‘AES-CBC’. Default: ‘AES-GCM’.

Returns

Dict, key is parameter name, value is a Parameter.

Raises

ValueError – Checkpoint file’s format is incorrect.

Examples

>>> from mindspore import load_checkpoint
>>>
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
>>> print(param_dict["conv2.weight"])
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)