保存与加载

在线运行下载Notebook下载样例代码查看源文件

上一章节的内容里面主要是介绍了如何调整超参数,并进行网络模型训练。训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型部署和推理,本章节我们开始学习如何保存与加载模型。

模型训练

下面我们以MNIST数据集为例,介绍网络模型的保存与加载方式。首先,我们需要获取MNIST数据集并训练模型,示例代码如下:

[2]:
import mindspore.nn as nn
from mindspore.train import Model

from mindvision.classification.dataset import Mnist
from mindvision.classification.models import lenet
from mindvision.engine.callback import LossMonitor

epochs = 10  # 训练轮次

# 1. 构建数据集
download_train = Mnist(path="./mnist", split="train", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)
dataset_train = download_train.run()

# 2. 定义神经网络
network = lenet(num_classes=10, pretrained=False)
# 3.1 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 3.2 定义优化器函数
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
# 3.3 初始化模型参数
model = Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'accuracy'})

# 4. 对神经网络执行训练
model.train(epochs, dataset_train, callbacks=[LossMonitor(0.01, 1875)])
Epoch:[  0/ 10], step:[ 1875/ 1875], loss:[0.148/1.210], time:2.021 ms, lr:0.01000
Epoch time: 4251.808 ms, per step time: 2.268 ms, avg loss: 1.210
Epoch:[  1/ 10], step:[ 1875/ 1875], loss:[0.049/0.081], time:2.048 ms, lr:0.01000
Epoch time: 4301.405 ms, per step time: 2.294 ms, avg loss: 0.081
Epoch:[  2/ 10], step:[ 1875/ 1875], loss:[0.014/0.050], time:1.992 ms, lr:0.01000
Epoch time: 4278.799 ms, per step time: 2.282 ms, avg loss: 0.050
Epoch:[  3/ 10], step:[ 1875/ 1875], loss:[0.035/0.038], time:2.254 ms, lr:0.01000
Epoch time: 4380.553 ms, per step time: 2.336 ms, avg loss: 0.038
Epoch:[  4/ 10], step:[ 1875/ 1875], loss:[0.130/0.031], time:1.932 ms, lr:0.01000
Epoch time: 4287.547 ms, per step time: 2.287 ms, avg loss: 0.031
Epoch:[  5/ 10], step:[ 1875/ 1875], loss:[0.003/0.027], time:1.981 ms, lr:0.01000
Epoch time: 4377.000 ms, per step time: 2.334 ms, avg loss: 0.027
Epoch:[  6/ 10], step:[ 1875/ 1875], loss:[0.004/0.023], time:2.167 ms, lr:0.01000
Epoch time: 4687.250 ms, per step time: 2.500 ms, avg loss: 0.023
Epoch:[  7/ 10], step:[ 1875/ 1875], loss:[0.004/0.020], time:2.226 ms, lr:0.01000
Epoch time: 4685.529 ms, per step time: 2.499 ms, avg loss: 0.020
Epoch:[  8/ 10], step:[ 1875/ 1875], loss:[0.000/0.016], time:2.275 ms, lr:0.01000
Epoch time: 4651.129 ms, per step time: 2.481 ms, avg loss: 0.016
Epoch:[  9/ 10], step:[ 1875/ 1875], loss:[0.022/0.015], time:2.177 ms, lr:0.01000
Epoch time: 4623.760 ms, per step time: 2.466 ms, avg loss: 0.015

从上面的打印结果可以看出,随着训练轮次的增加,损失值趋于收敛。

保存模型

在训练完网络完成后,下面我们将网络模型以文件的形式保存下来。保存模型的接口有主要2种:

  1. 简单的对网络模型进行保存,可以在训练前后进行保存。这种方式的优点是接口简单易用,但是只保留执行命令时候的网络模型状态;

  2. 在网络模型训练中进行保存,MindSpore在网络模型训练的过程中,自动保存训练时候设定好的epoch数和step数的参数,也就是把模型训练过程中产生的中间权重参数也保存下来,方便进行网络微调和停止训练;

直接保存模型

使用MindSpore提供的save_checkpoint保存模型,传入网络和保存路径:

[3]:
import mindspore as ms

# 定义的网络模型为net,一般在训练前或者训练后使用
ms.save_checkpoint(network, "./MyNet.ckpt")

其中,network为训练网络,"./MyNet.ckpt"为网络模型的保存路径。

训练过程中保存模型

在模型训练的过程中,使用model.train里面的callbacks参数传入保存模型的对象 ModelCheckpoint(一般与CheckpointConfig配合使用),可以保存模型参数,生成CheckPoint(简称ckpt)文件。

