[ "MindSpore Made Easy" ]

[ "MindSpore Made Easy" ]

MindSpore Made Easy Model Training - Resumable Training at a Checkpoint (1)

June 13, 2022

Author: kaierlong

Source: https://bbs.huaweicloud.com/forum/forum.php?mod=viewthread&tid=187445

Development environment

MindSpore 1.7.0

Contents

· Examples in Documents

· Guess and Verification

· Source Code Exploration

· Cases

· Summary

· Reference

1. Examples in Documents

1.1 Official Document

The exception_save parameter (bool type) controls the new resumable training function added in MindSpore 1.7.0, but the official document does not specify its application scenarios. See the following figure.

1.2 Official Example

For details about the official example, see Saving and Exporting Models.

MindSpore provides the resumable training function. As it is enabled, if an exception occurs during training, MindSpore automatically saves the checkpoint file (last checkpoint) generated when the exception occurs.

Resumable training is controlled by the exception_save parameter (bool type) in CheckpointConfig. If this parameter is set to True, resumable training is enabled. If it is set to False, resumable training is disabled. The default value is False. The last checkpoint file saved in resumable training and the checkpoint files saved in the normal process do not affect each other. Their naming mechanism and save path are the same. The only difference is that _breakpoint will be added to the end of the last checkpoint file name. The parameter usage is as follows:

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

# Enable resumable training.

config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10, exception_save=True)

If an exception occurs during training, the last checkpoint is automatically saved. If the exception occurs in the tenth step of the tenth epoch during training, the saved last checkpoint file is as follows:

# The name of the last checkpoint file is suffixed with _breakpoint to distinguish it from the checkpoint files in the normal process.

resnet50-10_10_breakpoint.ckpt

2. Guess and Verification

In section 1.1, I mentioned that the official website does not provide the application scenarios of this parameter, so I'll make a guess and verify it.

Guess: The parameter takes effect when the training is manually terminated.

Next, I'll verify it using code.

I selected the source code of my own open source case fashion_mnist_classification_with_cnn_by_mindspore and made some modifications.

For details about the data processing and execution of this case, see its README.md.

2.1 Setting exception_save to False

The test code is as follows:

#!/usr/bin/env python3

# -*- coding: utf-8 -*-

# -------------------

# @Version : 1.0

# @Author : xingchaolong

# @For : MindSpore FashionMnist LeNet Example.

# -------------------

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function



import argparse



import mindspore.dataset as ds

import mindspore.nn as nn

import mindspore.dataset.transforms.c_transforms as C

import mindspore.dataset.vision.c_transforms as CV



from mindspore import context

from mindspore import dtype as mstype

from mindspore import Model

from mindspore.common.initializer import Normal

from mindspore.dataset.vision import Inter

from mindspore.nn import Accuracy

from mindspore.train.callback import CheckpointConfig, LossMonitor, ModelCheckpoint





def create_dataset(data_path, usage="train", batch_size=32, repeat_size=1, num_parallel_workers=1):

    # Define the dataset.

    fashion_mnist_ds = ds.FashionMnistDataset(data_path, usage=usage)

    resize_height, resize_width = 28, 28

    rescale = 1.0 / 255.0

    shift = 0.0

    rescale_nml = 1 / 0.3081

    shift_nml = -1 * 0.1307 / 0.3081

    # Define the mapping to be operated.

    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)

    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)

    rescale_op = CV.Rescale(rescale, shift)

    hwc2chw_op = CV.HWC2CHW()

    type_cast_op = C.TypeCast(mstype.int32)


    # Use the map function to apply data operations to the dataset.

    fashion_mnist_ds = fashion_mnist_ds.map(

        operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)

    fashion_mnist_ds = fashion_mnist_ds.map(

        operations=[resize_op, rescale_op, rescale_nml_op, hwc2chw_op],

        input_columns="image", num_parallel_workers=num_parallel_workers)


    # Perform shuffle, batch, and repeat operations.

    buffer_size = 10000

    fashion_mnist_ds = fashion_mnist_ds.shuffle(buffer_size=buffer_size)

    fashion_mnist_ds = fashion_mnist_ds.batch(batch_size, drop_remainder=True)

    fashion_mnist_ds = fashion_mnist_ds.repeat(count=repeat_size)


    return fashion_mnist_ds
