Source code for mindspore.train.callback._checkpoint

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Checkpoint related classes and functions."""

import os
import shutil
import stat
import time

import mindspore.context as context
from mindspore import log as logger
from mindspore._checkparam import check_bool, check_string, check_int_non_negative
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
from ._callback import Callback, set_cur_net


_cur_dir = os.getcwd()
_save_dir = _cur_dir


def _check_file_name_prefix(file_name_prefix):
    """
    Check file name valid or not.

    File name can't include '/'. This file name naming convention only apply to Linux.
    """
    if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0:
        return False
    return True


def _chg_ckpt_file_name_if_same_exist(directory, prefix):
    """Check if there is a file with the same name."""
    files = os.listdir(directory)
    suffix_num = 0
    pre_len = len(prefix)
    for filename in files:
        name_ext = os.path.splitext(filename)
        if name_ext[-1] != ".ckpt":
            continue
        # find same prefix file
        if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
            # add the max suffix + 1
            index = filename[pre_len:].find("-")
            if index == 0:
                suffix_num = max(suffix_num, 1)
            elif index != -1:
                num = filename[pre_len+1:pre_len+index]
                if num.isdigit():
                    suffix_num = max(suffix_num, int(num)+1)

    if suffix_num != 0:
        prefix = prefix + "_" + str(suffix_num)

    return prefix


[docs]class CheckpointConfig: """ The config for model checkpoint. Note: During the training process, if dataset is transmitted through the data channel, suggest set save_checkpoint_steps be an integer multiple of loop_size. Otherwise there may be deviation in the timing of saving checkpoint. Args: save_checkpoint_steps (int): Steps to save checkpoint. Default: 1. save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0. Can't be used with save_checkpoint_steps at the same time. keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. Can't be used with keep_checkpoint_max at the same time. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal". Raises: ValueError: If the input_param is None or 0. Examples: >>> config = CheckpointConfig() >>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config) >>> model.train(10, dataset, callbacks=ckpoint_cb) """ def __init__(self, save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, model_type="normal"): if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: raise ValueError("The input_param can't be all None or 0") if save_checkpoint_steps: save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) if save_checkpoint_seconds: save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds) if keep_checkpoint_max: keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) if keep_checkpoint_per_n_minutes: keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) if model_type: model_type = check_string(model_type, ["normal", "fusion", "quant"]) self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: self._save_checkpoint_seconds = None self._keep_checkpoint_max = keep_checkpoint_max self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes if self._keep_checkpoint_max and self._keep_checkpoint_max > 0: self._keep_checkpoint_per_n_minutes = None else: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: self._keep_checkpoint_max = 1 self._model_type = model_type self._integrated_save = check_bool(integrated_save) @property def save_checkpoint_steps(self): """Get the value of _save_checkpoint_steps.""" return self._save_checkpoint_steps @property def save_checkpoint_seconds(self): """Get the value of _save_checkpoint_seconds.""" return self._save_checkpoint_seconds @property def keep_checkpoint_max(self): """Get the value of _keep_checkpoint_max.""" return self._keep_checkpoint_max @property def keep_checkpoint_per_n_minutes(self): """Get the value of _keep_checkpoint_per_n_minutes.""" return self._keep_checkpoint_per_n_minutes @property def integrated_save(self): """Get the value of _integrated_save.""" return self._integrated_save @property def model_type(self): """Get the value of model_type.""" return self._model_type
[docs] def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, 'save_checkpoint_seconds': self._save_checkpoint_seconds, 'keep_checkpoint_max': self._keep_checkpoint_max, 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes, 'model_type': self._model_type} return checkpoint_policy
[docs]class ModelCheckpoint(Callback): """ The checkpoint callback class. It is called to combine with train process and save the model and network parameters after traning. Args: prefix (str): Checkpoint files names prefix. Default: "CKP". directory (str): Folder path into which checkpoint files will be saved. Default: None. config (CheckpointConfig): Checkpoint strategy config. Default: None. Raises: ValueError: If the prefix is invalid. TypeError: If the config is not CheckpointConfig type. """ def __init__(self, prefix='CKP', directory=None, config=None): super(ModelCheckpoint, self).__init__() self._latest_ckpt_file_name = "" self._init_time = time.time() self._last_time = time.time() self._last_time_for_keep = time.time() self._last_triggered_step = 0 if _check_file_name_prefix(prefix): self._prefix = prefix else: raise ValueError("Prefix {} for checkpoint file name invalid, " "please check and correct it and then continue.".format(prefix)) if directory: self._directory = _make_directory(directory) else: self._directory = _cur_dir if config is None: self._config = CheckpointConfig() else: if not isinstance(config, CheckpointConfig): raise TypeError("config should be CheckpointConfig type.") self._config = config # get existing checkpoint files self._manager = CheckpointManager() self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) self._graph_saved = False
[docs] def step_end(self, run_context): """ Save the checkpoint at the end of step. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() # save graph (only once) if not self._graph_saved: graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') _save_graph(cb_params.train_network, graph_file_name) self._graph_saved = True self._save_ckpt(cb_params, self._config.model_type)
[docs] def end(self, run_context): """ Save the last checkpoint after training finished. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() _to_save_last_ckpt = True self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt) from mindspore.parallel._cell_wrapper import destroy_allgather_cell destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save): """Check whether save checkpoint files or not.""" if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ or force_to_save is True: return True elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: self._cur_time = time.time() if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True: self._last_time = self._cur_time return True return False def _save_ckpt(self, cb_params, model_type, force_to_save=False): """Save checkpoint files.""" if cb_params.cur_step_num == self._last_triggered_step: return save_ckpt = self._check_save_ckpt(cb_params, force_to_save) step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 if save_ckpt: cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ + str(step_num_in_epoch) + ".ckpt" # update checkpoint file list. self._manager.update_ckpoint_filelist(self._directory, self._prefix) # keep checkpoint files number equal max number. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: self._manager.remove_oldest_ckpoint_file() elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: self._cur_time_for_keep = time.time() if (self._cur_time_for_keep - self._last_time_for_keep) \ < self._config.keep_checkpoint_per_n_minutes * 60: self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, self._cur_time_for_keep) # generate the new checkpoint file and rename it. global _save_dir _save_dir = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt' gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num if context.get_context("enable_ge"): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() _exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save) if os.path.exists(gen_file): shutil.move(gen_file, cur_file) self._latest_ckpt_file_name = cur_file @property def latest_ckpt_file_name(self): """Return the latest checkpoint path and file name.""" return self._latest_ckpt_file_name
class CheckpointManager: """Manage checkpoint files according to train_config of checkpoint.""" def __init__(self): self._ckpoint_filelist = [] @property def ckpoint_filelist(self): """Get all the related checkpoint files managed here.""" return self._ckpoint_filelist @property def ckpoint_num(self): """Get the number of the related checkpoint files managed here.""" return len(self._ckpoint_filelist) def update_ckpoint_filelist(self, directory, prefix): """Update the checkpoint file list.""" self._ckpoint_filelist = [] files = os.listdir(directory) for filename in files: if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix): mid_name = filename[len(prefix):-5] flag = True for char in mid_name: if char.isalpha(): flag = False if flag: self._ckpoint_filelist.append(directory + '/' + filename) def remove_ckpoint_file(self, file_name): """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" try: os.chmod(file_name, stat.S_IWRITE) os.remove(file_name) self._ckpoint_filelist.remove(file_name) except OSError: logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) except ValueError: logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) def remove_oldest_ckpoint_file(self): """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) self.remove_ckpoint_file(ckpoint_files[0]) def keep_one_ckpoint_per_minutes(self, minutes, cur_time): """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" movs = [] oldest_file = '' oldest_time = cur_time for ck_file in self._ckpoint_filelist: modify_time = os.path.getmtime(ck_file) if cur_time - modify_time < 60 * minutes: movs.append(ck_file) if modify_time < oldest_time: oldest_time = modify_time oldest_file = ck_file for mv_file in movs: if mv_file == oldest_file: continue self.remove_ckpoint_file(mv_file)