mindspore.train.ModelCheckpoint

View Source On Gitee
class mindspore.train.ModelCheckpoint(prefix='CKP', directory=None, config=None)[source]

The checkpoint callback class.

It is called to combine with train process and save the model and network parameters after training.

Note

In the distributed training scenario, please specify different directories for each training process to save the checkpoint file. Otherwise, the training may fail. If this callback is used in the model function, the checkpoint file will saved parameters of the optimizer by default.

Parameters
  • prefix (Union[str, callable object]) – The prefix name or callable object to generate name of checkpoint files. Default: 'CKP' .

  • directory (Union[str, callable object]) – The folder path where the checkpoint is stored, or the callable object used to generate the path. By default, the file is saved in the current directory. Default: None .

  • config (CheckpointConfig) – Checkpoint strategy configuration. Default: None .

Raises
  • ValueError – If prefix is not str or contains the ‘/’ character and is not a callable object.

  • ValueError – If directory is not str and is not a callable object.

  • TypeError – If the config is not CheckpointConfig type.

Examples

>>> import numpy as np
>>> import mindspore.dataset as ds
>>> from mindspore import nn
>>> from mindspore.train import Model, ModelCheckpoint
>>>
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> net = nn.Dense(10, 5)
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> ckpt_callback = ModelCheckpoint(prefix="myckpt")
>>> model = Model(network=net, optimizer=opt, loss_fn=crit)
>>> model.train(2, train_dataset, callbacks=[ckpt_callback])
end(run_context)[source]

Save the last checkpoint after training finished.

Parameters

run_context (RunContext) – Context of the train running.

property latest_ckpt_file_name

Return the latest checkpoint path and file name.

step_end(run_context)[source]

Save the checkpoint at the end of step.

Parameters

run_context (RunContext) – Context of the train running.