class LeNet5(nn.Cell):

    """

    LeNet network structure

    """

    def __init__(self, num_class=10, num_channel=1):

        super(LeNet5, self).__init__()

        # Define the required operations.

        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')

        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')

        self.fc1 = nn.Dense(16 * 4 * 4, 256, weight_init=Normal(0.02))

        self.fc2 = nn.Dense(256, 128, weight_init=Normal(0.02))

        self.fc3 = nn.Dense(128, num_class, weight_init=Normal(0.02))

        self.relu = nn.ReLU()

        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)

        self.flatten = nn.Flatten()



    def construct(self, x):

        # Use the defined operations to build a feedfoward network.

        x = self.conv1(x)

        x = self.relu(x)

        x = self.max_pool2d(x)

        x = self.conv2(x)

        x = self.relu(x)

        x = self.max_pool2d(x)

        x = self.flatten(x)

        x = self.fc1(x)

        x = self.relu(x)

        x = self.fc2(x)

        x = self.relu(x)

        x = self.fc3(x)

        return x





def train_net(model, epoch_size, data_path, batch_size, repeat_size, ckpt_cb, sink_mode):

    """Define the training method."""

    # Load the training dataset.

    ds_train = create_dataset(data_path, usage="train", batch_size=batch_size, repeat_size=repeat_size)

    model.train(epoch_size, ds_train, callbacks=[ckpt_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)



def test_net(model, data_path):

    """Define the verification method."""

    ds_eval = create_dataset(data_path, usage="test")

    acc = model.eval(ds_eval, dataset_sink_mode=False)

    print("acc: {}".format(acc), flush=True)



def run(data_path, model_dir, device_target="CPU", batch_size=32, train_epoch=5, dataset_size=1):

    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)


    net = LeNet5()

    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

    net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)



    # Set the model saving parameter.

    config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10, exception_save=False)

    # Apply the model saving parameter.

    ckpt_cb = ModelCheckpoint(prefix="lenet_ckpt", directory=model_dir, config=config_ck)


    model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

    train_net(model, train_epoch, data_path, batch_size, dataset_size, ckpt_cb, False)

    test_net(model, data_path)


def main():

    parser = argparse.ArgumentParser(description='MindSpore FashionMnist LeNet Example.')

    parser.add_argument("--data_path", type=str, required=True, help="fashion mnist data path.")

    parser.add_argument("--device_target", type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],

                        help="target device")

    parser.add_argument("--model_dir", type=str, required=True, help="directory to save model ckpt.")

    parser.add_argument("--batch_size", type=int, default=32, help="batch size.")

    parser.add_argument("--train_epoch", type=int, default=5, help="train epoch.")

    parser.add_argument("--dataset_size", type=int, default=1, help="dataset size.")


    args = parser.parse_args()


    run(

        data_path=args.data_path,

        model_dir=args.model_dir,

        device_target=args.device_target,

        batch_size=args.batch_size,

        train_epoch=args.train_epoch,

        dataset_size=args.dataset_size

    )


if __name__ == "__main__":

    main()

Run the following command to execute the code on the foreground:

./data is the data directory. Replace it as required.

./ckpt is the model saving directory. Replace it as required.

python3 main.py --data_path=./data --model_dir=./ckpt

Press Ctrl+C to manually stop the command execution. The output is as follows:

epoch: 1 step: 125, loss is 2.2966978549957275

epoch: 1 step: 250, loss is 2.2930874824523926

epoch: 1 step: 375, loss is 2.257183074951172

epoch: 1 step: 500, loss is 1.0803303718566895