用户可以根据具体需求通过设置CheckpointConfig来对CheckPoint策略进行配置。具体用法如下:

[4]:
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

# 设置epoch_num数量
epoch_num = 5

# 设置模型保存参数
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)

# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="lenet", directory="./lenet", config=config_ck)
model.train(epoch_num, dataset_train, callbacks=[ckpoint])

上述代码中,首先需要初始化一个CheckpointConfig类对象,用来设置保存策略。

  • save_checkpoint_steps表示每隔多少个step保存一次。

  • keep_checkpoint_max表示最多保留CheckPoint文件的数量。

  • prefix表示生成CheckPoint文件的前缀名。

  • directory表示存放文件的目录。

创建一个ModelCheckpoint对象把它传递给model.train方法,就可以在训练过程中使用CheckPoint功能了。

生成的CheckPoint文件如下:

lenet-graph.meta # 编译后的计算图
lenet-1_1875.ckpt  # CheckPoint文件后缀名为'.ckpt'
lenet-2_1875.ckpt  # 文件的命名方式表示保存参数所在的epoch和step数,这里为第2个epoch的第1875个step的模型参数
lenet-3_1875.ckpt  # 表示保存的是第3个epoch的第1875个step的模型参数
...

如果用户使用相同的前缀名,运行多次训练脚本,可能会生成同名CheckPoint文件。MindSpore为方便用户区分每次生成的文件,会在用户定义的前缀后添加”_”和数字加以区分。如果想要删除.ckpt文件时,请同步删除.meta 文件。

例:lenet_3-2_1875.ckpt 表示运行第3次脚本生成的第2个epoch的第1875个step的CheckPoint文件。

加载模型

要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

示例代码如下:

[5]:
from mindspore import load_checkpoint, load_param_into_net

from mindvision.classification.dataset import Mnist
from mindvision.classification.models import lenet

# 将模型参数存入parameter的字典中,这里加载的是上面训练过程中保存的模型参数
param_dict = load_checkpoint("./lenet/lenet-5_1875.ckpt")

# 重新定义一个LeNet神经网络
net = lenet(num_classes=10, pretrained=False)

# 将参数加载到网络中
load_param_into_net(net, param_dict)

# 重新定义优化器函数
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)

model = Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={"accuracy"})
  • load_checkpoint方法会把参数文件中的网络参数加载到字典param_dict中。

  • load_param_into_net方法会把字典param_dict中的参数加载到网络或者优化器中,加载后,网络中的参数就是CheckPoint保存的。

模型验证

在上述模块把参数加载到网络中之后,针对推理场景,可以调用eval函数进行推理验证。示例代码如下:

[8]:
# 调用eval()进行推理
download_eval = Mnist(path="./mnist", split="test", batch_size=32, resize=32, download=True)
dataset_eval = download_eval.run()
acc = model.eval(dataset_eval)

print("{}".format(acc))
{'accuracy': 0.9866786858974359}

用于迁移学习

针对任务中断再训练及微调(Fine-tuning)场景,可以调用train函数进行迁移学习。示例代码如下:

[9]:
# 定义训练数据集
download_train = Mnist(path="./mnist", split="train", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)
dataset_train = download_train.run()

# 网络模型调用train()继续进行训练
model.train(epoch_num, dataset_train, callbacks=[LossMonitor(0.01, 1875)])
Epoch:[  0/  5], step:[ 1875/ 1875], loss:[0.000/0.010], time:2.193 ms, lr:0.01000
Epoch time: 4106.620 ms, per step time: 2.190 ms, avg loss: 0.010
Epoch:[  1/  5], step:[ 1875/ 1875], loss:[0.000/0.009], time:2.036 ms, lr:0.01000
Epoch time: 4233.697 ms, per step time: 2.258 ms, avg loss: 0.009
Epoch:[  2/  5], step:[ 1875/ 1875], loss:[0.000/0.010], time:2.045 ms, lr:0.01000
Epoch time: 4246.248 ms, per step time: 2.265 ms, avg loss: 0.010
Epoch:[  3/  5], step:[ 1875/ 1875], loss:[0.000/0.008], time:2.001 ms, lr:0.01000
Epoch time: 4235.036 ms, per step time: 2.259 ms, avg loss: 0.008
Epoch:[  4/  5], step:[ 1875/ 1875], loss:[0.002/0.008], time:2.039 ms, lr:0.01000
Epoch time: 4354.482 ms, per step time: 2.322 ms, avg loss: 0.008