Source code for mindspore.parallel.checkpoint_transform

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transform distributed checkpoint"""
from __future__ import absolute_import

import os
import glob
import copy
from collections import defaultdict
import numpy as np
import mindspore as ms
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
    _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
    _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
    _merge_protobuf_strategy, _merge_json_strategy


__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
           "transform_checkpoints"]


[docs]def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file): """ Merge parallel strategy between all pipeline stages in pipeline parallel mode. For more details about converting distributed Checkpoint, please refer to `Distributed Resilience Training and Inference <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/resilience_train_and_predict.html>`_. Note: Strategy file of each pipeline stage should be included in src_strategy_dirs. Args: src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)' dst_strategy_file (str): The file merged strategy to save. Raises: NotADirectoryError: `src_strategy_dirs` is not a directory. Examples: >>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt >>> merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt") """ dst_strategy_dir, _ = os.path.split(dst_strategy_file) if not os.path.exists(dst_strategy_dir): _make_dir(dst_strategy_dir, "path") if not os.path.isdir(src_strategy_dirs): raise NotADirectoryError("src_strategy_dirs {} is not a directory.".format(src_strategy_dirs)) src_strategy_files_protobuf = glob.glob(os.path.join(src_strategy_dirs, "*.ckpt")) src_strategy_files_json = glob.glob(os.path.join(src_strategy_dirs, "*.json")) if src_strategy_files_protobuf and src_strategy_files_json: raise ValueError("The strategys format should be all '.ckpt' or all '.json'") is_protobuf = len(src_strategy_files_protobuf) > 0 if is_protobuf: _merge_protobuf_strategy(src_strategy_files_protobuf, dst_strategy_file) else: _merge_json_strategy(src_strategy_files_json, dst_strategy_file)
[docs]def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None): """ List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id during the distributed checkpoint conversion. For more details about converting distributed Checkpoint, please refer to `Distributed Resilience Training and Inference <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/resilience_train_and_predict.html>`_. Args: rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion. src_strategy_file (str): Name of source sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the 'src_strategy_file' is None, it means that the source sharding strategy is without any sharing for each parameter. Default:None. dst_strategy_file (str): Name of destination sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the 'dst_strategy_file' is None, it means that the destination sharding strategy is without any sharing for each parameter. Default:None. Returns: List, the rank list required for converting the distributed checkpoint of rank_id. Raises: ValueError: `src_strategy_file` or dst_strategy_file is incorrect. TypeError: `src_strategy_file` or dst_strategy_file is not a string. TypeError: `rank_id` is not a int. Examples: >>> rank_id = 0 >>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt") >>> checkpoint_files_map = {} >>> for rank in rank_list: ... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank) """ if not isinstance(rank_id, int): raise TypeError("The rank_id should be a int.") if src_strategy_file is None: return [0] src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file) src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \ is not None else 1 dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \ is not None else 1 if not src_strategy_list: raise ValueError("The src_strategy_file is empty.") local_rank_id = rank_id % dst_stage_device_num if dst_stage_device_num > 1 else rank_id needed_rank_list_in_local_stage = _rank_list_for_transform_parallel_checkpoint(local_rank_id, src_strategy_list, dst_strategy_list) result_set = set() handled_pipeline_stage = [] for _, layout in src_strategy_list.items(): for src_pipeline_stage_id in layout[6]: if src_pipeline_stage_id in handled_pipeline_stage: continue src_rank_id_start = src_pipeline_stage_id * src_stage_device_num result_set.update([src_rank_id_start + rank for rank in needed_rank_list_in_local_stage]) handled_pipeline_stage.append(src_pipeline_stage_id) result_list = list(result_set) result_list.sort(reverse=True) return result_list
[docs]def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None): """ Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank for a network. For more details about converting distributed Checkpoint, please refer to `Distributed Resilience Training and Inference <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/resilience_train_and_predict.html>`_. Args: rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion. checkpoint_files_map (dict): The checkpoint files map whose key is the rank id and the value is the checkpoint file name. save_checkpoint_file_name (str): The file name to save the converted checkpoint. src_strategy_file (str): Name of source sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the 'src_strategy_file' is None, it means that the source sharding strategy is without any sharing for each parameter. Default:None. dst_strategy_file (str): Name of destination sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the 'dst_strategy_file' is None, it means that the destination sharding strategy is without any sharing for each parameter. Default:None. Raises: ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect. ValueError: item in `checkpoint_files_map` is incorrect. ValueError: `save_checkpoint_file_name` is not end with ".ckpt". TypeError: `checkpoint_files_map` is not a dict. TypeError: `src_strategy_file` or `dst_strategy_file` is not a string. TypeError: `rank_id` is not a int. TypeError: `save_checkpoint_file_name` is not a string. Examples: >>> dst_device_num = 8 >>> for rank_id in range(dst_device_num): ... rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt") ... checkpoint_files_map = {} ... for rank in rank_list: ... checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank) ... save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id) ... transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, ... "./src_strategy.ckpt", "./dst_strategy.ckpt") """ if not isinstance(checkpoint_files_map, dict): raise TypeError("The checkpoint_files_map should be a dict.") if not isinstance(rank_id, int): raise TypeError("The rank_id should be a int.") if not isinstance(save_checkpoint_file_name, str): raise TypeError("The save_checkpoint_file_name should be a str.") if save_checkpoint_file_name[-5:] != ".ckpt": raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name)) if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists( os.path.dirname(dst_strategy_file)): raise ValueError("The director of dst_strategy_file: {} is not exists.". format(os.path.dirname(dst_strategy_file))) for rank, local_file in checkpoint_files_map.items(): if not os.path.exists(local_file): raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank)) param_total_dict = defaultdict(dict) param_attr_dict = defaultdict(dict) src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file) # src rank => local rank inside pipeline stage src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \ is not None else 1 dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \ is not None else 1 origin_dst_strategy_list = _extract_layout_map(dst_strategy_file) origin_src_strategy_list = _extract_layout_map(src_strategy_file) for rank, file_name in checkpoint_files_map.items(): ckpt_dict = ms.load_checkpoint(file_name) for param_name, param in ckpt_dict.items(): # cut the parameter not in the pipeline stage. if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \ and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list): continue src_rank = rank % src_stage_device_num param_total_dict[param_name][src_rank] = param.data.asnumpy() param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel) local_rank_id = rank_id % dst_stage_device_num transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict, param_attr_dict, src_strategy_list, dst_strategy_list) ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
[docs]def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None, dst_strategy_file=None): """ Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank. For more details about converting distributed Checkpoint, please refer to `Distributed Resilience Training and Inference <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/resilience_train_and_predict.html>`_. Note: The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the rank number should be set to a subdirectory and the checkpoint file is stored in this subdirectory. If multiple files exist in a rank directory, the last file in the lexicgraphic order would be selected. Args: src_checkpoints_dir (str): The source checkpoints directory. dst_checkpoints_dir (str): The destination checkpoints directory to save the converted checkpoints. ckpt_prefix (str): The destination checkpoint name prefix. src_strategy_file (str): Name of source sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the 'src_strategy_file' is None, it means that the source sharding strategy is without any sharing for each parameter. Default:None. dst_strategy_file (str): Name of destination sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the 'dst_strategy_file' is None, it means that the destination sharding strategy is without any sharing for each parameter. Default:None. Raises: ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect. NotADirectoryError: `src_checkpoints_dir` or `dst_checkpoints_dir` is not a directory. ValueError: The checkpoint file is missing in `src_checkpoints_dir`. TypeError: `src_strategy_file` or `dst_strategy_file` is not a string. Examples: >>> transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, "dst_checkpoint", ... "./src_strategy.ckpt", "./dst_strategy.ckpt") """ if not os.path.isdir(src_checkpoints_dir): raise NotADirectoryError("src_checkpoints_dir {} is not a directory.".format(src_checkpoints_dir)) _make_dir(dst_checkpoints_dir, "path") if not isinstance(ckpt_prefix, str): raise TypeError("The ckpt_prefix should be a str.") checkpoints_rank_dir_list = os.path.join(src_checkpoints_dir, "rank_[0-9]*") all_checkpoint_files_map = {} for checkpoint_dir in glob.glob(checkpoints_rank_dir_list): if not os.path.isdir(checkpoint_dir): ms.log.warning("{} is not a directory.".format(checkpoint_dir)) continue rank_id_str = checkpoint_dir.split('rank_')[-1] if not rank_id_str.isdigit(): ms.log.warning("{} is not a expected directory, the directory should end with rank_0/rank_1.....". format(checkpoint_dir)) continue rank_id = int(rank_id_str) checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt") rank_ckpts = glob.glob(checkpoint_file_name) rank_ckpts.sort() for checkpoint_file in rank_ckpts: if not os.path.isfile(checkpoint_file): ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file)) continue all_checkpoint_files_map[rank_id] = checkpoint_file needed_rank_list_map = defaultdict(list) dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_file) src_stage_device_num = _get_device_num_from_strategy(src_strategy_file) dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file) dst_device_num = dst_stage_device_num * dst_stage_num origin_src_strategy_list = _extract_layout_map(src_strategy_file) origin_dst_strategy_list = _extract_layout_map(dst_strategy_file) for rank in range(dst_device_num): needed_rank_list = rank_list_for_transform(rank, src_strategy_file, dst_strategy_file) for needed_rank in needed_rank_list: if needed_rank not in all_checkpoint_files_map: raise ValueError("The checkpoint file of rank{} is needed for converting rank{}'s checkpoint, " "but it is missing.".format(needed_rank, rank)) needed_rank_list_key = "-".join([str(r) for r in needed_rank_list]) needed_rank_list_map[needed_rank_list_key].append(rank) for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items(): param_total_dict = defaultdict(dict) param_attr_dict = defaultdict(dict) needed_rank_list = needed_rank_list_key.split("-") for needed_rank in needed_rank_list: ckpt_dict = ms.load_checkpoint(all_checkpoint_files_map.get(int(needed_rank))) for param_name, param in ckpt_dict.items(): src_rank = int(needed_rank) % src_stage_device_num param_total_dict[param_name][src_rank] = param.data.asnumpy() param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel) for transform_rank in transform_rank_list: param_total_dict_copy = copy.deepcopy(param_total_dict) src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(transform_rank, src_strategy_file, dst_strategy_file) # cut the parameter not in the pipeline stage. for param in list(param_total_dict_copy.keys()): if _parameter_not_in_local_stage(param, origin_src_strategy_list, src_strategy_list) \ and _parameter_not_in_local_stage(param, origin_dst_strategy_list, dst_strategy_list): param_total_dict_copy.pop(param) local_rank_id = transform_rank % dst_stage_device_num transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict_copy, param_attr_dict, src_strategy_list, dst_strategy_list) save_checkpoint_file = "{}{}.ckpt".format(ckpt_prefix, transform_rank) save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank)) if not os.path.exists(save_checkpoint_file_dir): _make_dir(save_checkpoint_file_dir, "path") save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file) ms.save_checkpoint(transform_param_list, save_checkpoint_file_name) del param_total_dict_copy del param_total_dict