^CWARNING: Logging before InitGoogleLogging() is written to STDERR

[WARNING] RUNTIME_FRAMEWORK(18086,0x10f6e9dc0,Python):2022-05-11-10:56:54.267.943 [mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc:203] IntHandler] Process 18086 receive KeyboardInterrupt signal.

Terminated: 15

Run the tree ckpt command to check the model saving directory. The output is as follows:

ckpt/

├── lenet_ckpt-1_100.ckpt

├── lenet_ckpt-1_200.ckpt

├── lenet_ckpt-1_300.ckpt

├── lenet_ckpt-1_400.ckpt

├── lenet_ckpt-1_500.ckpt

└── lenet_ckpt-graph.meta


0 directories, 6 files

Interpretation: The content of the model saving directory is normal, and no .ckpt file related to _breakpoint is displayed.

2.2 Setting exception_save to True

Change the following test code used in section 2.1

config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10, exception_save=False)

to

config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10, exception_save=True)

2.2.1 Executing Code on the Foreground and Manually Terminating the Execution

Run the following command to execute the test code:

python3 main.py --data_path=./data --model_dir=./ckpt

Press Ctrl+C to manually stop the command execution. The output is as follows:

epoch: 1 step: 125, loss is 2.2990877628326416

epoch: 1 step: 250, loss is 2.3014278411865234

epoch: 1 step: 375, loss is 2.300143003463745

epoch: 1 step: 500, loss is 2.2685062885284424

epoch: 1 step: 625, loss is 1.2246686220169067

^CWARNING: Logging before InitGoogleLogging() is written to STDERR

[WARNING] RUNTIME_FRAMEWORK(22670,0x10c621dc0,Python):2022-05-11-10:59:14.927.645 [mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc:203] IntHandler] Process 22670 receive KeyboardInterrupt signal.

Terminated: 15

Run the tree ckpt command to check the model saving directory. The output is as follows:

ckpt/

├── lenet_ckpt-1_100.ckpt

├── lenet_ckpt-1_200.ckpt

├── lenet_ckpt-1_300.ckpt

├── lenet_ckpt-1_400.ckpt

├── lenet_ckpt-1_500.ckpt

├── lenet_ckpt-1_600.ckpt

└── lenet_ckpt-graph.meta


0 directories, 7 files

2.2.2 Executing Code on the Background and Manually Terminating the Execution

Run the following command to execute the test code:

nohup python3 main.py --data_path=./data --model_dir=./ckpt &

Run the ps aux|grep main command to view the process ID and run the kill command to manually terminate the process.

Run the cat nohup.out command to check the running status of the process. The output is as follows:

epoch: 1 step: 125, loss is 2.308577537536621

epoch: 1 step: 250, loss is 2.303668737411499

epoch: 1 step: 375, loss is 2.3061931133270264

epoch: 1 step: 500, loss is 1.572475790977478

epoch: 1 step: 625, loss is 1.2929679155349731

epoch: 1 step: 750, loss is 0.8329849243164062

Run the tree ckpt command to check the model saving directory. The output is as follows:

ckpt/

├── lenet_ckpt-1_100.ckpt

├── lenet_ckpt-1_200.ckpt

├── lenet_ckpt-1_300.ckpt

├── lenet_ckpt-1_400.ckpt

├── lenet_ckpt-1_500.ckpt

├── lenet_ckpt-1_600.ckpt

├── lenet_ckpt-1_700.ckpt

├── lenet_ckpt-1_800.ckpt

└── lenet_ckpt-graph.meta


0 directories, 9 files

Interpretation: In the test examples in sections 2.2.1 and 2.2.2, exception_save is set to True. In 2.2.1, I execute the code on the foreground and manually terminate the training. In 2.2.2, I execute the code on the background and kill the training process. However, no .ckpt files of the _breakpoint type are generated in the model saving directory. That is, the parameter usages in the test examples are incorrect.