【MindSpore开发者分享】模型训练之断点续训初体验
【MindSpore开发者分享】模型训练之断点续训初体验
作者:kaierlong
来源:https://bbs.huaweicloud.com/forum/forum.php?mod=viewthread&tid=187445
the first experience of exception_save for model training
模型训练之断点续训初体验
本文开发环境:
- MindSpore 1.7.0
本文内容提要:
- 文档示例
- 几种尝试
- 源码探究
- 案例介绍
- 本文总结
- 本文参考
1. 文档示例
1.1 官方文档
老传统,先看官方文档说明,说明如下:
笔者解读:exception_save参数是1.7.0版本新加的功能,该参数为bool数据类型,但是官方文档对该参数的使用场景没有明确说明。

1.2 官方示例
官方给出的示例如下:代码链接
MindSpore提供了断点续训的功能,当用户开启该功能时,如果在训练过程中发生了异常,那么MindSpore会自动保存异常发生时的CheckPoint文件(临终CheckPoint)。断点续训的功能通过CheckpointConfig中的exception_save参数(bool类型)控制,设置为True时开启该功能,False关闭该功能,默认为False。断点续训功能保存的临终CheckPoint文件与正常流程保存的CheckPoint互不影响,命名机制和保存路径与正常流程设置保持一致,唯一不同之处在于会在临终CheckPoint文件名最后加上’_breakpoint’进行区分。其用法如下:
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
# 配置断点续训功能开启
config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10, exception_save=True)
如果在训练过程中发生了异常,那么会自动保存临终CheckPoint,假如在训练中的第10个epoch的第10个step中发生异常,保存的临终CheckPoint文件如下。
# 临终CheckPoint文件名最后会加上'_breakpoint'与正常流程CheckPoint区分开
resnet50-10_10_breakpoint.ckpt
2. 几种尝试
在1.1中笔者谈到了官方并没有给出使用场景说明,笔者先按照自己的猜测来进行尝试。
猜测:训练过程中手动终止训练,是否会触发该参数生效。
下面进行代码验证
本文使用代码取自笔者之前开源案例fashion_mnist_classification_with_cnn_by_mindspore,并在原始代码基础上进行适当修改。
该案例相关数据处理和运行请参考案例readme。
2.1 exception_save设置为False
测试代码如下:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -------------------
# @Version : 1.0
# @Author : xingchaolong
# @For : MindSpore FashionMnist LeNet Example.
# -------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import mindspore.dataset as ds
import mindspore.nn as nn
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore import context
from mindspore import dtype as mstype
from mindspore import Model
from mindspore.common.initializer import Normal
from mindspore.dataset.vision import Inter
from mindspore.nn import Accuracy
from mindspore.train.callback import CheckpointConfig, LossMonitor, ModelCheckpoint
def create_dataset(data_path, usage="train", batch_size=32, repeat_size=1, num_parallel_workers=1):
# 定义数据集
fashion_mnist_ds = ds.FashionMnistDataset(data_path, usage=usage)
resize_height, resize_width = 28, 28
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# 使用map映射函数,将数据操作应用到数据集
fashion_mnist_ds = fashion_mnist_ds.map(
operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
fashion_mnist_ds = fashion_mnist_ds.map(
operations=[resize_op, rescale_op, rescale_nml_op, hwc2chw_op],
input_columns="image", num_parallel_workers=num_parallel_workers)
# 进行shuffle、batch、repeat操作
buffer_size = 10000
fashion_mnist_ds = fashion_mnist_ds.shuffle(buffer_size=buffer_size)
fashion_mnist_ds = fashion_mnist_ds.batch(batch_size, drop_remainder=True)
fashion_mnist_ds = fashion_mnist_ds.repeat(count=repeat_size)
return fashion_mnist_ds
class LeNet5(nn.Cell):
"""
Lenet网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 4 * 4, 256, weight_init=Normal(0.02))
self.fc2 = nn.Dense(256, 128, weight_init=Normal(0.02))
self.fc3 = nn.Dense(128, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
def train_net(model, epoch_size, data_path, batch_size, repeat_size, ckpt_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(data_path, usage="train", batch_size=batch_size, repeat_size=repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpt_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)
def test_net(model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(data_path, usage="test")
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("acc: {}".format(acc), flush=True)
def run(data_path, model_dir, device_target="CPU", batch_size=32, train_epoch=5, dataset_size=1):
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
net = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
# 设置模型保存参数
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10, exception_save=False)
# 应用模型保存参数
ckpt_cb = ModelCheckpoint(prefix="lenet_ckpt", directory=model_dir, config=config_ck)
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(model, train_epoch, data_path, batch_size, dataset_size, ckpt_cb, False)
test_net(model, data_path)
def main():
parser = argparse.ArgumentParser(description='MindSpore FashionMnist LeNet Example.')
parser.add_argument("--data_path", type=str, required=True, help="fashion mnist data path.")
parser.add_argument("--device_target", type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help="target device")
parser.add_argument("--model_dir", type=str, required=True, help="directory to save model ckpt.")
parser.add_argument("--batch_size", type=int, default=32, help="batch size.")
parser.add_argument("--train_epoch", type=int, default=5, help="train epoch.")
parser.add_argument("--dataset_size", type=int, default=1, help="dataset size.")
args = parser.parse_args()
run(
data_path=args.data_path,
model_dir=args.model_dir,
device_target=args.device_target,
batch_size=args.batch_size,
train_epoch=args.train_epoch,
dataset_size=args.dataset_size
)
if __name__ == "__main__":
main()
前台运行代码,命令如下:
./data为数据目录,读者需要自行替换。
./ckpt为模型保存目录,读者需要自行替换。
python3 main.py --data_path=./data --model_dir=./ckpt
使用ctrl+c命令手动终止命令运行,输出内容如下:
epoch: 1 step: 125, loss is 2.2966978549957275
epoch: 1 step: 250, loss is 2.2930874824523926
epoch: 1 step: 375, loss is 2.257183074951172
epoch: 1 step: 500, loss is 1.0803303718566895
^CWARNING: Logging before InitGoogleLogging() is written to STDERR
[WARNING] RUNTIME_FRAMEWORK(18086,0x10f6e9dc0,Python):2022-05-11-10:56:54.267.943 [mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc:203] IntHandler] Process 18086 receive KeyboardInterrupt signal.
Terminated: 15
使用tree ckpt命令查看模型保存目录情况,输出内容如下:
ckpt/
├── lenet_ckpt-1_100.ckpt
├── lenet_ckpt-1_200.ckpt
├── lenet_ckpt-1_300.ckpt
├── lenet_ckpt-1_400.ckpt
├── lenet_ckpt-1_500.ckpt
└── lenet_ckpt-graph.meta
0 directories, 6 files
解读:可以看到模型保存目录内容正常,并没有_breakpoint相关的ckpt出现。
2.2 exception_save设置为True
将2.1中测试代码
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10, exception_save=False)
修改为
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10, exception_save=True)
2.2.1 前台运行代码,并手动终止。
运行测试代码,命令如下:
python3 main.py --data_path=./data --model_dir=./ckpt
使用ctrl+c命令手动终止命令运行,输出内容如下:
epoch: 1 step: 125, loss is 2.2990877628326416
epoch: 1 step: 250, loss is 2.3014278411865234
epoch: 1 step: 375, loss is 2.300143003463745
epoch: 1 step: 500, loss is 2.2685062885284424
epoch: 1 step: 625, loss is 1.2246686220169067
^CWARNING: Logging before InitGoogleLogging() is written to STDERR
[WARNING] RUNTIME_FRAMEWORK(22670,0x10c621dc0,Python):2022-05-11-10:59:14.927.645 [mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc:203] IntHandler] Process 22670 receive KeyboardInterrupt signal.
Terminated: 15
使用tree ckpt命令查看模型保存目录情况,输出内容如下:
ckpt/
├── lenet_ckpt-1_100.ckpt
├── lenet_ckpt-1_200.ckpt
├── lenet_ckpt-1_300.ckpt
├── lenet_ckpt-1_400.ckpt
├── lenet_ckpt-1_500.ckpt
├── lenet_ckpt-1_600.ckpt
└── lenet_ckpt-graph.meta
0 directories, 7 files
2.2.2 后台运行代码,并手动终止。
运行测试代码,命令如下:
nohup python3 main.py --data_path=./data --model_dir=./ckpt &
使用ps aux|grep main查看进程id,并使用kill命令进行手动终止。
使用cat nohup.out查看进程运行情况,输出内容如下:
epoch: 1 step: 125, loss is 2.308577537536621
epoch: 1 step: 250, loss is 2.303668737411499
epoch: 1 step: 375, loss is 2.3061931133270264
epoch: 1 step: 500, loss is 1.572475790977478
epoch: 1 step: 625, loss is 1.2929679155349731
epoch: 1 step: 750, loss is 0.8329849243164062
使用tree ckpt命令查看模型保存目录情况,输出内容如下:
ckpt/
├── lenet_ckpt-1_100.ckpt
├── lenet_ckpt-1_200.ckpt
├── lenet_ckpt-1_300.ckpt
├── lenet_ckpt-1_400.ckpt
├── lenet_ckpt-1_500.ckpt
├── lenet_ckpt-1_600.ckpt
├── lenet_ckpt-1_700.ckpt
├── lenet_ckpt-1_800.ckpt
└── lenet_ckpt-graph.meta
0 directories, 9 files
解读:2.2.1和2.2.2测试示例中,exception_save均设置为True。一个为前台运行,手动终止训练;一个为后台运行,杀死训练进程,但是模型保存目录均没有_breakpoint类型的ckpt生成,也就是说此处测试示例的用法不对