# 保存与加载

## 模型训练

[2]:

import mindspore.nn as nn
import mindspore as ms

from mindvision.classification.dataset import Mnist
from mindvision.classification.models import lenet
from mindvision.engine.callback import LossMonitor

epochs = 10  # 训练轮次

# 1. 构建数据集

# 2. 定义神经网络
network = lenet(num_classes=10, pretrained=False)
# 3.1 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 3.2 定义优化器函数
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
# 3.3 初始化模型参数
model = ms.Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'accuracy'})

# 4. 对神经网络执行训练
model.train(epochs, dataset_train, callbacks=[LossMonitor(0.01, 1875)])

Epoch:[  0/ 10], step:[ 1875/ 1875], loss:[0.148/1.210], time:2.021 ms, lr:0.01000
Epoch time: 4251.808 ms, per step time: 2.268 ms, avg loss: 1.210
Epoch:[  1/ 10], step:[ 1875/ 1875], loss:[0.049/0.081], time:2.048 ms, lr:0.01000
Epoch time: 4301.405 ms, per step time: 2.294 ms, avg loss: 0.081
Epoch:[  2/ 10], step:[ 1875/ 1875], loss:[0.014/0.050], time:1.992 ms, lr:0.01000
Epoch time: 4278.799 ms, per step time: 2.282 ms, avg loss: 0.050
Epoch:[  3/ 10], step:[ 1875/ 1875], loss:[0.035/0.038], time:2.254 ms, lr:0.01000
Epoch time: 4380.553 ms, per step time: 2.336 ms, avg loss: 0.038
Epoch:[  4/ 10], step:[ 1875/ 1875], loss:[0.130/0.031], time:1.932 ms, lr:0.01000
Epoch time: 4287.547 ms, per step time: 2.287 ms, avg loss: 0.031
Epoch:[  5/ 10], step:[ 1875/ 1875], loss:[0.003/0.027], time:1.981 ms, lr:0.01000
Epoch time: 4377.000 ms, per step time: 2.334 ms, avg loss: 0.027
Epoch:[  6/ 10], step:[ 1875/ 1875], loss:[0.004/0.023], time:2.167 ms, lr:0.01000
Epoch time: 4687.250 ms, per step time: 2.500 ms, avg loss: 0.023
Epoch:[  7/ 10], step:[ 1875/ 1875], loss:[0.004/0.020], time:2.226 ms, lr:0.01000
Epoch time: 4685.529 ms, per step time: 2.499 ms, avg loss: 0.020
Epoch:[  8/ 10], step:[ 1875/ 1875], loss:[0.000/0.016], time:2.275 ms, lr:0.01000
Epoch time: 4651.129 ms, per step time: 2.481 ms, avg loss: 0.016
Epoch:[  9/ 10], step:[ 1875/ 1875], loss:[0.022/0.015], time:2.177 ms, lr:0.01000
Epoch time: 4623.760 ms, per step time: 2.466 ms, avg loss: 0.015


## 保存模型

1. 简单的对网络模型进行保存，可以在训练前后进行保存。这种方式的优点是接口简单易用，但是只保留执行命令时候的网络模型状态；

2. 在网络模型训练中进行保存，MindSpore在网络模型训练的过程中，自动保存训练时候设定好的epoch数和step数的参数，也就是把模型训练过程中产生的中间权重参数也保存下来，方便进行网络微调和停止训练；

### 直接保存模型

[3]:

import mindspore as ms

# 定义的网络模型为net，一般在训练前或者训练后使用
ms.save_checkpoint(network, "./MyNet.ckpt")


### 训练过程中保存模型

[4]:

import mindspore as ms

# 设置epoch_num数量
epoch_num = 5

# 设置模型保存参数
config_ck = ms.CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)

# 应用模型保存参数
ckpoint = ms.ModelCheckpoint(prefix="lenet", directory="./lenet", config=config_ck)
model.train(epoch_num, dataset_train, callbacks=[ckpoint])


• save_checkpoint_steps表示每隔多少个step保存一次。

• keep_checkpoint_max表示最多保留CheckPoint文件的数量。

• prefix表示生成CheckPoint文件的前缀名。

• directory表示存放文件的目录。

lenet-graph.meta # 编译后的计算图
lenet-1_1875.ckpt  # CheckPoint文件后缀名为'.ckpt'
lenet-2_1875.ckpt  # 文件的命名方式表示保存参数所在的epoch和step数，这里为第2个epoch的第1875个step的模型参数
lenet-3_1875.ckpt  # 表示保存的是第3个epoch的第1875个step的模型参数
...


## 加载模型

[5]:

import mindspore as ms

from mindvision.classification.dataset import Mnist
from mindvision.classification.models import lenet

# 将模型参数存入parameter的字典中，这里加载的是上面训练过程中保存的模型参数

# 重新定义一个LeNet神经网络
net = lenet(num_classes=10, pretrained=False)

# 将参数加载到网络中

# 重新定义优化器函数
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)

model = ms.Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={"accuracy"})

• load_checkpoint方法会把参数文件中的网络参数加载到字典param_dict中。

• load_param_into_net方法会把字典param_dict中的参数加载到网络或者优化器中，加载后，网络中的参数就是CheckPoint保存的。

### 模型验证

[8]:

# 调用eval()进行推理
acc = model.eval(dataset_eval)

print("{}".format(acc))

{'accuracy': 0.9866786858974359}


### 用于迁移学习

[9]:

# 定义训练数据集

# 网络模型调用train()继续进行训练
model.train(epoch_num, dataset_train, callbacks=[LossMonitor(0.01, 1875)])

Epoch:[  0/  5], step:[ 1875/ 1875], loss:[0.000/0.010], time:2.193 ms, lr:0.01000
Epoch time: 4106.620 ms, per step time: 2.190 ms, avg loss: 0.010
Epoch:[  1/  5], step:[ 1875/ 1875], loss:[0.000/0.009], time:2.036 ms, lr:0.01000
Epoch time: 4233.697 ms, per step time: 2.258 ms, avg loss: 0.009
Epoch:[  2/  5], step:[ 1875/ 1875], loss:[0.000/0.010], time:2.045 ms, lr:0.01000
Epoch time: 4246.248 ms, per step time: 2.265 ms, avg loss: 0.010
Epoch:[  3/  5], step:[ 1875/ 1875], loss:[0.000/0.008], time:2.001 ms, lr:0.01000
Epoch time: 4235.036 ms, per step time: 2.259 ms, avg loss: 0.008
Epoch:[  4/  5], step:[ 1875/ 1875], loss:[0.002/0.008], time:2.039 ms, lr:0.01000
Epoch time: 4354.482 ms, per step time: 2.322 ms, avg loss: 0.008