mindformers.core.callback.callback 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindFormer Self-Define Callback."""
import json
import os
import re
import glob
import sys
import time
import tempfile
import hashlib
import subprocess
import shlex

from collections import OrderedDict, defaultdict
from copy import deepcopy
from datetime import datetime, timedelta
from typing import Callable, Optional, Union, Dict, Tuple, List
import numpy as np

import mindspore as ms
import mindspore.ops.operations as P
import mindspore.ops.functional as F
from mindspore._checkparam import args_type_check
from mindspore import (
    Callback,
    ModelCheckpoint,
    CheckpointConfig,
    context,
    Tensor,
    get_auto_parallel_context,
    set_auto_parallel_context
)
from mindspore.train.callback import SummaryCollector
from mindspore.train.callback._checkpoint import CheckpointManager
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
from mindspore.train.serialization import _get_merged_param_data
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import Broadcast
from mindspore.common import jit
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
from mindspore.common.api import flops_collection
from mindspore.communication.management import create_group, get_group_size, get_rank, GlobalComm
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.comm_func import all_gather_into_tensor, barrier
from mindspore.profiler import ProfilerLevel, schedule
from mindspore.utils import stress_detect
from mindspore.mint.distributed import all_to_all_single

from mindformers.wrapper.wrapper import get_real_models
from mindformers.checkpoint.sharded_tensor import get_all_sharded_tensor
from mindformers.core.context.build_context import get_context, is_legacy_model
from mindformers.tools import get_output_root_path
from mindformers.tools.logger import logger
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.utils.parallel_utils import barrier_world
from mindformers.tools.utils import (
    get_output_subpath,
    get_real_rank,
    get_real_group_size,
    get_real_local_rank,
    get_pipeline_rank_ids,
    is_last_pipeline_stage,
    get_ascend_log_path,
    set_safe_mode_for_file_or_dir
)
from mindformers.utils.parameter_register import parameter_register
from mindformers.utils.tensorboard import get_tensorboard_writer, get_tensorboard_args
from mindformers.version_control import is_version_ge, check_arf_status, check_tft_valid
from mindformers.parallel_core.training_graph.loss_func import (
    get_device_local_loss,
    reset_device_local_loss,
    check_device_local_loss
)
from mindformers.checkpoint.checkpoint import AsyncSaveManager, CommonInfo, save_checkpoint
from mindformers.checkpoint.utils import CkptHealthStatus
from mindformers.parallel_core.transformer_config import TransformerConfig
from mindformers.models.utils import num_floating_point_operations, convert_transformer_config_to_args_for_tflops

# pylint: disable=import-outside-toplevel
__all__ = ['MFLossMonitor', 'CheckpointMonitor', 'SummaryMonitor', 'ProfileMonitor', 'EvalCallBack']

_cur_dir = os.getcwd()
SAVE_DIR = _cur_dir

VOLTAGE_ERROR_CODE = 574007



class ExpertParallelManager:
    """
    Expert Parallel Manager for managing expert parallel communication.

    This class handles expert-to-rank mapping and communication pattern calculation
    for expert parallel training in mixture of experts (MoE) models.
    """

    @args_type_check(rank_to_expert=Optional[Dict[int, List[int]]])
    def __init__(self,
                 ep_group: List[int],
                 current_rank: int,
                 expert_nums: int,
                 rank_to_expert: Optional[Dict[int, List[int]]] = None
                 ):
        """
        Initialize Expert Parallel Manager

        Args:
            ep_group (List[int]): List of global ranks in the expert parallel group
            current_rank (int): Global rank of the current process
            expert_nums (int): Total number of global experts
            rank_to_expert (Dict[int, List[int]], optional): Initial mapping from rank to expert list. Default: ``None``
        """
        self.ep_group = ep_group
        self.current_rank = current_rank
        self.expert_nums = expert_nums
        self.rank_to_expert = rank_to_expert or {}
        self.ep = len(ep_group)
        self.local_expert_num = expert_nums // self.ep
        self.current_ep_rank = ep_group.index(current_rank)
        if self.current_rank not in self.ep_group:
            raise ValueError(f"Current rank {self.current_rank} not in ep group: {self.ep_group}")

    def update_communication_info(self, q_mapping: Dict[int, List[int]]):
        """
        Update communication information and calculate send/receive data volume

        Args:
            q_mapping (Dict[int, List[int]]): New mapping dictionary from rank to expert list
        """
        recv_count_map = []
        for ep_rank, _ in enumerate(self.ep_group):
            recv_expert_list = q_mapping.get(ep_rank, [])
            recv_count_list = [0] * self.ep

            for recv_expert_id in recv_expert_list:
                target_ep_rank = recv_expert_id // self.local_expert_num
                recv_count_list[target_ep_rank] += 1

            recv_count_map.append(recv_count_list)

        recv_size_list = recv_count_map[self.current_ep_rank]

        send_count_map = np.array(recv_count_map).T
        send_size_list = send_count_map[self.current_ep_rank].tolist()

        if not self.rank_to_expert:
            self.rank_to_expert = q_mapping
        else:
            new_rank_to_expert = {}
            for rank, expert_list in q_mapping.items():
                new_expert_list = []
                for expert_id in expert_list:
                    ep_rank = expert_id // self.local_expert_num
                    expert_idx_in_ep = expert_id % self.local_expert_num

                    original_expert_id = self.rank_to_expert[ep_rank][expert_idx_in_ep]
                    new_expert_list.append(original_expert_id)

                new_rank_to_expert[rank] = new_expert_list
            self.rank_to_expert = new_rank_to_expert

        return send_size_list, recv_size_list

    def restore_communication_info(self):
        """
        Restore communication information and calculate send/receive data volume
        """
        self.rank_to_expert = {}


class AllReduceNet(Cell):
    """
    Used to accumulate flops in pipeline parallel.
    """

    def __init__(self, group_name):
        super().__init__()
        self.allreduce_sum = P.AllReduce(op=P.ReduceOp.SUM, group=group_name)
        self.add_flags(skip_auto_parallel_compile=True)

    def construct(self, x):
        return self.allreduce_sum(x)


def _check_mspti_is_on():
    """Check whether mspti is enabled."""
    ld_preload = os.getenv("LD_PRELOAD")
    return isinstance(ld_preload, str) and ld_preload.find("libmspti.so") != -1


def _get_separate_loss():
    """callback drop rate."""
    aux_loss = parameter_register.get("aux_loss").asnumpy()
    mtp_loss = parameter_register.get("mtp_loss").asnumpy()
    lm_loss = parameter_register.get("lm_loss").asnumpy()
    indexer_loss = parameter_register.get("indexer_loss").asnumpy()
    parameter_register.clear("aux_loss")
    parameter_register.clear("mtp_loss")
    parameter_register.clear("lm_loss")
    parameter_register.clear("indexer_loss")
    return lm_loss, aux_loss, mtp_loss, indexer_loss


def _log_grouped_lr_info():
    """Log the current learning rate values for the default and grouped parameter sets."""
    from mindformers.trainer.optimizer_grouped_parameters import GROUPED_PARAMS
    if not GROUPED_PARAMS:  # Skip logging if no grouped parameters are registered
        return

    # Retrieve the current default learning rate from parameter registry
    default_lr = parameter_register.get("current_default_lr").asnumpy()
    logger.info(f"default_lr: {default_lr:.6e}, equal to `lr:` above.")

    # Retrieve the current grouped learning rate from parameter registry
    grouped_lr = parameter_register.get("current_grouped_lr").asnumpy()
    for group_id, params in enumerate(GROUPED_PARAMS):
        logger.info(f"group_{group_id}_lr: {grouped_lr[group_id]:.6e}, group_{group_id}_params: {params}")


def _get_loss_output(output):
    """Get output of task for MFLossMonitor."""
    overflow = False
    scaling_sens = False
    loss = output
    learning_rate = None
    global_norm = None
    local_norm = None

    if isinstance(output, (tuple, list)):
        if len(output) in [3, 4, 5, 7]:
            loss, overflow, scaling_sens, *res = output
            if len(res) == 4:
                learning_rate, global_norm, local_norm, norm_size = res[0], res[1], res[2], res[3]
                logger.info(f" norm_size: {norm_size}\nlocal_norm:\n{local_norm}")
            if len(res) == 2:
                learning_rate, global_norm = res[0], res[1]
            if len(res) == 1:
                learning_rate = res[0]
            if isinstance(scaling_sens, ms.Tensor):
                scaling_sens = scaling_sens.asnumpy()
        else:
            if isinstance(output[0], ms.Tensor) and isinstance(output[0].asnumpy(), np.ndarray):
                loss = output[0]

    if isinstance(global_norm, ms.Tensor):
        global_norm = global_norm.asnumpy()

    if isinstance(loss, ms.Tensor) and isinstance(loss.asnumpy(), np.ndarray):
        loss = np.mean(loss.asnumpy())

    if isinstance(overflow, ms.Tensor):
        overflow = overflow.asnumpy()

    if isinstance(learning_rate, ms.Tensor):
        learning_rate = learning_rate.asnumpy()

    return loss, overflow, scaling_sens, learning_rate, global_norm


def _get_weight_norm(network):
    """Get the L2 norm of network trainable parameters. Return 0 if there's no trainable parameter"""
    norms = []
    for param in network.trainable_params():
        norms.append(param.to(ms.float32).norm())
    if not norms:
        return 0.0
    norm = float(F.stack(norms).norm().item())
    return norm


def _get_max_eigenvalue(input_tensor, num_iter):
    """
    Calculate max eigenvalue
    https://www.cnblogs.com/zxn-share/p/17392450.html
    """
    input_tensor = input_tensor.astype(ms.float32)  # (m,n) or (b,m,n)
    in_features = input_tensor.shape[-1]  # (n)
    u_tensor = None
    for _ in range(5):
        u_tensor = ms.ops.randn(in_features)  # (n)
        u_norm = u_tensor.norm()
        if u_norm.asnumpy() > 0:
            break
    else:
        logger.warning("Calculate max eigenvalue: the norm of a randomly generated vector is 0")
        return 0.0
    u_tensor = u_tensor / u_tensor.norm()  # (n)
    input_seq = ms.ops.matmul(input_tensor.transpose(-2, -1), input_tensor)  # (n.n) or (b,n,n)
    if input_tensor.ndim == 2:
        input_seq = ms.ops.unsqueeze(input_seq, 0)  # (1,n,n)
    u_tensor = ms.ops.unsqueeze(u_tensor, 1)  # (n,1)
    for _ in range(num_iter):
        v_tensor = ms.ops.matmul(input_seq, u_tensor)  # (b,n,n) * (n,1) = (b,n,1)
        eigenvalue = ms.ops.matmul(v_tensor.transpose(-2, -1), u_tensor).squeeze()  # (b,1,n) * (b,n,1) = b
        v_norm = v_tensor.norm(dim=1, keepdim=True)  # (b,1,1)
        v_norm_safe = ms.ops.select(v_norm == 0, ms.ops.ones_like(v_norm), v_norm)
        u_tensor = v_tensor / v_norm_safe
    return eigenvalue


def _get_stable_rank(weight, num_iter):
    """Calculate stable rank"""
    # pylint: disable=W0718
    try:
        eig = _get_max_eigenvalue(weight, num_iter)
    except Exception as e:
        logger.warning(f"{weight.name} calculate max eigenvalue failed: {e}")
        return 0.0, 0.0
    if isinstance(eig, float) and np.isclose(eig, 0.0, atol=0.0, rtol=0.0):
        return 0.0, 0.0
    f_norm_square = ms.ops.square(ms.ops.norm(weight, ord='fro', dim=(-2, -1)))
    stable_rank = ms.ops.select(
        eig != 0,
        f_norm_square / eig,
        ms.ops.zeros_like(eig)
    )
    return stable_rank.asnumpy(), eig.asnumpy()


def _get_optimizer_state(optim_params, filter_fn: Callable = None):
    """Get the respective L2 norms of specified optimizer parameters. Return a dict"""
    norms = {}
    for param in optim_params:
        if filter_fn is None or filter_fn(param.name):
            norms[param.name] = float(param.to(ms.float32).norm().item())
    return norms


def _is_positive_natural_number(x):
    """Check if it is a positive natural number"""
    return isinstance(x, int) and x > 0


[文档]@MindFormerRegister.register(MindFormerModuleType.CALLBACK) class MFLossMonitor(Callback): """ Monitor loss and other parameters in training process. Args: learning_rate (Union[float, LearningRateSchedule], optional): The learning rate schedule. Default: ``None``. per_print_times (int, optional): Every how many steps to print the log information. Default: ``1``. micro_batch_num (int, optional): MicroBatch size for Pipeline Parallel. Default: ``1``. micro_batch_interleave_num (int, optional): split num of batch size. Default: ``1``. origin_epochs (int, optional): Training epoches. Default: ``None``. dataset_size (int, optional): Training dataset size. Default: ``None``. initial_epoch (int, optional): The beginning epoch. Default: ``0``. initial_step (int, optional): The beginning step. Default: ``0``. global_batch_size (int, optional): The total batch size. Default: ``0``. gradient_accumulation_steps (int, optional): The gradient accumulation steps. Default: ``1``. check_for_nan_in_loss_and_grad (bool, optional): Whether to check loss and norm of grad is Nan. Default: ``False``. calculate_per_token_loss (bool, optional): Whether to calculate the loss of each token. Default: ``False``. print_separate_loss (bool, optional): Whether to print loss separately. Default: ``False``. Examples: >>> from mindformers.core import MFLossMonitor >>> lr = [0.01, 0.008, 0.006, 0.005, 0.002] >>> monitor = MFLossMonitor(learning_rate=lr, per_print_times=10) """ def __init__(self, learning_rate: Optional[Union[float, LearningRateSchedule]] = None, per_print_times: int = 1, micro_batch_num: int = 1, micro_batch_interleave_num: int = 1, origin_epochs: int = None, dataset_size: int = None, initial_epoch: int = 0, initial_step: int = 0, global_batch_size: int = 0, gradient_accumulation_steps: int = 1, check_for_nan_in_loss_and_grad: bool = False, calculate_per_token_loss: bool = False, print_separate_loss: bool = False, **kwargs): super().__init__() self.per_print_times = per_print_times self.learning_rate = deepcopy(learning_rate) self.last_print_time = 0 self.mirco_size = micro_batch_num self.print_warning_flag = True self.loss_list = [] self.step_time = time.time() self.epoch_time = time.time() self.run_context = None self.steps_per_epoch = dataset_size self.micro_batch_interleave_num = micro_batch_interleave_num self.origin_epochs = origin_epochs self.initial_epoch = initial_epoch self.initial_step = initial_step self.global_batch_size = global_batch_size self.gradient_accumulation_steps = gradient_accumulation_steps self.device_num = get_real_group_size() self.mf_support = None self.mf_calculated = False self.current_phase = None self.full_model_flops = 0.0 self.tensor_writer = get_tensorboard_writer() self.tensorboard = get_tensorboard_args() self.check_for_nan_in_loss_and_grad = check_for_nan_in_loss_and_grad self.calculate_per_token_loss = calculate_per_token_loss self.mstx_range_id = None self.mstx_enabled = _check_mspti_is_on() self.print_separate_loss = print_separate_loss self.is_moe_model = kwargs.get("is_moe_model", False) self.is_mtp_model = kwargs.get("is_mtp_model", False) if self.print_separate_loss and is_legacy_model(): logger.warning("print_separate_loss = True, is not supported when use_legacy = True.") self.print_separate_loss = False if self.print_separate_loss and not self.is_moe_model and not self.is_mtp_model: self.print_separate_loss = False def on_train_epoch_begin(self, run_context): """ Record time at the beginning of epoch. Args: run_context (RunContext): Context of the process running. """ self.loss_list = [] self.epoch_time = time.time() self.run_context = run_context def on_train_step_begin(self, run_context): """ Record time at the beginning of step. Args: run_context (RunContext): Context of the process running. """ self.step_time = time.time() self.run_context = run_context if self.mstx_enabled: cb_params = run_context.original_args() step_num = cb_params.cur_step_num self.mstx_range_id = ms.profiler.mstx.range_start(f'step {step_num}', ms.runtime.current_stream()) def on_train_step_end(self, run_context): """ Print training info at the end of step. Args: run_context (RunContext): Context of the process running. """ parallel_mode = get_auto_parallel_context("parallel_mode") full_batch = get_auto_parallel_context("full_batch") auto_parallel = parallel_mode in ['semi_auto_parallel', 'auto_parallel'] if auto_parallel: set_auto_parallel_context(parallel_mode='data_parallel', full_batch=False) cb_params = run_context.original_args() step_seconds = (time.time() - self.step_time) * 1000 if self.mstx_enabled: ms.profiler.mstx.range_end(self.mstx_range_id) net_outputs = cb_params.net_outputs loss, overflow, scaling_sens, learning_rate, global_norm = _get_loss_output(net_outputs) if learning_rate is not None: self.learning_rate = learning_rate loss = self._fix_loss_for_parallel(loss) self.loss_list.append(loss) lm_loss, aux_loss, mtp_loss, indexer_loss = None, None, None, None if self.print_separate_loss: lm_loss, aux_loss, mtp_loss, indexer_loss = _get_separate_loss() lm_loss = self._fix_loss_for_parallel(lm_loss, print_warning=False) aux_loss = self._fix_loss_for_parallel(aux_loss, print_warning=False) mtp_loss = self._fix_loss_for_parallel(mtp_loss, print_warning=False) indexer_loss = self._fix_loss_for_parallel(indexer_loss, print_warning=False) if not overflow: overflow = "False" if not scaling_sens: scaling_sens = "unavailable" # Choose calculation method based on legacy model status if is_legacy_model(): self._calculate_flops_legacy(cb_params) else: self._calculate_flops_mcore(cb_params) self._reset_accu_gbs_fi(cb_params) origin_epochs = self.origin_epochs if cb_params.get('initial_step', None) is not None: self.initial_step = cb_params.initial_step if cb_params.dataset_sink_mode: per_step_seconds = step_seconds / cb_params.batch_num steps_per_epoch = self.steps_per_epoch cur_epoch_num = (cb_params.cur_step_num + self.initial_step - 1) // steps_per_epoch + 1 cur_step_num = (cb_params.cur_step_num + self.initial_step - 1) % steps_per_epoch + 1 else: per_step_seconds = step_seconds steps_per_epoch = cb_params.batch_num cur_epoch_num = cb_params.cur_epoch_num cur_step_num = (cb_params.cur_step_num + self.initial_step - 1) % steps_per_epoch + 1 # compute time remaining step_remain = (origin_epochs - cur_epoch_num + 1) * steps_per_epoch - cur_step_num time_remain = step_remain * per_step_seconds / 1000 # compute throughput throughput = self.global_batch_size / self.device_num / (per_step_seconds / 1000) # compute percent percent = ((cur_epoch_num - 1) * steps_per_epoch + cur_step_num) / origin_epochs / steps_per_epoch * 100 step_diff = cb_params.cur_step_num - self.last_print_time if step_diff >= self.per_print_times or step_diff <= 0: self.last_print_time = cb_params.cur_step_num self.print_output_info(cb_params, cur_epoch_num, origin_epochs, throughput, cur_step_num, steps_per_epoch, loss, per_step_seconds, overflow, scaling_sens, time_remain, percent, global_norm, lm_loss, aux_loss, mtp_loss, indexer_loss) if auto_parallel: set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch) def _reset_accu_gbs_fi(self, cb_params): """Reset the accumulated FI tensor of each MoE router back to zeros.""" network = get_real_models(cb_params.train_network) transformer_config = network.get_gpt_transformer_config() if transformer_config.moe_router_load_balancing_type == "gbs_aux_loss": model = network.get_gpt_model() if hasattr(model, "reset_accu_gbs_fi"): model.reset_accu_gbs_fi() else: raise NotImplementedError(f"network: {network} does not Implemented function 'reset_accu_gbs_fi'") def _fix_loss_for_parallel(self, loss, print_warning=True): """Fix loss value in pipeline or double parallel mode.""" pipeline_stages = ms.context.get_auto_parallel_context("pipeline_stages") self.is_zbv = ms.get_auto_parallel_context("pipeline_scheduler") == "zero_bubble_v" if self.is_zbv and self.print_warning_flag and print_warning: logger.warning("When zero_bubble_v is enabled, loss is valid only on rank 0") else: if pipeline_stages > 1 and self.print_warning_flag and print_warning: logger.warning("pipeline stages: %s > 1, the loss on the last card is valid.", pipeline_stages) if self.micro_batch_interleave_num > 1 and self.print_warning_flag and print_warning: logger.warning("micro_batch_interleave_num: %s > 1, multiple copies in parallel is open.", self.micro_batch_interleave_num) if pipeline_stages > 1 and not self.calculate_per_token_loss: loss = loss / self.mirco_size if self.micro_batch_interleave_num > 1: loss = loss / self.micro_batch_interleave_num if self.gradient_accumulation_steps > 1 and not self.calculate_per_token_loss: loss = loss / self.gradient_accumulation_steps return loss @staticmethod def _get_pipeline_group(): """ Calculate the communication group between all pipeline stages """ rank = get_rank() stage_nums = auto_parallel_context().get_pipeline_stages() device_nums = get_group_size() per_stage_device_nums = device_nums // stage_nums local_stage_rank_id = rank % per_stage_device_nums group = range(0, stage_nums) rank_list = [local_stage_rank_id + x * per_stage_device_nums for x in group] rank_str_list = [str(r) for r in rank_list] rank_list_str = "-".join(rank_str_list) return rank_list, rank_list_str def _calculate_flops_legacy(self, cb_params): """Calculate FLOPs for legacy models using runtime collection API.""" if self.mf_support is None: self.mf_support = self._can_calculate_model_flops(cb_params) if (not self.mf_calculated or check_arf_status(cb_params)) and self.mf_support: self._calculate_model_flops() def _calculate_flops_mcore(self, cb_params): """Calculate FLOPs for mcore models using configuration-based method.""" if not self.mf_calculated or check_arf_status(cb_params): network = cb_params.train_network if cb_params.mode == 'train' else cb_params.eval_network real_network = get_real_models(network) config = real_network.get_gpt_transformer_config() batch_size = self.global_batch_size args = convert_transformer_config_to_args_for_tflops(config) self.full_model_flops = num_floating_point_operations(args, batch_size) self.mf_calculated = True logger.info("Full model flops calculated: %d", self.full_model_flops) def _can_calculate_model_flops(self, cb_params): """ Check whether the model flops can be collected for legacy models. """ if cb_params.mode == 'train': network = cb_params.train_network elif cb_params.mode == 'eval': network = cb_params.eval_network else: logger.warning('Model Flops computation only support train and eval mode!') return False if get_context('mode') != ms.GRAPH_MODE: logger.warning('Model Flops computation only support graph mode in legacy mode!') return False if not hasattr(network, 'current_phase'): logger.warning('This model dose not support collecting model flops now in legacy mode!') return False if self.is_moe_model: logger.warning("Model Flops computation is not support when using GroupMatMul MoELayer " "in legacy mode, due to dynamic shape") return False self.current_phase = network.current_phase return True def _calculate_model_flops(self): """ Calculate the full model flops """ try: full_model_flops, _, shard_model_flops, \ _, is_dynamic_shape = flops_collection(self.current_phase) except RuntimeError as e: logger.warning("%s", e) self.mf_support = False return self.full_model_flops = full_model_flops / 1.0 self.mf_calculated = True if auto_parallel_context().get_pipeline_stages() > 1: pipeline_group_list, pipeline_group_name = self._get_pipeline_group() hashed = hashlib.sha256( pipeline_group_name.encode()).hexdigest()[:48] pipeline_group_name = str(hashed) create_group(pipeline_group_name, pipeline_group_list) is_dynamic_shape = AllReduceNet(pipeline_group_name)( Tensor([int(is_dynamic_shape)], dtype=ms.int32)).asnumpy()[0] if is_dynamic_shape > 0: logger.warning("Model Flops computation now do not support dynamic shape.") self.mf_support = False return self.full_model_flops = AllReduceNet(pipeline_group_name)( Tensor([self.full_model_flops])).asnumpy()[0] if is_dynamic_shape: logger.warning("Model Flops computation now do not support dynamic shape.") self.mf_support = False return if auto_parallel_context().get_parallel_mode() != "stand_alone": self.full_model_flops = self.full_model_flops / get_group_size() logger.info("Full model flops is %d, Shard model flops is %d.", full_model_flops, shard_model_flops) def print_output_info(self, cb_params, cur_epoch_num, origin_epochs, throughput, cur_step_num, steps_per_epoch, loss, per_step_seconds, overflow, scaling_sens, time_remain, percent, global_norm, main_loss, extra_loss, mtp_loss, indexer_loss): """print output information.""" if self.learning_rate is not None: if isinstance(self.learning_rate, (float, Tensor, np.ndarray)): current_lr = str(self.learning_rate) elif isinstance(self.learning_rate, LearningRateSchedule): if get_context('device_target') == 'CPU': if self.print_warning_flag: logger.warning( "device target not support CPU when generating the learning rate value, " "please use: mindspore.set_device('Ascend')") self.print_warning_flag = False current_lr = None else: if cb_params.optimizer is not None: global_step = cb_params.optimizer.global_step else: global_step = cb_params.network.optimizer.global_step # temporary set_train to avoid error on Atlas 800T A2 origin_phase = cb_params.train_network.phase cb_params.train_network.set_train(False) current_lr = self.learning_rate(global_step) cb_params.train_network.set_train(origin_phase) current_lr = np.array2string(current_lr.asnumpy()) else: if self.print_warning_flag: logger.warning( "The current learning rate cannot be calculated in real time." "Only the type of LearningRateSchedule is supported in the callback of MFLossMonitor," "but the input learning rate function type is %s", type(self.learning_rate) ) self.print_warning_flag = False current_lr = None else: if self.print_warning_flag: logger.warning( "MFLossMonitor callback is not set learning rate arguments." "To display the learning rate, you must input the arguments, " "which can be LearningRateSchedule or a fixed float" ) self.print_warning_flag = False current_lr = None global_step = cur_step_num + (cur_epoch_num - 1) * steps_per_epoch if self.mf_calculated: if is_legacy_model(): throughput_per_npu = self.full_model_flops / per_step_seconds / 1e9 else: throughput_per_npu = self.full_model_flops / per_step_seconds / 1e9 / self.device_num throughput_info = f', train_throughput_per_npu: {throughput_per_npu:.3f}T' if self.tensor_writer is not None: self.tensor_writer.add_scalar('model-flops-throughput-per-npu', float(throughput_per_npu), global_step=global_step) else: throughput_info = '' if cb_params.dataset_sink_mode: loss_info = f"loss: {loss:5.6f}, " else: loss_info = f"loss:[{loss:5.6f}/{np.mean(self.loss_list):5.6f}], " if self.print_separate_loss: separate_loss = f"lm_loss: {main_loss[0]:5.6f}, " if self.is_moe_model and np.all(extra_loss > 0): separate_loss += f"load_balancing_loss: {extra_loss[0]:5.6f}, " if self.is_mtp_model: separate_loss += f"mtp_loss: {mtp_loss[0]:5.6f}, " if np.all(indexer_loss > 0): separate_loss += f"indexer_loss: {indexer_loss[0]:5.6f}, " else: separate_loss = "" if current_lr is not None: logger.info("{ Epoch:[%3d/%3d], step:[%5d/%5d], " + loss_info + separate_loss + "per_step_time: %dms, lr: %s, overflow cond: %s, loss_scale: %s, global_norm: %s%s", cur_epoch_num, origin_epochs, cur_step_num, steps_per_epoch, int(per_step_seconds), current_lr, overflow, scaling_sens, global_norm, throughput_info) if self.tensor_writer is not None: self.tensor_writer.add_scalar('learning-rate', float(current_lr), global_step=global_step) self.tensor_writer.add_scalar('learning-rate vs samples', float(current_lr), global_step=global_step * self.global_batch_size) else: logger.info("{ Epoch:[%3d/%3d], step:[%5d/%5d], " + loss_info + separate_loss + "per_step_time: %dms, overflow cond: %s, loss_scale: %s, global_norm: %s%s", cur_epoch_num, origin_epochs, cur_step_num, steps_per_epoch, int(per_step_seconds), overflow, scaling_sens, global_norm, throughput_info) # print progress bar bar = int(50 * percent / 100) * "█" show_str = f"|{bar:<50}|" logger.info(" %4.1f%% %s %.5f samples/s/p %s }", percent, show_str, throughput, timedelta(seconds=int(time_remain))) # log grouped lr info if enabled _log_grouped_lr_info() # write tensorboard if self.tensor_writer is not None: self.tensor_writer.add_scalar('batch-size', self.global_batch_size, global_step=global_step) self.tensor_writer.add_scalar('batch-size vs samples', self.global_batch_size, global_step=global_step * self.global_batch_size) self.tensor_writer.add_scalar('loss', loss, global_step=global_step) self.tensor_writer.add_scalar('loss vs samples', loss, global_step=global_step * self.global_batch_size) if self.tensorboard.get('log_loss_scale_to_tensorboard', False): self.tensor_writer.add_scalar('loss-scale', scaling_sens, global_step=global_step) self.tensor_writer.add_scalar('loss-scale vs samples', scaling_sens, global_step=global_step * self.global_batch_size) self.tensor_writer.add_scalar('grad-norm', global_norm, global_step=global_step) self.tensor_writer.add_scalar('grad-norm vs samples', global_norm, global_step=global_step * self.global_batch_size) if self.tensorboard.get('log_timers_to_tensorboard', False): self.tensor_writer.add_scalar('iteration-time', int(per_step_seconds), global_step=global_step) self.tensor_writer.add_scalar('iteration-time vs samples', int(per_step_seconds), global_step=global_step * self.global_batch_size) self.tensor_writer.add_scalar('throughput', throughput, global_step=global_step) seconds_per_day = 86400 billion_samples_per_day = throughput * get_group_size() * seconds_per_day / 1e9 self.tensor_writer.add_scalar('B-samples-per-day', billion_samples_per_day, global_step=global_step) self.tensor_writer.add_scalar('throughput vs samples', throughput, global_step=global_step * self.global_batch_size) if self.print_separate_loss: self.tensor_writer.add_scalar('lm-loss', main_loss, global_step=global_step) if self.is_mtp_model: self.tensor_writer.add_scalar('mtp-loss', mtp_loss, global_step=global_step) if self.is_moe_model: self.tensor_writer.add_scalar('load-balancing-loss', extra_loss, global_step=global_step) if np.all(indexer_loss > 0): self.tensor_writer.add_scalar('indexer-loss', indexer_loss, global_step=global_step)
[文档]@MindFormerRegister.register(MindFormerModuleType.CALLBACK) class TrainingStateMonitor(Callback): """ Monitor metrics such as local norm and local loss in training process. Args: origin_epochs (int): Required. Training epoches. config (dict, optional): The config specified how to display metrics. Keys are shown below. Default: ``None``, mean that keys will be set as the default values as below. - target: Specify the name or regular expression of params to monitor. Must be list of str, e.g. ["layers.[01]", "attention"]. Default: ['*'] , all params are selected. - invert: Whether to invert `target`, i.e. params in `target` won't be monitored. Must be `bool`. Default: `False` - local_norm_format: Determine where to display the local norm. Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file), or a `list` containing them, or ``None``. Only params specified will be monitored. may cause a large amount of print info if 'log' is selected. Set to ``None`` to ignore this metric. Default: ``None``. - device_local_norm_format: Determine where to display the device local norm. Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file), or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric. Default: ``None``. - local_loss_format: Determine where to display the local loss. Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file), or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric. Default: ``None``. - device_local_loss_format: Determine where to display the device local loss. Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file), or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric. Default: ``None``. - optimizer_state_format: Determine where to display the optimizer state. Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file), or a `list` containing them, or ``None``. Only the optimizer state of params specified will be monitored, may cause a large amount of print info if 'log' is selected. Set to ``None`` to ignore this metric. Default: ``None``. - weight_state_format: Determine where to display the weight L2-norm. Should be a `str` in ['tensorboard', 'log'] (mean that write data to tensorboard or log file), or a `list` containing them, or ``None``. Set to ``None`` to ignore this metric. Default: ``None``. - throughput_baseline: The model throughput baseline to calculate linearity. Must be a positive number. Will be displayed both to tensorboard and log. Set to ``None`` to ignore this metric. Default: ``None``. - print_struct: Whether to print the structure of model. If ``True``, callback will print the names of all trainable params at the first step and then quit training process. Default: ``False``. step_interval (int, optional): Every how many steps to display metrics. Default: ``1``. dataset_size (int, optional): Required in sink mode. Training dataset size. Default: ``None``. initial_epoch (int, optional): The beginning epoch. Default: ``0``. initial_step (int, optional): The beginning step. Default: ``0``. micro_batch_num (int, optional): MicroBatch size for Pipeline Parallel. Default: ``0``. global_batch_size (int, optional): The total batch size. Default: ``0``. tensor_model_parallel_size (int, optional): Tensor model parallel size. Default: ``0``. check_for_nan_in_loss_and_grad (bool, optional): Whether to check loss and norm of grad is Nan. Default: ``False``. use_skip_data_by_global_norm (bool, optional): Whether to use the skip data function by global norm. Default: ``False``. embedding_size (int, optional): The size of embedding norm which is get by hidden_size * vocab_size. Default: ``4096``. use_local_norm (bool, optional): Whether to turn on the local norm. Default: ``False``. """ @args_type_check(embedding_size=int, use_skip_data_by_global_norm=bool) def __init__(self, origin_epochs: int, config: dict = None, step_interval: int = 1, dataset_size: int = None, initial_epoch: int = 0, initial_step: int = 0, micro_batch_num: int = 0, global_batch_size: int = 0, tensor_model_parallel_size: int = 0, check_for_nan_in_loss_and_grad: bool = False, use_skip_data_by_global_norm: bool = False, embedding_size: int = 4096, use_local_norm: bool = False): super().__init__() if not _is_positive_natural_number(step_interval): raise TypeError("The value of 'monitor_config.step_interval' should be positive integer, " f"but get {step_interval}.") self.step_interval = step_interval self.last_print_time = 0 self.step_time = time.time() self.epoch_time = time.time() self.run_context = None self.steps_per_epoch = dataset_size self.origin_epochs = origin_epochs self.initial_epoch = initial_epoch self.initial_step = initial_step self.micro_batch_num = micro_batch_num self.global_batch_size = global_batch_size self.tensor_model_parallel_size = tensor_model_parallel_size self.global_norm_spike_count = 0 self.use_skip_data_by_global_norm = use_skip_data_by_global_norm self.embedding_size = embedding_size self.use_local_norm = use_local_norm self.device_num = get_real_group_size() self.tensor_writer = get_tensorboard_writer() self.outputer = {'tensorboard': self._to_tensorboard, 'log': self._to_log} self._init_config(config) self.dump_path = None self.check_for_nan_in_loss_and_grad = check_for_nan_in_loss_and_grad if get_auto_parallel_context("dump_local_norm_path"): self.dump_path = os.path.join(get_auto_parallel_context("dump_local_norm_path"), f'rank_{get_real_rank()}') self.dump_key = {0: -1} self.dump_step = step_interval self.dump_name_mode = 0 self.finish_pattern = 'finish_step_*_*' self.local_loss_pattern = re.compile('(local_loss)__(.+)_[a-z]+[0-9]+_([0-9]+)') self.local_norm_pattern = re.compile('(local_norm)__(.+)_[a-z]+[0-9]+_([0-9]+)') self.device_local_norm_pattern = re.compile('(device_local_norm)_[a-z]+[0-9]+_([0-9]+)') # when pipeline_stages > 2, param aggregation is not supported for now pp_parallel = context.get_auto_parallel_context("pipeline_stages") > 1 if pp_parallel and self.sr_format and self.do_aggregation: raise TypeError("When pipeline_stages > 1, weight aggregation is not supported") def on_train_epoch_begin(self, run_context): """ Record time at the beginning of epoch. Args: run_context (RunContext): Context of the process running. """ self.epoch_time = time.time() self.run_context = run_context def on_train_step_begin(self, run_context): """ Record time at the beginning of step. Args: run_context (RunContext): Context of the process running. """ self.step_time = time.time() self.run_context = run_context if self.print_struct: network = run_context.original_args().network if isinstance(network, ms.nn.TrainOneStepCell): network = network.network for param in network.trainable_params(): logger.info(param.name) self.run_context.request_stop() def on_train_step_end(self, run_context): """ Print training info at the end of step. Args: run_context (RunContext): Context of the process running. """ if self.print_struct: self._clear_dump_path() return step_seconds = (time.time() - self.step_time) * 1000 parallel_mode = get_auto_parallel_context("parallel_mode") full_batch = get_auto_parallel_context("full_batch") auto_parallel = parallel_mode in ['semi_auto_parallel', 'auto_parallel'] if auto_parallel: set_auto_parallel_context(parallel_mode='data_parallel', full_batch=False) cb_params = run_context.original_args() if cb_params.dataset_sink_mode: per_step_seconds = step_seconds / cb_params.batch_num else: self.steps_per_epoch = cb_params.batch_num per_step_seconds = step_seconds step_diff = cb_params.cur_step_num - self.last_print_time if step_diff >= self.step_interval or step_diff <= 0: self.last_print_time = cb_params.cur_step_num if get_auto_parallel_context("dump_local_norm_path"): self._dump_data_in_step(cb_params.cur_step_num) if self.optimizer_state_format: self._dump_optimizer_state(cb_params) if self.max_attention_logit_format: self._dump_max_attention_logit(cb_params) if self.weight_state_format: self._calc_weight_state(cb_params) if self.throughput_baseline is not None: self._calc_throughput_linearity(cb_params, per_step_seconds) if self.device_local_loss_format: self._calc_device_local_loss(cb_params) if self.sr_format: # cal stable rank and max eigenvalue self._do_stable_rank(cb_params) if auto_parallel: set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch) if self.use_local_norm and self.embedding_size is not None: embedding_local_norm = get_embedding_info(cb_params, self.embedding_size) logger.info("embedding_local_norm: %s", embedding_local_norm) self.abnormal_global_norm_check(cb_params) # Boundary check. if self.check_for_nan_in_loss_and_grad: self._boundary_check(cb_params) if self.device_local_loss_format: reset_device_local_loss() def abnormal_global_norm_check(self, cb_params): """Check the abnormal global_norm and raise error""" if cb_params.get('initial_step', None) is not None: self.initial_step = cb_params.initial_step if cb_params.dataset_sink_mode: steps_per_epoch = self.steps_per_epoch cur_epoch_num = (cb_params.cur_step_num + self.initial_step - 1) // steps_per_epoch + 1 cur_step_num = (cb_params.cur_step_num + self.initial_step - 1) % steps_per_epoch + 1 else: steps_per_epoch = cb_params.batch_num cur_epoch_num = cb_params.cur_epoch_num cur_step_num = (cb_params.cur_step_num + self.initial_step - 1) % steps_per_epoch + 1 net_outputs = cb_params.net_outputs global_norm = self._get_loss_output(net_outputs)[1] global_step = cur_step_num + (cur_epoch_num - 1) * steps_per_epoch if self.check_for_global_norm and self.use_skip_data_by_global_norm: raise ValueError("The check_for_global_norm and use_skip_data_by_global_norm" " cannot be turned on at the same time, please choose one.") if self.check_for_global_norm and global_norm >= self.global_norm_spike_threshold: if str(global_step) not in self.abnormal_global_norms: # Because json cannot use number as key, so we convert it to string self.abnormal_global_norms[str(global_step)] = [global_norm.item()] if get_rank() == 0: parent_dirs = os.path.dirname(self.global_norm_record_path) if not os.path.exists(parent_dirs): os.makedirs(parent_dirs) with open(self.global_norm_record_path, 'w', encoding="utf-8") as file: json.dump(self.abnormal_global_norms, file) set_safe_mode_for_file_or_dir(self.global_norm_record_path) logger.info(f"Current global norm {global_norm} is greater equal than " f"threshold {self.global_norm_spike_threshold}, stop training...") barrier_world() logger.info("Call barrier before throw TREError.") ms.runtime.synchronize() logger.info("All stream execution completed.") raise RuntimeError("TREError occurred......") self.abnormal_global_norms[str(global_step)].append(global_norm.item()) logger.info(f"The global norm {global_norm} of step {global_step} is still greater or equal " f"than threshold {self.global_norm_spike_threshold}, continue training.") if self.use_skip_data_by_global_norm: opt_global_step = cb_params.optimizer.global_step \ if cb_params.optimizer is not None else cb_params.network.optimizer.global_step is_skip = global_norm >= self.global_norm_spike_threshold if is_skip: logger.info("opt_global_step: %d, skip_data_grad_norm_threshold: %s, is_skip: %s", opt_global_step, self.global_norm_spike_threshold, is_skip) self.global_norm_spike_count += 1 if self.global_norm_spike_count < self.global_norm_spike_count_threshold: logger.info(f"Current global norm {global_norm} of step {global_step} " f"has been {self.global_norm_spike_count} " "consecutive times greater than threshold: " f"{self.global_norm_spike_threshold}") else: raise ValueError( f"Current global norm {global_norm} of step {global_step} " f"has been {self.global_norm_spike_count_threshold} " "consecutive times greater than threshold " f"{self.global_norm_spike_threshold}, stop training...") else: self.global_norm_spike_count = 0 def _calc_weight_state(self, cb_params): """calculate local weight_state""" network = cb_params.network if isinstance(network, ms.nn.TrainOneStepCell): network = network.network weight_norm = _get_weight_norm(network) self._output('weight_norm', weight_norm, cb_params.cur_step_num, self.weight_state_format) def _calc_throughput_linearity(self, cb_params, per_step_seconds): """calculate throughput_linearity""" throughput = self.global_batch_size / self.device_num / (per_step_seconds / 1000) linearity = throughput / self.throughput_baseline self._output('throughput_linearity', linearity, cb_params.cur_step_num, ['log', 'tensorboard']) def _calc_device_local_loss(self, cb_params): """calculate device local loss""" for loss_tag, device_local_loss in get_device_local_loss(None).items(): device_local_loss = np.mean(device_local_loss.asnumpy()) self._output(f'device_accum_local_{loss_tag}_loss', device_local_loss, cb_params.cur_step_num, self.device_local_loss_format) def _do_stable_rank(self, cb_params): """do stable_rank""" ms.runtime.empty_cache() sr_step_diff = cb_params.cur_step_num - self.sr_last_print_time if sr_step_diff >= self.sr_step_interval or sr_step_diff <= 0: self.sr_last_print_time = cb_params.cur_step_num self._calc_stable_rank(cb_params) def _calc_stable_rank(self, cb_params): """calculate stable_rank""" network = cb_params.train_network parallel_mode = context.get_auto_parallel_context("parallel_mode") if parallel_mode == "stand_alone": for param in network.network.trainable_params(): self._print_stable_rank(param.name, param, cb_params.cur_step_num) return save_param_names = self._get_remove_redundancy_param_names(network) if save_param_names is None: return if self.do_aggregation: parameter_layout_dict = network.parameter_layout_dict for param in network.network.trainable_params(): if (not self._check_sr_target(param.name)) or param.name.startswith("accu_grads"): continue # do aggregation if param.name in parameter_layout_dict: param_data = Tensor(param.data.asnumpy()) param_data = _get_merged_param_data(network, parameter_layout_dict, param.name, param_data, True) # within a communication group, only the first rank data is to be calculated redundancy_layout = self._get_single_params(network) if self._get_redundancy_removed_print(redundancy_layout, param.name): self._print_stable_rank(param.name, param_data, cb_params.cur_step_num) else: for param in network.network.trainable_params(): if self._check_sr_target(param.name) and param.name in save_param_names: self._print_stable_rank(param.name, param, cb_params.cur_step_num) def _boundary_check(self, cb_params): """boundary check""" loss, global_norm, local_norm = self._get_loss_output(cb_params.net_outputs) check_device_local_loss() self._check_nan_or_inf(loss, 'loss') self._check_nan_or_inf(global_norm, 'global_norm') self._check_nan_or_inf(local_norm, 'local_norm') def _get_redundancy_removed_print(self, redundancy_layout, name): """get parameter which has redundancy removed to print""" if redundancy_layout is None: return False for rankid, group in redundancy_layout.items(): if name in group: if rankid == get_real_rank(): return True logger.info(f"stable_rank for {name} is printed in rank{rankid}.") return False return False def _get_single_params(self, network): """get non-redundancy parameters dict""" parameter_layout_dict = network.parameter_layout_dict if parameter_layout_dict is None: return None device_num = get_real_group_size() stage_num = get_auto_parallel_context("pipeline_stages") chunk_size = device_num // stage_num rank_id = get_real_rank() initial_rank = (rank_id // chunk_size) * chunk_size param_redundancy_dict = get_parameter_redundancy(parameter_layout_dict, initial_rank) single_params = remove_param_redundancy(param_redundancy_dict) return single_params def _get_remove_redundancy_param_names(self, network): """remove redundancy parameters for this rank""" single_params = self._get_single_params(network) if single_params is None: return None rank_id = get_real_rank() save_param_names = single_params.get(rank_id) return save_param_names def _check_sr_target(self, param_name): if self.sr_target_cache.get(param_name) is None: for pattern in self.sr_target: if re.search(pattern, param_name) is not None: self.sr_target_cache[param_name] = True return True self.sr_target_cache[param_name] = False return False return self.sr_target_cache[param_name] def _init_stable_rank_config(self, config): """Initialize stable rank config""" if hasattr(config.get('stable_rank_config'), "get"): self.sr_format = config.get('stable_rank_config').get('format', None) self.sr_step_interval = config.get('stable_rank_config').get('step_interval', 100) if not _is_positive_natural_number(self.sr_step_interval): raise TypeError("'monitor_config.stable_rank_config.step_interval' should be positive integer," f"but get {self.sr_step_interval}.") self.sr_last_print_time = 0 self.sr_target = config.get('stable_rank_config').get('target') or ['.*'] self.sr_target_cache = {} self.do_aggregation = config.get('stable_rank_config').get('do_aggregation', False) self.moe_show_mode = config.get('stable_rank_config').get('moe_show_mode') or ["all"] self.power_iteration_num = config.get('stable_rank_config').get('power_iteration_num', 5) if not _is_positive_natural_number(self.power_iteration_num): raise TypeError("'monitor_config.stable_rank_config.power_iteration_num' should be positive integer," f"but get {self.power_iteration_num}.") else: self.sr_format = None def _init_global_norm_monitor_config(self, config): """Initialize global norm monitor config""" self.check_for_global_norm = config.get('check_for_global_norm') self.global_norm_record_path = os.path.join(get_output_root_path(), "abnormal_global_norm.json") self.global_norm_spike_threshold = config.get('global_norm_spike_threshold') self.health_checkpoint = config.get('health_checkpoint', None) if self.health_checkpoint: self.global_norm_spike_threshold = self.health_checkpoint.global_norm_spike_threshold self.check_for_global_norm = bool(self.global_norm_spike_threshold) self.global_norm_spike_count_threshold = config.get('global_norm_spike_count_threshold', 10) if not _is_positive_natural_number(self.global_norm_spike_count_threshold): raise TypeError("'monitor_config.global_norm_spike_count_threshold' should be positive integer, " f"but get {self.global_norm_spike_count_threshold}.") self.abnormal_global_norms: dict[str, list[float]] = {} if self.global_norm_record_path and os.path.exists(self.global_norm_record_path): # the data format might be like {"300": [3.3], "600": [4.1, 4.2],} # because json cannot use number as key, we convert it to string with open(self.global_norm_record_path, 'r', encoding="utf-8") as file: self.abnormal_global_norms = json.load(file) def _init_config(self, config): """Initialize members from config""" if config is None: logger.warning("The param `config` of TrainingStateMonitor is not set. Will use the default config.") config = {} if not isinstance(config, dict): raise TypeError("The param `config` of TrainingStateMonitor should be a dict.") self.target = config.get('target') or ['.*'] self.invert = config.get('invert') if self.invert is None: self.invert = False self.target_cache = {} self.local_norm_format = config.get('local_norm_format', None) self.local_loss_format = config.get('local_loss_format', None) self.device_local_norm_format = config.get('device_local_norm_format', None) self.device_local_loss_format = \ config.get('device_local_loss_format', None) if is_last_pipeline_stage() else None self.max_attention_logit_format = config.get('max_attention_logit_format', None) self.optimizer_state_format = config.get('optimizer_state_format', None) self.weight_state_format = config.get('weight_state_format', None) self._init_stable_rank_config(config) self._init_global_norm_monitor_config(config) self.throughput_baseline = config.get('throughput_baseline', None) self.print_struct = config.get('print_struct') if self.print_struct is None: self.print_struct = False if not (isinstance(self.target, list) and self.target and all(isinstance(i, str) for i in self.target)): raise TypeError("The value of 'target' should be a list of str.") if not isinstance(self.invert, bool): raise TypeError("The value of 'invert' should be bool.") if (self.throughput_baseline is not None and not (isinstance(self.throughput_baseline, (int, float)) and self.throughput_baseline > 0)): raise ValueError("The value of 'throughput_baseline' should be None or positive number.") if not isinstance(self.print_struct, bool): raise TypeError("The value of 'print_struct' should be bool.") attrs = ['local_norm_format', 'local_loss_format', 'device_local_norm_format', 'device_local_loss_format', 'optimizer_state_format', 'weight_state_format', 'max_attention_logit_format'] for attr in attrs: self._check_attr_formats(attr) def _print_stable_rank(self, name, param, cur_step_num): """output stable rank and max eigenvalues""" if param.ndim > 3 or param.ndim < 2: logger.warning(f"Calculate {name} stable rank: input tensor should be 2/3-dimensional," f"actual dim: {param.ndim}") elif param.ndim == 2: stable_rank, eigenvalue = _get_stable_rank(param, self.power_iteration_num) if not isinstance(stable_rank, np.ndarray) and np.isclose(stable_rank, 0.0, atol=0.0, rtol=0.0): logger.info(f"{name}'s stable rank is 0.0 or some exception happened, check warning above.") self._output(f'weight_stable_rank/{name}', stable_rank, cur_step_num, self.sr_format) self._output(f'weight_eigenvalue/{name}', eigenvalue, cur_step_num, self.sr_format) else: stable_rank, eigenvalue = _get_stable_rank(param, self.power_iteration_num) if not isinstance(stable_rank, np.ndarray) and np.isclose(stable_rank, 0.0, atol=0.0, rtol=0.0): logger.info(f"{name}'s stable rank some exception happened, check warning above.") return if self.moe_show_mode in ('all', 'full'): for index, sr in enumerate(stable_rank): self._output(f'weight_stable_rank/{name}/expert_{index}', sr, cur_step_num, self.sr_format) for index, ev in enumerate(eigenvalue): self._output(f'weight_eigenvalue/{name}/expert_{index}', ev, cur_step_num, self.sr_format) if self.moe_show_mode in ('all', 'statistics'): sr_max = np.max(stable_rank) self._output(f'weight_stable_rank/{name}/expert_max', sr_max, cur_step_num, self.sr_format) sr_min = np.min(stable_rank) self._output(f'weight_stable_rank/{name}/expert_min', sr_min, cur_step_num, self.sr_format) sr_mean = np.mean(stable_rank) self._output(f'weight_stable_rank/{name}/expert_mean', sr_mean, cur_step_num, self.sr_format) ev_max = np.max(eigenvalue) self._output(f'weight_eigenvalue/{name}/expert_max', ev_max, cur_step_num, self.sr_format) ev_min = np.min(eigenvalue) self._output(f'weight_eigenvalue/{name}/expert_min', ev_min, cur_step_num, self.sr_format) ev_mean = np.mean(eigenvalue) self._output(f'weight_eigenvalue/{name}/expert_mean', ev_mean, cur_step_num, self.sr_format) def _check_attr_formats(self, attr): """Check the validation of formats in config""" if getattr(self, attr): if not isinstance(getattr(self, attr), (str, list)): raise TypeError(f"The value of {attr} should be a `str` in 'tensorboard' or 'log', " f"or a list containing them, or None, but get type {type(getattr(self, attr))}") if isinstance(getattr(self, attr), str): setattr(self, attr, set([getattr(self, attr)])) else: setattr(self, attr, set(getattr(self, attr))) diff = getattr(self, attr) - {'tensorboard', 'log'} if diff: raise ValueError(f"The value of {attr} should be a `str` in 'tensorboard' or 'log', " f"or a list containing them, or None, but get unexpected value {diff}") else: setattr(self, attr, None) if self.tensor_writer is None and getattr(self, attr): logger.warning("Tensorboard config is unset. '%s' will use 'log' only.", attr) getattr(self, attr).discard('tensorboard') getattr(self, attr).add('log') def _parse_step(self): """record the finish dump id of each step""" def check_step(pattern, id_pos): search_path = glob.glob(os.path.join(self.dump_path, f'{pattern}.npy')) if not search_path: return None step_ids = [] for f in search_path: tag, _ = os.path.splitext(os.path.basename(f)) step_ids.append(int(tag.split('_')[id_pos])) os.remove(f) step_ids.sort() return step_ids step_ids = check_step(self.finish_pattern, self.dump_name_mode - 1) if step_ids is None: return cur_steps = len(self.dump_key) for i, step_id in enumerate(step_ids): self.dump_key[cur_steps + i] = step_id def _dump_data_in_step(self, global_step): """write the dumped data each step to tensorboard""" def match_pattern(pattern, filename): name, _ = os.path.splitext(os.path.basename(filename)) parsed = re.fullmatch(pattern, name) if parsed is None: return None, None, None groups = parsed.groups() dump_id = int(groups[self.dump_name_mode - 1]) prefix = groups[self.dump_name_mode] suffix = None if len(groups) < 3 else groups[self.dump_name_mode + 1] return dump_id, prefix, suffix self._parse_step() while self.dump_step <= global_step and self.dump_key.get(self.dump_step) is not None: begin_id = self.dump_key[self.dump_step - 1] end_id = self.dump_key[self.dump_step] file_list = os.listdir(self.dump_path) local_losses = {} for f in file_list: parsed_name = None, None, None if self.local_norm_format: parsed_name = match_pattern(self.local_norm_pattern, f) if not any(parsed_name) and self.device_local_norm_format: parsed_name = match_pattern(self.device_local_norm_pattern, f) if not any(parsed_name) and self.local_loss_format: parsed_name = match_pattern(self.local_loss_pattern, f) if not any(parsed_name): continue dump_id, prefix, suffix = parsed_name if not begin_id < dump_id < end_id: continue data = np.load(os.path.join(self.dump_path, f), allow_pickle=False) if prefix == 'device_local_norm': self._output('device_local_norm', data, self.dump_step, self.device_local_norm_format) elif prefix == 'local_loss': # collect all local loss if there are more than one local loss within one step local_losses[suffix] = local_losses.get(suffix, []) local_losses[suffix].append(data) elif prefix == 'local_norm' and self._check_param_name(suffix): self._output(f'local_norm/{suffix}', data, self.dump_step, self.local_norm_format) if local_losses and self.local_loss_format: self._dump_local_loss(local_losses) self._clear_dump_path() self.dump_step += self.step_interval def _dump_local_loss(self, local_losses): """write the local loss to log/tensorboard""" # log local loss of each micro for loss_tag, loss_list in local_losses.items(): if 'log' in self.local_loss_format: for local_loss in loss_list: self._output(f'micro_local_{loss_tag}_loss', local_loss, self.dump_step, ['log']) if 'tensorboard' in self.local_loss_format: self._output(f'local_{loss_tag}_loss', np.mean(loss_list), self.dump_step, ['tensorboard']) def _dump_max_attention_logit(self, cb_params): """write the max attention logit to log/tensorboard""" network = cb_params.train_network network = get_real_models(network) params = network.get_max_attention_logit() if not params: return step = cb_params.cur_step_num vals = [] for param_name, param in params.items(): v = param.asnumpy() tag = f"max_attention_logit/{param_name}" if 'log' in self.max_attention_logit_format: self._output(tag, v.tolist(), step, ['log']) if 'tensorboard' in self.max_attention_logit_format: tp_id = get_rank() // self.tensor_model_parallel_size head_start = tp_id * len(v) data = {f"head_{head_start + i}": max_attention_logit for i, max_attention_logit in enumerate(v)} self._output(tag, data, step, ['tensorboard']) vals.extend(v) if vals: mean_v = float(np.mean(vals)) max_v = float(np.max(vals)) self._output('max_attention_logit/mean', mean_v, step, self.max_attention_logit_format) self._output('max_attention_logit/max', max_v, step, self.max_attention_logit_format) def _dump_optimizer_state(self, cb_params): """write the optimizer state to tensorboard""" optimizer = cb_params.optimizer if optimizer is None: optimizer = getattr(cb_params.network, "optimizer", None) if hasattr(optimizer, "moment1") and hasattr(optimizer, "moment2"): adam_m, adam_v = optimizer.moment1, optimizer.moment2 elif hasattr(optimizer, "moments1") and hasattr(optimizer, "moments2"): adam_m, adam_v = optimizer.moments1, optimizer.moments2 elif hasattr(optimizer, "exp_avg") and hasattr(optimizer, "exp_avg_sq"): adam_m, adam_v = optimizer.exp_avg, optimizer.exp_avg_sq else: return global_step = cb_params.cur_step_num adam_m_norms = _get_optimizer_state(adam_m, self._check_param_name) adam_v_norms = _get_optimizer_state(adam_v, self._check_param_name) for param_name, adam_m_norm in adam_m_norms.items(): param_name = param_name.split('.', maxsplit=1)[1] self._output(f'adam_m_norm/{param_name}', adam_m_norm, global_step, self.optimizer_state_format) for param_name, adam_v_norm in adam_v_norms.items(): param_name = param_name.split('.', maxsplit=1)[1] self._output(f'adam_v_norm/{param_name}', adam_v_norm, global_step, self.optimizer_state_format) def _check_param_name(self, param_name): if self.target_cache.get(param_name) is None: for pattern in self.target: if re.search(pattern, param_name) is not None: self.target_cache[param_name] = not self.invert return not self.invert self.target_cache[param_name] = self.invert return self.invert return self.target_cache[param_name] def _clear_dump_path(self): if not self.dump_path: return file_list = os.listdir(self.dump_path) for f in file_list: if f.startswith(".nfs"): continue os.remove(os.path.join(self.dump_path, f)) def _to_tensorboard(self, tag, data, global_step): """Write data to tensorboard if possible""" if self.tensor_writer is not None: if isinstance(data, dict): self.tensor_writer.add_scalars(tag, data, global_step=global_step) else: self.tensor_writer.add_scalar(tag, data, global_step=global_step) def _to_log(self, tag, data, global_step): """Write data to log file""" cur_epoch_num = (global_step + self.initial_step - 1) // self.steps_per_epoch + 1 cur_step_num = (global_step + self.initial_step - 1) % self.steps_per_epoch + 1 logger.info( "Epoch:[%3d/%3d], step:[%5d/%5d] %s: %s", cur_epoch_num, self.origin_epochs, cur_step_num, self.steps_per_epoch, tag, data ) def _output(self, tag, data, global_step, formats): """Write data in specified formats""" if formats: for fmt in formats: self.outputer[fmt](tag, data, global_step) def _get_loss_output(self, output): """Get loss, global/local norm""" loss = output global_norm = None local_norm = None if isinstance(output, (tuple, list)): if len(output) == 7: loss, global_norm, local_norm = output[0], output[4], output[5] elif len(output) == 5: loss, global_norm = output[0], output[4] elif isinstance(output[0], ms.Tensor) and isinstance(output[0].asnumpy(), np.ndarray): loss = output[0] if isinstance(global_norm, ms.Tensor): global_norm = global_norm.asnumpy() if isinstance(local_norm, ms.Tensor): local_norm = local_norm.asnumpy() if isinstance(loss, ms.Tensor) and isinstance(loss.asnumpy(), np.ndarray): loss = np.mean(loss.asnumpy()) return loss, global_norm, local_norm @staticmethod def _check_nan_or_inf(indicator, indicator_name): """Check if Nan or Inf in indicator then terminate training""" if indicator is not None: if np.any(np.isnan(indicator)): raise ValueError(f"There is nan in {indicator_name} with value {indicator}, terminate training.") if np.any(np.isinf(indicator)): raise ValueError(f"There is inf in {indicator_name} with value {indicator}, terminate training.")
[文档]@MindFormerRegister.register(MindFormerModuleType.CALLBACK) class SummaryMonitor: """ Summary Monitor can help you to collect some common information, such as loss, learning late, computational graph and so on. Note: referring to `note <https://www.mindspore.cn/docs/en/r2.9.0/api_python/mindspore/mindspore.SummaryCollector.html>`_ . Args: summary_dir (str, optional): The collected data will be persisted to this directory. If the directory does not exist, it will be created automatically. Default: ``None``. collect_freq (int, optional): Set the frequency of data collection, it should be greater than zero, and the unit is `step`. Default: ``10``. collect_specified_data (Union[None, dict], optional): Perform custom operations on the collected data. Default: ``None``. keep_default_action (bool, optional): This field affects the collection behavior of the 'collect_specified_data' field. Default: ``True``. custom_lineage_data (Union[dict, None], optional): Allows you to customize the data. In the custom data, the type of the key supports str, and the type of value supports str, int and float. Default: ``None`` , it means there is no custom data. collect_tensor_freq (Optional[int], optional): The same semantics as the `collect_freq`, but controls TensorSummary only. Default: ``None``. max_file_size (Optional[int], optional): The maximum size in bytes of each file that can be written to the disk. For example, to write not larger than 4GB, specify max_file_size=4*1024**3. Default: ``None``, which means no limit. export_options (Union[None, dict], optional): Perform custom operations on the export data. Default: ``None``, it means that the data is not exported. Examples: >>> from mindformers.core import SummaryMonitor >>> monitor = SummaryMonitor(summary_dir='./summary_dir') """ def __new__(cls, summary_dir=None, collect_freq=10, collect_specified_data=None, keep_default_action=True, custom_lineage_data=None, collect_tensor_freq=None, max_file_size=None, export_options=None): if summary_dir is None: rank_id = get_real_rank() summary_dir = get_output_subpath('summary', rank_id) kwargs = { "summary_dir": summary_dir, "collect_freq": collect_freq, "collect_specified_data": collect_specified_data, "keep_default_action": keep_default_action, "custom_lineage_data": custom_lineage_data, "collect_tensor_freq": collect_tensor_freq, "max_file_size": max_file_size, "export_options": export_options } return SummaryCollector(**kwargs)
[文档]@MindFormerRegister.register(MindFormerModuleType.CALLBACK) class CheckpointMonitor(ModelCheckpoint): """ Checkpoint Monitor For Save LossScale. Args: prefix (str, optional): The prefix name of checkpoint files. Default: ``'CKP'``. directory (str, optional): The path of the folder which will be saved in the checkpoint file. Default: ``None``. config (CheckpointConfig, optional): Checkpoint strategy configuration. Default: ``None``. save_checkpoint_steps (int, optional): Steps to save checkpoint. Default: ``1``. save_checkpoint_seconds (int, optional): Seconds to save checkpoint. Can't be used with save_checkpoint_steps at the same time. Default: ``0``. keep_checkpoint_max (int, optional): Maximum number of checkpoint files can be saved. Default: ``5``. keep_checkpoint_per_n_minutes (int, optional): Save the checkpoint file every "keep_checkpoint_per_n_minutes" minutes. Can't be used with keep_checkpoint_max at the same time. Default: ``0``. integrated_save (bool, optional): Whether to merge and save the split Tensor in the automatic parallel scenario. Integrated save function is only supported in automatic parallel scene. Default: ``True``. save_network_params (bool, optional): Whether to only save network weights additionally. Default: ``False``. save_trainable_params (bool, optional): Whether to save only weights of trainable parameters. Default: ``False``. async_save (bool, optional): Whether asynchronous execution saves the checkpoint to a file. Default: ``False``. saved_network (Cell, optional): Network to be saved in checkpoint file. Default: ``None``. append_info (list, optional): The information save to checkpoint file. Support "epoch_num", "step_num" and dict. Default: ``None``. enc_key (Union[None, bytes], optional): Byte type key used for encryption. Default: ``None``. enc_mode (str, optional): This parameter is valid only when "enc_key" is not set to None. Specifies the encryption mode, currently supports 'AES-GCM', 'AES-CBC' and 'SM4-CBC'. Default: ``'AES-GCM'``. exception_save (bool, optional): Whether to save the current checkpoint when an exception occurs. Default: ``False``. global_batch_size (int, optional): The total batch size. Default: ``0``. checkpoint_format (str, optional): The format of checkpoint to save. Support 'ckpt' or 'safetensors'. Default: ``'ckpt'``. remove_redundancy (bool, optional): Whether to remove redundancy when saving checkpoint. Default: ``False``. embedding_size (int, optional): The size of embedding norm which is get by hidden_size * vocab_size. Default: ``4096``. use_checkpoint_health_monitor (bool, optional): Whether to use the checkpoint health monitor function by embedding norm. Default: ``False``. embedding_local_norm_threshold (float, optional): The threshold of the embedding norm. Default: ``1.0``. health_ckpts_record_dir (str, optional): The path of the file which is used to record the health of checkpoint. Default: ``./output``. use_legacy_format (bool, optional): Whether to use the legacy 'save_checkpoint' process, Default: ``True``. save_optimizer (bool, optional): Whether to save optimizer weights, only used in megatron-format weight save scene. Legacy scene will be set to ``None``. Default: ``True``. Raises: ValueError: If `prefix` is not str or contains the '/' character. ValueError: If `directory` is not str. TypeError: If the config is not CheckpointConfig type. Examples: >>> from mindformers.core import CheckpointMonitor >>> monitor = CheckpointMonitor(directory='./checkpoint_dir') """ @args_type_check(embedding_local_norm_threshold=float, use_checkpoint_health_monitor=bool) def __init__(self, prefix='CKP', directory=None, config=None, save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, save_network_params=False, save_trainable_params=False, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM', exception_save=False, global_batch_size=None, checkpoint_format='ckpt', remove_redundancy=False, embedding_size=4096, embedding_local_norm_threshold=1.0, use_checkpoint_health_monitor=False, health_ckpts_record_dir="./output", use_legacy_format=True, save_optimizer=True): self.config = config self.save_network_params = save_network_params self.save_trainable_params = save_trainable_params self.rank_id = get_real_rank() self.embedding_local_norm_threshold = embedding_local_norm_threshold self.use_checkpoint_health_monitor = use_checkpoint_health_monitor self.embedding_size = embedding_size self.health_ckpts_record_dir = health_ckpts_record_dir self.use_legacy_format = use_legacy_format # Ensure that 'save_optimizer' only use in the sense of 'use_legacy_format == False' self.save_optimizer = save_optimizer if not use_legacy_format else False self.origin_prefix = prefix self.directory = directory self.need_remove_redundancy = remove_redundancy prefix = prefix + f"_rank_{self.rank_id}" self.global_batch_size = global_batch_size # this list records parameters which will be ignored when saving ckpt. self.filter_list = ['accu_grads', 'fi_parameter', 'zeros_k_pe', 'zeros_k_nope', 'zeros_value_states', '_cache', '_device_local_norm', '_device_local_loss', 'expert_load', 'max_logits_val'] self.save_info_list = defaultdict( lambda: { 'ckpt': {'ckpt_file_path': None, 'save_start_time': None, 'save_end_time': None}, 'network': {'ckpt_file_path': None, 'save_start_time': None, 'save_end_time': None}, 'trainable_params': {'ckpt_file_path': None, 'save_start_time': None, 'save_end_time': None}, } ) if append_info is None: append_info = [{ "epoch_num": 0, "step_num": 0, "global_step": 0, "loss_scale": 1 }] ckpt_directory = os.path.join(directory, f"checkpoint/rank_{self.rank_id}") \ if directory else get_output_subpath('checkpoint', self.rank_id) self.network_directory = os.path.join(directory, f"checkpoint_network/rank_{self.rank_id}") \ if directory else get_output_subpath('checkpoint_network', self.rank_id) self.trainable_directory = os.path.join(directory, f"checkpoint_trainable/rank_{self.rank_id}") \ if directory else get_output_subpath('checkpoint_trainable', self.rank_id) if context.get_auto_parallel_context('parallel_mode') in \ ['semi_auto_parallel', 'auto_parallel', 'hybrid_parallel']: logger.info("Integrated_save is changed to False when using auto_parallel.") integrated_save = False config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, save_checkpoint_seconds=save_checkpoint_seconds, keep_checkpoint_max=keep_checkpoint_max, keep_checkpoint_per_n_minutes=keep_checkpoint_per_n_minutes, integrated_save=integrated_save, async_save=async_save, saved_network=saved_network, append_info=append_info, enc_key=enc_key, enc_mode=enc_mode, format=checkpoint_format, exception_save=exception_save, remove_redundancy=remove_redundancy) super().__init__(prefix, ckpt_directory if self.use_legacy_format else None, config=config_ck) # Remove empty checkpoint directory created too early to avoid errors when reading empty folder on second launch if (ckpt_directory and os.path.exists(ckpt_directory) and not os.listdir(ckpt_directory) and not check_tft_valid()): os.rmdir(ckpt_directory) self._graph_saved = True self.meta_json = os.path.join(self._directory, "meta.json") if self._config.async_save: self.last_epoch_num = None self.last_step_num_in_epoch = None self.last_ckpoint_file = None self.meta_updated = True self.async_save_manager = AsyncSaveManager(self._config.async_save) if self.save_network_params: self._network_manager = CheckpointManager(config_ck.format) if self.save_trainable_params: self._trainable_manager = CheckpointManager(config_ck.format) self.need_remove_extra_ckpt = False self.common_info = CommonInfo() self.save_checkpoint_steps = save_checkpoint_steps self.current_ckpt_step_list = [] def print_savetime(self, record_step, batch_num): """print the time cost of saving checkpoint files.""" epoch = int((record_step - 1) // batch_num + 1) step = int((record_step - 1) % batch_num + 1) def output_if_exists(key): save_info = self.save_info_list[record_step][key] file = save_info['ckpt_file_path'] if file is not None and os.path.exists(file): save_info['save_end_time'] = os.path.getmtime(file) cost_time = save_info['save_end_time'] - save_info['save_start_time'] logger.info(f'Finish saving {key} of epoch {epoch} step {step}' f' using {cost_time:.3f} seconds') save_info['ckpt_file_path'] = None output_if_exists('ckpt') output_if_exists('network') output_if_exists('trainable_params') def _save_ckpt(self, cb_params, force_to_save=False): """Save checkpoint files.""" # if fault occurs, ensure that saving ckpt of the recover node should synchronize with other nodes. if check_arf_status(cb_params): self._last_triggered_step = \ int(cb_params.cur_step_num / self.save_checkpoint_steps) * self.save_checkpoint_steps if cb_params.cur_step_num < self._last_triggered_step + self.save_checkpoint_steps: return # pylint: disable=E0203 if cb_params.cur_step_num == self._last_triggered_step: return # if param is cache enable, flush data from cache to host before save_ckpt if self._need_flush_from_cache: self._flush_from_cache(cb_params) save_ckpt = self._check_save_ckpt(cb_params, force_to_save) # if async_save is True, check whether saving processes are completed each step if self._config.async_save: keys = list(self.save_info_list.keys()) for record_step in keys: self.print_savetime(record_step, cb_params.batch_num) if not any(self.save_info_list[record_step][key]['ckpt_file_path'] for key in ['ckpt', 'network', 'trainable_params']): self.save_info_list.pop(record_step) if self._config.async_save and not ms.async_ckpt_thread_status() and \ self.last_epoch_num and self.last_step_num_in_epoch and self.last_ckpoint_file and \ not self.meta_updated: self.record_last_ckpt_to_json(self.last_epoch_num, self.last_step_num_in_epoch, self.last_ckpoint_file) self.meta_updated = True if save_ckpt: # NOTE: origin checkpoint processes are remained here # Create checkpoint directory only before saving weights to avoid empty folder if training stops early if not os.path.exists(self._directory): os.makedirs(self._directory, exist_ok=True) set_safe_mode_for_file_or_dir(self._directory) self.save_checkpoint(cb_params) self.save_checkpoint_network(cb_params) # If async_save is False, output the time cost directly if not self._config.async_save: self.print_savetime(cb_params.cur_step_num, cb_params.batch_num) def get_checkpoint_health_info(self, cb_params): """get the health of checkpoint.""" embedding_local_norm = get_embedding_info(cb_params, self.embedding_size) stage_nums = auto_parallel_context().get_pipeline_stages() device_nums = get_group_size() per_stage_device_nums = device_nums // stage_nums health_flag = ms.Tensor([0], dtype=ms.float32) is_health = 0 if stage_nums > 1: parallel_mode = ms.get_auto_parallel_context("parallel_mode") ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.STAND_ALONE) if get_rank() < per_stage_device_nums: rank_list = list(range(0, per_stage_device_nums)) if embedding_local_norm >= self.embedding_local_norm_threshold: health_flag = ms.Tensor([1], dtype=ms.float32) group_name = self.create_group_pipeline(rank_list) final_health = AllReduceNet(group_name)(health_flag) if final_health.asnumpy() != 0: is_health = 1 ms.set_auto_parallel_context(parallel_mode=parallel_mode) return is_health def create_group_pipeline(self, rank_list): rank_str_list = [str(r) for r in rank_list] rank_list_str = "-".join(rank_str_list) # To make the name of group unique. hashed = hashlib.sha256( rank_list_str.encode()).hexdigest()[:48] pipeline_group_name = str(hashed) create_group(pipeline_group_name, rank_list) return pipeline_group_name def save_checkpoint(self, cb_params): """save checkpoint suitable for resume training.""" logger.info('......Saving ckpt......') self.save_info_list[cb_params.cur_step_num]['ckpt']['save_start_time'] = time.time() step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) cur_ckpoint_file = (f"{self._prefix}-{str(cb_params.cur_epoch_num)}" f"_{str(step_num_in_epoch)}.{self._config.format}") # update checkpoint file list. self._manager.update_ckpoint_filelist(self._directory, self._prefix) # keep checkpoint files number equal max number. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: self._manager.remove_oldest_ckpoint_file() self.need_remove_extra_ckpt = True elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: # pylint: disable=E0203 self._cur_time_for_keep = time.time() if (self._cur_time_for_keep - self._last_time_for_keep) \ < self._config.keep_checkpoint_per_n_minutes * 60: self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, self._cur_time_for_keep) # generate the new checkpoint file and rename it. global SAVE_DIR SAVE_DIR = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num if self.use_checkpoint_health_monitor: is_health = self.get_checkpoint_health_info(cb_params) # check the health of checkpoint and save the record file if get_rank() == 0: dump_health_json_path = os.path.join(self.health_ckpts_record_dir, "health_ckpts.json") health_step_data = { 'is_health': is_health, 'ckpt_name': cur_ckpoint_file } all_step_health_data = [] if os.path.exists(dump_health_json_path): with open(dump_health_json_path, 'r', encoding="utf-8") as file: data = json.load(file) all_step_health_data = list(data) all_step_health_data.append(health_step_data) with open(dump_health_json_path, 'w', encoding="utf-8") as file: json.dump(all_step_health_data, file, indent=4) set_safe_mode_for_file_or_dir(dump_health_json_path) if "epoch_num" in self._append_dict: self._append_dict["epoch_num"] = cb_params.cur_epoch_num if "step_num" in self._append_dict: self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num if cb_params.optimizer is not None: self._append_dict["global_step"] = cb_params.optimizer.global_step else: self._append_dict["global_step"] = cb_params.network.optimizer.global_step if "loss_scale" in self._append_dict: outputs = cb_params.net_outputs if isinstance(outputs, (tuple, list)) and len(outputs) >= 3: self._append_dict["loss_scale"] = outputs[2] if self.global_batch_size is not None: self._append_dict["global_batch_size"] = self.global_batch_size logger.info("global_batch_size: %d", self._append_dict["global_batch_size"]) logger.info("epoch_num: %d", self._append_dict["epoch_num"]) logger.info("step_num: %d", self._append_dict["step_num"]) logger.info("global_step: %d", self._append_dict["global_step"]) network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network self.remove_redundancy(network, cur_file, self._append_dict, None) self._latest_ckpt_file_name = cur_file self.save_info_list[cb_params.cur_step_num]['ckpt']['ckpt_file_path'] = cur_file if self._config.async_save: self.last_epoch_num = cb_params.cur_epoch_num self.last_step_num_in_epoch = step_num_in_epoch self.last_ckpoint_file = cur_ckpoint_file self.meta_updated = False else: if "__exception_save__" not in self._append_dict: self.record_last_ckpt_to_json(cb_params.cur_epoch_num, step_num_in_epoch, cur_ckpoint_file) def _get_cur_dp(self, cur_rank, parameter_redundancy_dict): """get the current dp""" value_len = sys.maxsize min_value = () min_value_set = set() for key, value in parameter_redundancy_dict.items(): if key.startswith("accu_grads") or key.startswith("inputs"): continue for item in value: if cur_rank not in item: continue # if item is subset of min_value_set, update min_value_set and min_value if len(item) < value_len: if min_value_set and not set(item).issubset(min_value_set): return (cur_rank,) value_len = len(item) min_value_set = set(item) min_value = item # if value is not smaller than len of min_value len, # check if min_value_set is subset of current item elif not min_value_set.issubset(set(item)): return (cur_rank,) return min_value def _filter_ckpt_not_save(self, x, filter_list): return all(not x.startswith(item) and item not in x for item in filter_list) def _tft_save_ckpt(self, param_layout_set, save_param_names, cur_file, append_dict, network): """save checkpoint with remove redundancy for TFT training.""" def choice_func(x): return (x not in param_layout_set or (save_param_names is not None and x in save_param_names)) \ and self._filter_ckpt_not_save(x, self.filter_list) ms.save_checkpoint(network, cur_file, False, False, append_dict, self._config.enc_key, self._config.enc_mode, format=self._config.format, choice_func=choice_func, remove_redundancy=self._config.remove_redundancy) # pylint: disable=W0640 def _do_remove_redundancy_for_tft(self, redundancy_info, cur_file, network, append_dict): """save checkpoint with remove redundancy for TFT training.""" rank_id, param_redundancy_dict, single_params, param_layout = redundancy_info pattern = rf'_(\d+)\.{self._config.format}$' match = re.search(pattern, cur_file) cur_step_in_epoch = int(match.group(1)) parallel_mode = context.get_auto_parallel_context("parallel_mode") cur_dp = self._get_cur_dp(rank_id, param_redundancy_dict) # loop through all ranks in the current dp for rank in cur_dp: save_param_names = single_params.get(rank) if save_param_names == param_layout.keys(): logger.warning("For remove_redundancy save checkpoint, the saved parameters are non-redundant.") param_layout_set = set(param_layout.keys()) if parallel_mode else set() cur_file = re.sub(r'rank_\d+', f'rank_{rank}', cur_file) self._tft_save_ckpt(param_layout_set, save_param_names, cur_file, append_dict, network) append_dict["__exception_save__"] = True self.meta_json = re.sub(r'rank_\d+', f'rank_{rank}', self.meta_json) self.record_last_ckpt_to_json(append_dict["epoch_num"], cur_step_in_epoch, os.path.basename(cur_file)) def _check_if_skip_trainable_params(self, value): """ Checks if a trainable parameter should be skipped based on execution mode and parameter properties. """ is_graph_mode = get_context('mode') == context.GRAPH_MODE in_auto_parallel = ms.get_auto_parallel_context("parallel_mode") in [ ms.ParallelMode.SEMI_AUTO_PARALLEL, ms.ParallelMode.AUTO_PARALLEL, ] skip_for_parallel = is_graph_mode and in_auto_parallel and ((not value.sliced) or value.has_init) cur_param_info = value.param_info is_pipeline_shared = getattr(cur_param_info, 'is_pipeline_shared_param', False) return skip_for_parallel or is_pipeline_shared def remove_redundancy(self, network, cur_file, append_dict, train_network): """remove redundancy when saving checkpoint files.""" parallel_mode = context.get_auto_parallel_context("parallel_mode") if self._config.remove_redundancy and parallel_mode != "stand_alone": logger.info('......Removing redundancy......') if train_network: param_layout = train_network.parameter_layout_dict else: param_layout = network.parameter_layout_dict rank_id = get_real_rank() if param_layout: device_num = get_real_group_size() stage_num = get_auto_parallel_context("pipeline_stages") chunk_size = device_num // stage_num initial_rank = (rank_id // chunk_size) * chunk_size param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank) single_params = remove_param_redundancy(param_redundancy_dict) save_param_names = single_params.get(rank_id) param_layout_set = set(param_layout.keys()) if save_param_names == param_layout.keys(): logger.warning("For remove_redundancy save checkpoint, the saved parameters are non-redundant.") def choice_func(x): return (x not in param_layout_set or (save_param_names is not None and x in save_param_names)) \ and self._filter_ckpt_not_save(x, self.filter_list) else: param_redundancy_dict = get_parameter_redundancy(network) single_params = remove_param_redundancy(param_redundancy_dict) save_param_names = single_params.get(rank_id) def choice_func(x): return save_param_names is not None and x in save_param_names \ and self._filter_ckpt_not_save(x, self.filter_list) # __exception_save__ is used to indicate that the checkpoint is saved by the TFT process if "__exception_save__" in append_dict: redundancy_info = (rank_id, param_redundancy_dict, single_params, param_layout) self._do_remove_redundancy_for_tft(redundancy_info, cur_file, network, append_dict) return ms.save_checkpoint(network, cur_file, False, self._config.async_save, append_dict, self._config.enc_key, self._config.enc_mode, format=self._config.format, choice_func=choice_func, remove_redundancy=self._config.remove_redundancy) else: ms.save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, append_dict, self._config.enc_key, self._config.enc_mode, format=self._config.format, choice_func=lambda x: self._filter_ckpt_not_save(x, self.filter_list), remove_redundancy=self._config.remove_redundancy) def save_checkpoint_network(self, cb_params): """save checkpoint only network params, which is suitable for train, evaluate and predict.""" save_obj = cb_params.network network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network if hasattr(save_obj, 'optimizer') and save_obj.optimizer is not None: save_obj = save_obj.network step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) if self.save_trainable_params: self.save_info_list[cb_params.cur_step_num]['trainable_params']['save_start_time'] = time.time() save_obj.init_parameters_data() param_dict = OrderedDict() for param in save_obj.trainable_params(): param_dict[param.name] = param param_list = [] for (key, value) in param_dict.items(): if self._check_if_skip_trainable_params(value): continue each_param = {"name": key} param_data = Tensor(value.data.asnumpy()) # in automatic model parallel scenario, some parameters were split to all the devices, # which should be combined before saving if key in save_obj.parameter_layout_dict: param_data = _get_merged_param_data(save_obj, key, param_data, self._config.integrated_save) each_param["data"] = param_data param_list.append(each_param) save_obj = param_list cb_cur_ckpoint_file = (f"{self._prefix}-trainable_params-{str(cb_params.cur_epoch_num)}" f"_{str(step_num_in_epoch)}.{self._config.format}") cb_cur_file = os.path.join(self.trainable_directory, cb_cur_ckpoint_file) os.makedirs(self.trainable_directory, exist_ok=True) # update checkpoint file list. self._trainable_manager.update_ckpoint_filelist( self.trainable_directory, f"{self._prefix}-trainable_params" ) # keep checkpoint files number equal max number. if self.need_remove_extra_ckpt: self._trainable_manager.remove_oldest_ckpoint_file() self.remove_redundancy(save_obj, cb_cur_file, {}, network) self.save_info_list[cb_params.cur_step_num]['trainable_params']['ckpt_file_path'] = cb_cur_file return if self.save_network_params: self.save_info_list[cb_params.cur_step_num]['network']['save_start_time'] = time.time() cb_cur_ckpoint_file = (f"{self._prefix}-network-{str(cb_params.cur_epoch_num)}" f"_{str(step_num_in_epoch)}.{self._config.format}") cb_cur_file = os.path.join(self.network_directory, cb_cur_ckpoint_file) os.makedirs(self.network_directory, exist_ok=True) # update checkpoint file list. self._network_manager.update_ckpoint_filelist(self.network_directory, f"{self._prefix}-network") # keep checkpoint files number equal max number. if self.need_remove_extra_ckpt: self._network_manager.remove_oldest_ckpoint_file() self.remove_redundancy(save_obj, cb_cur_file, {}, network) self.save_info_list[cb_params.cur_step_num]['network']['ckpt_file_path'] = cb_cur_file self.need_remove_extra_ckpt = False def _save_megatron_ckpt_file_format(self, cb_params): """Save the checkpoints like megatron format.""" # Check whether the checkpoint of current step has been saved. if cb_params.cur_step_num == self._last_triggered_step: return # Get current step as iteration iteration = self._append_step_num + cb_params.cur_step_num # Get common info self.common_info.step_num = iteration self.common_info.epoch_num = cb_params.cur_epoch_num self.common_info.global_step = int(cb_params.network.optimizer.global_step) self.common_info.loss_scale = None if isinstance(cb_params.net_outputs, (tuple, list)) and len(cb_params.net_outputs) >= 3: self.common_info.loss_scale = float(cb_params.net_outputs[2]) self.common_info.global_batch_size = self.global_batch_size if self.use_checkpoint_health_monitor: self.common_info.ckpt_status = CkptHealthStatus.NORMAL.value \ if self.get_checkpoint_health_info(cb_params) == 0 else CkptHealthStatus.ABNORMAL.value # Get all sharded tensor info of this network to save 'metadata.json' sharded_tensor_metas = get_all_sharded_tensor( network=cb_params.network, filter_func=(lambda x: x in list( cb_params.network.network.parameters_dict().keys())) if not self.save_optimizer else None ) if get_real_group_size() > 1 else None save_checkpoint( iteration=iteration, network=cb_params.network.network, optimizer=cb_params.network.optimizer if self.save_optimizer else None, async_save_manager=self.async_save_manager if self._config.async_save else None, common_info=self.common_info, keep_max_num=self._config.keep_checkpoint_max, user_prefix=self.origin_prefix, save_checkpoint_path=self.directory, sharded_tensor_metas=sharded_tensor_metas, remove_redundancy=self.need_remove_redundancy, current_ckpt_step_list=self.current_ckpt_step_list, ) # After saving, update the counter of last saved step. self._last_triggered_step = cb_params.cur_step_num def record_last_ckpt_to_json(self, epoch, step, ckpt_file): """record last ckpt info to json""" meta_data = { "last_epoch": epoch, "last_step": step, "last_ckpt_file": ckpt_file } with tempfile.NamedTemporaryFile('w', delete=False, dir=self._directory) as temp_file: json.dump(meta_data, temp_file) temp_file_path = temp_file.name os.replace(temp_file_path, self.meta_json) set_safe_mode_for_file_or_dir(self.meta_json) def step_end(self, run_context): """ Save the checkpoint at the end of step. Args: run_context (RunContext): Context of the train running. """ if self.use_legacy_format: super().step_end(run_context) else: cb_params = run_context.original_args() force_to_save = False if cb_params.cur_step_num == self._last_triggered_step: return # If param is cache enable, flush data from cache to host before save_ckpt if self._need_flush_from_cache: self._flush_from_cache(cb_params) cur_step_need_save_ckpt = self._check_save_ckpt(cb_params, force_to_save) # Save checkpoint if cur_step_need_save_ckpt: self._save_megatron_ckpt_file_format(cb_params) # If async_save is False, output the time cost directly if not self._config.async_save: self.print_savetime(cb_params.cur_step_num, cb_params.batch_num) self._last_triggered_step = cb_params.cur_step_num def end(self, run_context): """ Save the last checkpoint after training finished. Args: run_context (RunContext): Context of the train running. """ if self.use_legacy_format: super().end(run_context) def on_train_step_begin(self, run_context): """Called before each training step.""" super().on_train_step_begin(run_context) if not self.use_legacy_format and self._config.async_save: logger.info("(on_train_step_begin) Try to execute finalize func.") self.async_save_manager.maybe_finalize(wait_finish=False) def on_train_end(self, run_context): """Called after the end of training.""" super().on_train_end(run_context) if not self.use_legacy_format: cb_params = run_context.original_args() # Need to save the last step checkpoint self._save_megatron_ckpt_file_format(cb_params) if self._config.async_save: logger.info("(on_train_end) Wait all ranks and execute finalize func.") self.async_save_manager.maybe_finalize(wait_finish=True)
[文档]@MindFormerRegister.register(MindFormerModuleType.CALLBACK) class ProfileMonitor(Callback): """ Profile analysis in training. Args: start_step (int, optional): The step to start profiling. Default: ``1``. stop_step (int, optional): The step to stop profiling. Default: ``10``. output_path (str, optional): The result of profiling will be saved in this path. Default: ``None``. start_profile (str, optional): Whether to enable profiling. Default: ``True``. profile_rank_ids (list, optional): Specify rank ids to enable profiling. Default: ``None`` (All rank ids are enabled). profile_pipeline (str, optional): Whether to enable profiling on one card of each parallel stage. Default: ``False``. profile_communication (str, optional): Whether to collect communication performance data during multi-device training. Default: ``False``. profile_memory (str, optional): Whether to collect Tensor memory data. Default: ``False``. config (dict, optional): Configuration items, used to profile relevant configuration information, such as parallel configuration. Default: ``None``. profiler_level (int, optional): Collection level of profiling data(0, 1, 2). Default: ``0``. - 0: The most streamlined level of performance data collection, only collecting execution time data for computational operators and basic data for large communication operators. - 1: In addition to level 0, extra data is collected for CANN layer AscendCL, AICORE performance data, and small communication operators. - 2: In addition to level 1, extra data is collected for graph compile level O2 and Runtime in the CANN layer. with_stack (str, optional): Whether to collect Python-side stack trace data. Default: ``False``. data_simplification (str, optional): Whether to enable data simplification, which will delete the FRAMEWORK directory and other extraneous data after exporting profiling data. Default: ``True``. mstx (bool, optional): Whether to enable mstx step-time recording. Default: ``False``. Examples: >>> from mindformers.core import ProfileMonitor >>> monitor = ProfileMonitor(output_path='./profile_dir') """ def __init__(self, start_step=1, stop_step=10, output_path=None, start_profile=True, profile_rank_ids=None, profile_pipeline=False, profile_communication=False, profile_memory=False, config=None, profiler_level=0, with_stack=False, data_simplification=True, mstx=False, **kwargs): super().__init__() self.mstx_range_id = None self.mstx_enabled = not _check_mspti_is_on() self.stop_step = stop_step self.profile_rank_ids = profile_rank_ids self.profile_pipeline = profile_pipeline self.profiler = None # check start_profile start_profile = self._check_start_profile(start_profile, start_step) # check step self.start_step, stop_step = self._check_step(start_step, stop_step) if profile_communication: if profiler_level == 0: profiler_level = 1 logger.warning( "When profile_communication is True, profiler_level must be greater than 0, reset " "profiler_level to 1") # convert profiler_level profiler_level = self._get_profiler_level(profiler_level) rank_id = get_real_rank() self.pipeline_rank_ids = get_pipeline_rank_ids() if self.profile_pipeline else None if self.pipeline_rank_ids == [-1]: raise ValueError("Device num should be divided by pipeline stage num.") if self._is_profile_required(rank_id): if not output_path: output_path = get_output_subpath('profile', rank_id) else: output_path = os.path.join(output_path, 'profile', f'rank_{rank_id}') logger.info(f"Profile save path: {output_path}") if get_context("device_target") == "GPU" and profile_memory: logger.warning("The parameter profile_memory is not supported on GPU currently, " "so is changed to False. ") profile_memory = False # get schedule config schedule_config = self._get_schedule(start_profile, start_step, stop_step) if is_version_ge(ms.__version__, '2.6.0'): from mindspore.profiler import profile, _ExperimentalConfig, tensorboard_trace_handler experimental_config = _ExperimentalConfig(profiler_level=profiler_level, data_simplification=data_simplification, mstx=mstx) self.profiler = profile( profile_memory=profile_memory, start_profile=False, with_stack=with_stack, schedule=schedule_config, on_trace_ready=tensorboard_trace_handler(dir_name=output_path), experimental_config=experimental_config, **kwargs ) self.is_profiler_start = False # compatible to old version mindspore else: from mindspore.profiler import Profiler, tensor_board_trace_handler self.profiler = Profiler( start_profile=False, profile_memory=profile_memory, profiler_level=profiler_level, with_stack=with_stack, data_simplification=data_simplification, mstx=mstx, schedule=schedule_config, on_trace_ready=tensor_board_trace_handler(dir_name=output_path), **kwargs ) self.is_profiler_start = False self._record_metadata(config) self.run_context = None self.output_path = output_path @staticmethod def _check_step(start_step, stop_step): """ Check start_step and stop_step. Args: start_step: start step number. stop_step: stop step number. """ if start_step < 0: start_step = 1 logger.warning("start_step must bo greater than 0, but got %s, reset to default 1") if stop_step < 0: stop_step = 10 logger.warning("stop_step must bo greater than 0, but got %s, reset to default 10") if start_step > stop_step: start_step = 1 stop_step = 10 logger.warning("stop_step must bo greater than start_step, but get start_step = %d, stop_step = %d, " "now start_step and stop_step are reset to 1 and 10.", start_step, stop_step) return start_step, stop_step @staticmethod def _check_start_profile(start_profile, start_step): """ Check start_step and stop_step. Args: start_profile: Whether to collect after initialization. start_step: start step number. """ if start_step != 1 and start_profile: logger.warning("If the parameters start_step and init_start_profile are set simultaneously, " "the init_start_profile parameter will not take effect, reset init_start_profile to False.") return False return start_profile @staticmethod def _get_schedule(start_profile, start_step, stop_step): """ Get schedule by start_step and stop_step. Args: start_profile: Whether to start the profiler from the first step. start_step: start step number. stop_step: stop step number. """ if start_profile: schedule_config = schedule(wait=0, active=stop_step, warmup=0, repeat=1, skip_first=1) else: schedule_config = schedule(wait=0, active=stop_step - start_step + 1, warmup=0, repeat=1, skip_first=start_step) return schedule_config def on_train_step_begin(self, run_context): """ Start profile at the beginning of step. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() step_num = cb_params.cur_step_num if self.profiler and not self.is_profiler_start: self.profiler.start() self.profiler.step() # avoid the first step to align with train steps self.is_profiler_start = True if self.mstx_enabled: self.mstx_range_id = ms.profiler.mstx.range_start(f'step {step_num}', ms.runtime.current_stream()) def on_train_step_end(self, run_context): """ Stop profile at the end of step. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() step_num = cb_params.cur_step_num if self.mstx_enabled: ms.profiler.mstx.range_end(self.mstx_range_id) if self.profiler: self.profiler.step() if step_num == self.stop_step and self.profiler: logger.info("End of Profiling, please analyze it using MindStudio Insight. " "See https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/" "GUI_baseddevelopmenttool/msascendinsightug/Insight_userguide_0002.html for details.") def _record_metadata(self, config): """ Record metadata from config. Args: config (dict): config of the train running. """ if config is None: return parallel = config.parallel parallel_config = config.parallel_config.to_dict() try: self.profiler.add_metadata_json('distributed_args', json.dumps({ 'tensor_model_parallel_size': parallel_config.get('model_parallel', 1), 'pipeline_model_parallel_size': parallel_config.get('pipeline_stage', 1), 'data_parallel_size': parallel_config.get('data_parallel', 1), 'expert_model_parallel_size': parallel_config.get('expert_parallel', 1), 'sequence_parallel': parallel_config.get('use_seq_parallel', False), 'parallel_mode': parallel.get('parallel_mode', None), 'world_size': parallel.get('device_num', None) })) except AttributeError as e: logger.warning("Profiler failed to record distributed args, %s", e) def _is_profile_required(self, rank_id): """ Determine whether current rank id needs to enable profiling. Args: rank_id (int): current rank id. """ if not self.profile_rank_ids and not self.pipeline_rank_ids: return True profile_ids = self.profile_rank_ids if isinstance(self.profile_rank_ids, list) else [] pipeline_ids = self.pipeline_rank_ids if isinstance(self.pipeline_rank_ids, list) else [] if rank_id in profile_ids or rank_id in pipeline_ids: return True return False @staticmethod def _get_profiler_level(level): """ Obtain profiler level based on the level value with integer type. Args: level (int): the value of profiler_level in MF config. """ if level is None: return ProfilerLevel.Level0 max_level = len(ProfilerLevel.__members__) - 1 if level < 0 or level > max_level: logger.warning("Invalid profiler_level: %s, return None.", level) return None profiler_level = getattr(ProfilerLevel, f"Level{level}") return profiler_level
[文档]@MindFormerRegister.register(MindFormerModuleType.CALLBACK) class EvalCallBack(Callback): """ Evaluate Callback used in training progress. Args: eval_func (Callable): The function used to evaluate the model results and can be customized according to specific task. step_interval (int, optional): Determine the num of step intervals between each eval. Default ``100``. Note that it will not take effects when running in data sink mode. epoch_interval (int, optional): Determine the num of epoch intervals between each eval. Default ``-1``, means eval on every epoch end. Examples: >>> from mindformers.core.callback import EvalCallBack >>> def eval_func(): ... print("output result") >>> eval_callback = EvalCallBack(eval_func=eval_func) >>> type(eval_callback) """ def __init__(self, eval_func: Callable, step_interval: int = 100, epoch_interval: int = -1): self.eval_func = eval_func self.step_interval = step_interval self.epoch_interval = epoch_interval def on_train_epoch_end(self, run_context): # if not use epoch end if self.epoch_interval <= 0: return callback_params = run_context.original_args() cur_epoch_num = callback_params.cur_epoch_num if cur_epoch_num % self.epoch_interval == 0: self._execute_eval() def on_train_step_end(self, run_context): # if not use step end if self.step_interval <= 0: return callback_params = run_context.original_args() cur_step_num = callback_params.cur_step_num if cur_step_num % self.step_interval == 0: self._execute_eval() def _execute_eval(self): start_time = time.time() output = self.eval_func() eval_time = time.time() - start_time logger.info("Eval result: %s, eval time is %f s.", output, eval_time)
@MindFormerRegister.register(MindFormerModuleType.CALLBACK) class ColdHotExpertMonitor(Callback): """ ColdHotExpertMonitor Callback used in MoE model training progress. Args: config : Read config from configuration file. Examples: >>> from mindformers.core.callback import ColdHotExpertMonitor >>> callback = ColdHotExpertMonitor(config) >>> type(callback) <class 'mindformers.core.callback.callback.ColdHotExpertMonitor'> """ def __init__(self, moe_config=None, hidden_size=None, ffn_hidden_size=None, expert_parallel=None, model_parallel=None, save_checkpoint_steps=None): self.update_step = moe_config.update_step if hasattr(moe_config, "update_step") else 10000 self.expert_num = moe_config.expert_num self.hot_expert_num = moe_config.hot_expert_num self.moe_module_name = moe_config.moe_module_name self.hidden_size = hidden_size self.ffn_hidden_size = ffn_hidden_size self.ep = expert_parallel self.mp = model_parallel self.save_checkpoint_steps = save_checkpoint_steps self.rank_id = int(os.getenv("RANK_ID")) self.local_expert_num = self.expert_num // self.ep start_index = (self.rank_id // self.mp) * self.local_expert_num end_index = start_index + self.local_expert_num self.local_expert_index = list(range(start_index, end_index)) self.rank_size = int(os.getenv("RANK_SIZE")) def on_train_step_end(self, run_context): """ Switch popular expert copies when there is a change in popular experts at the step. Args: run_context (RunContext): Context of the train running. """ if self.update_step <= 0: return callback_params = run_context.original_args() cur_step_num = callback_params.cur_step_num if ((cur_step_num < self.update_step and cur_step_num & (cur_step_num - 1) == 0) or (cur_step_num == self.save_checkpoint_steps) or (cur_step_num % self.update_step == 0)): total_start = time.time() train_network = callback_params.train_network if train_network is None: return blocks = self.get_attribute_by_path(train_network, self.moe_module_name) for block in blocks: if cur_step_num > 1: self.return_back_hot_expert(block) self.switch_hot_expert(block, cur_step_num) total_end = time.time() logger.info("switch hot experts spent time is %f s.", total_end - total_start) def on_train_end(self, run_context): """ Switch popular expert copies when there is a change in popular experts at the step. Args: run_context (RunContext): Context of the train running. """ callback_params = run_context.original_args() cur_step_num = callback_params.cur_step_num train_network = callback_params.train_network if train_network is None: return blocks = self.get_attribute_by_path(train_network, self.moe_module_name) for block in blocks: if cur_step_num > 1: self.return_back_hot_expert(block) def get_attribute_by_path(self, obj, path): """ Obtains MoE blocks modules in obj by path. Args: obj : Model. path(str) : Path of the MoE layer in the model """ for attr in path.split('.'): obj = getattr(obj, attr) return obj def return_back_hot_expert(self, block): """ When the popular experts change, return the replica parameters to the old popular experts. Args: block : MoE layer. """ old_hot_expert_index = block.output.hot_expert_index.value()[0] if self.hot_expert_num == 1: if old_hot_expert_index[0] in self.local_expert_index: ffn_index = old_hot_expert_index[0] - (self.rank_id // self.mp) * self.local_expert_num block.output.ffn.mapping.weight[ffn_index] = block.output.mlp.mapping.weight block.output.ffn.mapping.bias[0][ffn_index][0] = block.output.mlp.mapping.bias block.output.ffn.projection.weight[ffn_index] = block.output.mlp.projection.weight block.output.ffn.projection.bias[0][ffn_index][0] = block.output.mlp.projection.bias elif self.hot_expert_num > 1: for i in range(self.hot_expert_num): if old_hot_expert_index[i] in self.local_expert_index: ffn_index = old_hot_expert_index[i] - (self.rank_id // self.mp) * self.local_expert_num block.output.ffn.mapping.weight[ffn_index] = block.output.mlp.mapping.weight[i] block.output.ffn.mapping.bias[0][ffn_index][0] = block.output.mlp.mapping.bias[0][i][0] block.output.ffn.projection.weight[ffn_index] = block.output.mlp.projection.weight[i] block.output.ffn.projection.bias[0][ffn_index][0] = block.output.mlp.projection.bias[0][i][0] def switch_hot_expert(self, block, cur_step_num): """ Switch popular expert copies when there is a change in popular experts at the step. Args: block : MoE layer. cur_step_num : Current training step """ old_hot_expert_index = block.output.hot_expert_index.value()[0] cumsum_tensor = block.output.router.router.cumsum_value.value() _, new_expert_index = cumsum_tensor.topk(self.expert_num, largest=True) new_hot_expert_index = new_expert_index[0:self.hot_expert_num] new_cold_expert_index = new_expert_index[self.hot_expert_num:self.expert_num] broadcasts = [self.BroadcastCell(i) for i in range(self.rank_size)] if self.hot_expert_num == 1: if cur_step_num > 1 and old_hot_expert_index[0] == new_hot_expert_index[0]: return # Broadcast new hot expert and copy the weights of new hot experts to mlp for i in range(self.mp): ffn_index = new_hot_expert_index[0] % self.local_expert_num rank_id = new_hot_expert_index[0] // self.local_expert_num * self.mp + i expert_part = broadcasts[rank_id]((block.output.ffn.mapping.weight[ffn_index], block.output.ffn.mapping.bias[0][ffn_index][0], block.output.ffn.projection.weight[ffn_index], block.output.ffn.projection.bias[0][ffn_index][0])) if self.rank_id % self.mp == i: block.output.mlp.mapping.weight = expert_part[0] block.output.mlp.mapping.bias = expert_part[1] block.output.mlp.projection.weight = expert_part[2] block.output.mlp.projection.bias = expert_part[3] elif self.hot_expert_num > 1: new_hot_expert_index, _ = new_hot_expert_index.topk(self.hot_expert_num, largest=False) if cur_step_num > 1 and old_hot_expert_index.equal(new_hot_expert_index).all(): return # Broadcast new hot expert and copy the weights of new hot experts to mlp for index in range(self.hot_expert_num): for i in range(self.mp): ffn_index = new_hot_expert_index[index] % self.local_expert_num rank_id = new_hot_expert_index[index] // self.local_expert_num * self.mp + i expert_part = broadcasts[rank_id]((block.output.ffn.mapping.weight[ffn_index], block.output.ffn.mapping.bias[0][ffn_index][0], block.output.ffn.projection.weight[ffn_index], block.output.ffn.projection.bias[0][ffn_index][0])) if self.rank_id % self.mp == i: block.output.mlp.mapping.weight[index] = expert_part[0] block.output.mlp.mapping.bias[0][index][0] = expert_part[1] block.output.mlp.projection.weight[index] = expert_part[2] block.output.mlp.projection.bias[0][index][0] = expert_part[3] block.output.hot_expert_index = new_hot_expert_index.reshape((1, -1)) block.output.cold_expert_index = new_cold_expert_index.reshape((1, -1)) del broadcasts class BroadcastCell(Cell): def __init__(self, rank_id): super().__init__(auto_prefix=False) self.broadcast = Broadcast(rank_id) self.add_flags(skip_auto_parallel_compile=True) @jit() def construct(self, x): x = self.broadcast(x) return x @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class TrainCallBack(Callback): """ Train Callback used in training progress. Args: stop_step (int): The function stop train process at the step. Default None, set in yaml. Examples: >>> from mindformers.core.callback import TrainCallBack >>> stop_step = TrainCallBack(stop_step=10) <class 'mindformers.core.callback.callback.TrainCallBack'> """ def __init__(self, stop_step: int = None): self.stop_step = stop_step def on_train_step_end(self, run_context): """ Print training info at the end of epoch. Args: run_context (RunContext): Context of the process running. """ cb_params = run_context.original_args() if self.stop_step is not None and cb_params.cur_step_num >= self.stop_step: run_context.request_stop() logger.info("set train process early stop at %s steps in yaml", self.stop_step) @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class StressDetectCallBack(Callback): """ Stress Detect Callback used in training progress. Args: detection_interval (int): (int, optional): The number of steps between each hardware precision stress detection. Default: ``None``. num_detections (int, optional): The number of consecutive hardware precision stress detections for each round. Default: ``None``. dataset_size (int, optional): Training dataset size. Default: ``None``. Examples: >>> from mindformers.core.callback import StressDetectCallBack >>> stress_detect_callback = StressDetectCallBack(detection_interval=10, num_detections=3, dataset_size=1024) >>> type(stress_detect_callback) """ def __init__(self, detection_interval: int = None, num_detections: int = None, dataset_size: int = None): logger.warning('StressDetectCallBack serves as an experimental interface and its functionality is ' 'not yet stable.') self.detection_interval = detection_interval self.num_detections = num_detections self.steps_per_epoch = dataset_size if self.detection_interval > self.steps_per_epoch: logger.warning(f"detection_interval = {self.detection_interval} is bigger than " f"steps_per_epoch = {self.steps_per_epoch}") def on_train_step_end(self, run_context): """ Stress detect at the end of step. Args: run_context (RunContext): Context of the train running. """ callback_params = run_context.original_args() cur_step_num = callback_params.cur_step_num # stress detect detect_ret_list = [] if cur_step_num % self.detection_interval == 0: logger.info("Start to stress detect") for _ in range(self.num_detections): ret = stress_detect() detect_ret_list.append(ret) self.log_stress_detect_result(detect_ret_list) @staticmethod def log_stress_detect_result(detect_ret_list): """print output information.""" for ret in detect_ret_list: if ret == 0: logger.info("Stress detection passed") elif ret == VOLTAGE_ERROR_CODE: raise RuntimeError(f"Voltage recovery failed with error code: {ret}") else: logger.warning(f"Stress detection failed with error code: {ret}") @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class MaxLogitsMonitor(Callback): """ Callback to reset max attention logits during training. This callback resets the maximum attention logit values at the end of each training step. """ def _reset_max_attention_logit(self, network): """Reset max attention logit in the network. Args: network: The network to reset max attention logit. Raises: RuntimeError: If the network does not have reset_max_attention_logit method. """ while hasattr(network, "network"): network = network.network if hasattr(network, "reset_max_attention_logit"): network.reset_max_attention_logit() else: raise RuntimeError(f"network {type(network).__name__} should have reset_max_attention_logit") def on_train_step_end(self, run_context): """update expert bias at the end of step.""" cb_params = run_context.original_args() self.cur_step = cb_params.cur_step_num # pylint: disable=W0212 network = cb_params.train_network while hasattr(network, 'network'): network = network.network parallel_mode = get_auto_parallel_context("parallel_mode") if parallel_mode in ["semi_auto_parallel", "auto_parallel"] and get_context('mode') == ms.GRAPH_MODE: network = network._backbone self._reset_max_attention_logit(network) @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class TopkBiasBalanceCallback(Callback): """ Callback for topk bias balance feature in moe module. Arguments below, except `gradient_accumulation_steps`, take effects only when use legacy models. Args: balance_via_topk_bias (bool, optional): Whether to use topk bias update, should be consistent with moe config. Defaults to False. topk_bias_update_rate (float, optional): How fast is the bias updated. Defaults to 0.0. expert_num (int, optional): How many experts in the moe module. Defaults to 1. micro_batch_num (int, optional): Micro batch number in pipeline parallel. Default to 1. gradient_accumulation_steps (int, optional): Gradient accumulation steps for training. Default to 1. """ def __init__(self, balance_via_topk_bias: bool = False, topk_bias_update_rate: float = 0.0, expert_num: int = 1, micro_batch_num: int = 1, gradient_accumulation_steps: int = 1): # for aux loss free # this process is to update the expert load self.update_topk_bias_flag = balance_via_topk_bias self.gradient_accumulation_steps = gradient_accumulation_steps self.write_expert_load_to_tensorboard = get_tensorboard_args()['log_expert_load_to_tensorboard'] self.tensor_writer = get_tensorboard_writer() if self.update_topk_bias_flag and self.tensor_writer is not None and self.write_expert_load_to_tensorboard: logger.info('The expert loads will be written to tensorboard.') self.cur_step = 0 if self.update_topk_bias_flag: self.assign = P.Assign() self.assign.recompute(False) self.afb_sub = P.Sub() self.afb_add = P.Add() self.sign = P.Sign() self.afb_mul = P.Mul() self.afb_div = P.Div() self.pipeline_stages = ms.context.get_auto_parallel_context("pipeline_stages") self.micro_batch_num = micro_batch_num if self.pipeline_stages > 1 else 1 self.acc_step_over_expert_num = \ Tensor([micro_batch_num * gradient_accumulation_steps / expert_num], ms.float32) self.topk_bias_update_rate = topk_bias_update_rate self.zeros_tensor = ms.Tensor(np.zeros([expert_num]), ms.float32) def _update_topk_bias(self, network): """update topk bias tensor during training.""" while hasattr(network, "network"): network = network.network if hasattr(network, "update_topk_bias"): expert_loads = network.update_topk_bias(self.gradient_accumulation_steps) if self.tensor_writer is not None and self.write_expert_load_to_tensorboard: for layer, expert_load in expert_loads: if expert_load.sum() > 0: expert_load_dict = {f"ep_{i}": load_i.asnumpy() for i, load_i in enumerate(expert_load)} self.tensor_writer.add_scalars( f"expert_load/{layer}", expert_load_dict, global_step=self.cur_step ) return if self.update_topk_bias_flag: for layer in network.model.layers: if hasattr(layer.feed_forward, "routed_experts"): if hasattr(layer.feed_forward.routed_experts, "router"): expert_load_data = \ layer.feed_forward.routed_experts.router.router.expert_load.value() if expert_load_data.sum() > 0: err = self.afb_sub(self.acc_step_over_expert_num, expert_load_data) topk_bias_new = self.afb_add( layer.feed_forward.routed_experts.router.router.topk_bias.value(), self.afb_mul(self.sign(err), self.topk_bias_update_rate) ) self.assign(layer.feed_forward.routed_experts.router.router.topk_bias, topk_bias_new) self.assign(layer.feed_forward.routed_experts.router.router.expert_load, self.zeros_tensor) else: expert_load_data = layer.feed_forward.routed_experts.expert_load.value() if expert_load_data.sum() > 0: err = self.afb_sub(self.acc_step_over_expert_num, expert_load_data) topk_bias_new = self.afb_add( layer.feed_forward.routed_experts.topk_bias.value(), self.afb_mul(self.sign(err), self.topk_bias_update_rate) ) self.assign(layer.feed_forward.routed_experts.topk_bias, topk_bias_new) self.assign(layer.feed_forward.routed_experts.expert_load, self.zeros_tensor) def on_train_step_end(self, run_context): """update expert bias at the end of step.""" cb_params = run_context.original_args() self.cur_step = cb_params.cur_step_num # pylint: disable=W0212 network = cb_params.train_network while hasattr(network, 'network'): network = network.network parallel_mode = get_auto_parallel_context("parallel_mode") if parallel_mode in ["semi_auto_parallel", "auto_parallel"] and get_context('mode') == ms.GRAPH_MODE: network = network._backbone self._update_topk_bias(network) @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class MoEDropRateCallback(Callback): """Callback drop rate in moe module. Args: expert_num (int): How many experts in the moe module. capacity_factor (float): Capcity factor in the moe module. num_layers (int): How many layers in the model. mtp_depth (int): How many layers in the mtp module. Examples: >>> from mindformers.core.callback import MoEDropRateCallback >>> stop_step = MoEDropRateCallback(expert_num=8, capacity_factor=1.5, num_layers=4, mtp_depth=1) <class 'mindformers.core.callback.callback.MoEDropRateCallback'> """ def __init__(self, expert_num: int, capacity_factor: float, num_layers: int, mtp_depth: int): self.capacity_factor_over_expert_num = capacity_factor / expert_num self.num_layers = num_layers + mtp_depth def _callback_droprate(self, network): """callback drop rate.""" for i in range(self.num_layers): while hasattr(network, "network"): network = network.network if hasattr(network.model.layers[i].feed_forward, "routed_experts"): if hasattr(network.model.layers[i].feed_forward.routed_experts, "router"): fi = network.model.layers[i].feed_forward.routed_experts.router.router.fi_parameter.value() if fi.sum() > 0: delta = fi - self.capacity_factor_over_expert_num droprate = ms.ops.sum(delta * (delta > 0)) logger.info(f"layer: {i}, drop_rate: {droprate:.5f}") else: if hasattr(network.model.layers[i].feed_forward, "router"): fi = network.model.layers[i].feed_forward.router.router.fi_parameter.value() if fi.sum() > 0: delta = fi - self.capacity_factor_over_expert_num droprate = ms.ops.sum(delta * (delta > 0)) logger.info(f"layer: {i}, drop_rate: {droprate:.5f}") def on_train_step_end(self, run_context): """get expert drop rate at the end of step.""" cb_params = run_context.original_args() # pylint: disable=W0212 network = cb_params.train_network while hasattr(network, 'network'): network = network.network parallel_mode = get_auto_parallel_context("parallel_mode") if parallel_mode in ["semi_auto_parallel", "auto_parallel"] and get_context('mode') == ms.GRAPH_MODE: network = network._backbone self._callback_droprate(network) def get_embedding_info(cb_params, embedding_size): """print embedding info and get the health of checkpoint.""" if len(cb_params.net_outputs) < 7: raise ValueError("You should turn on the local norm while using the skip data by global norm function.") embedding_local_norm = 0 pipeline_stages = ms.context.get_auto_parallel_context("pipeline_stages") device_nums = get_group_size() rank = get_rank() if rank < device_nums // pipeline_stages: for local_norm, local_norm_size in zip(cb_params.net_outputs[5], cb_params.net_outputs[6]): if local_norm_size == embedding_size: embedding_local_norm = local_norm break return embedding_local_norm @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class StressTestModelMonitor(Callback): """Initialize the StressTestModelMonitor. Args: interval_steps (int, optional): Number of steps after which to check the model. stress_model_dir (str, optional): The directory where the model ymal file is stored. stress_dataset_dir (str): The directory where the stress test dataset is stored. compare_interval_steps (int, optional): Number of interval steps where the stress test result is compared. stress_master_port (int, optional): The master port of stress test. stress_test_log_dir (str optional): The directory where the stress test training log is stored. check_stresslog_interval_time (int, optional): Time interval where the stress test log is checked. """ def __init__(self, interval_steps=10, stress_model_dir=None, stress_dataset_dir=None, compare_interval_steps=None, stress_master_port=8338, stress_test_log_dir="test_output/stress_test_output1/msrun_log", check_stresslog_interval_time=60): logger.warning('StressTestModelMonitor serves as an experimental interface and its functionality is ' 'not yet stable.') super().__init__() self.interval_steps = interval_steps self.last_checked_step = 0 self.model_dir = stress_model_dir if not self.model_dir or not os.path.exists(self.model_dir): raise ValueError(f"model_dir {self.model_dir} was not found for StressTestModelMonitor.") self.dataset_dir = stress_dataset_dir self.stress_master_port = stress_master_port self.main_master_port = int(os.getenv("MS_SCHED_PORT")) logger.info(f"The main model is using master port {self.main_master_port}") if not isinstance(self.stress_master_port, int) or self.stress_master_port < 1: logger.warning("For StressTestMonitor, stress_master_port must be an integer greater than or equal " f"to 1, but got {self.stress_master_port}. Setting to default value 8338") self.stress_master_port = 8338 if self.stress_master_port == self.main_master_port: logger.warning("For StressTestMonitor, stress_master_port must be different from the main task " f"but both got {self.stress_master_port}. Setting to {self.stress_master_port + 1}") self.stress_master_port += 1 logger.warning(f"Make sure that the new port {self.stress_master_port} is unoccupied.") self.worker_num = ms.communication.get_local_rank_size() logger.info(f"Local worker number for each stress test is {self.worker_num}.") self.compare_interval_steps = compare_interval_steps if not isinstance(self.compare_interval_steps, int) or self.compare_interval_steps < 1: logger.warning("For StressTestMonitor, compare_interval_steps must be an integer greater than or equal" f" to 1, but got {self.compare_interval_steps}.") logger.warning("Skipping interval steps comparison, only the last step result will be compared." " compare_interval_steps is set to None") self.compare_interval_steps = None self.stress_test_log_dir = stress_test_log_dir self.check_stresslog_interval_time = check_stresslog_interval_time if not isinstance(self.check_stresslog_interval_time, int) or self.check_stresslog_interval_time < 1: logger.warning(f"For StressTestMonitor, check_stresslog_interval_time must be an integer greater than or " f"equal to 1, but got {self.check_stresslog_interval_time}. Setting to default value 60") self.check_stresslog_interval_time = 60 def on_train_step_end(self, run_context): """Perform actions after each training step and check the criteria.""" cb_params = run_context.original_args() current_step = cb_params.cur_step_num # Retrieve the current step number # Check if interval_steps is set and enough steps have passed if self.interval_steps and (current_step - self.last_checked_step >= self.interval_steps): self.check_stress_test_model(current_step) self.last_checked_step = current_step # Update the last checked step def check_stress_test_model(self, current_step): """Perform stress test on current step""" logger.info(f"On Step {current_step}, Main Process paused. Running the stress test models...") logger.info(f"Stress test model directory is: '{self.model_dir}'") logger.info(f"Check stress test logs at {self.stress_test_log_dir} for details.") if not self.dataset_dir or not os.path.exists(self.dataset_dir): logger.error(f"dataset_dir: {self.dataset_dir} was not found for StressTestModelMonitor, " "Exiting Stress test.") return num_cores = os.cpu_count() cpu_cores = f"0-{num_cores - 1}" logger.debug(f"CPU cores assigned to the stress test task: {cpu_cores}") rank_id = get_rank() if rank_id % self.worker_num == 0: node_num = rank_id // self.worker_num saved_dir = os.path.join(self.stress_test_log_dir, "node" + str(node_num)) command = f"""taskset -c {cpu_cores} bash scripts/msrun_launcher.sh "run_mindformer.py \ --config {self.model_dir} \ --use_parallel True\ --run_mode train \ --train_data {self.dataset_dir}" \ {self.worker_num} {self.stress_master_port} {saved_dir} True 7200""" logger.info(f"Running stress test on node {node_num}, RANK {rank_id} with logs in {saved_dir}") log_file_path = os.path.join(saved_dir, "worker_0.log") # Start the subprocess command = shlex.split(command) with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as result_1: # Monitor the log file while result_1.poll() is None: # While the subprocess is running time.sleep(self.check_stresslog_interval_time) log_msg = self.readlog(log_file_path) logger.info(f"Checking stress test log every {self.check_stresslog_interval_time} seconds") logger.info(f"Current state of stress_test: {log_msg}") # Once the subprocess has finished, check the result if result_1.returncode != 0: logger.warning(f"An error occurred while running the stress test model on rank {rank_id}: \ {result_1.stderr.read().decode('utf-8')}") logger.warning(f"Check the sub task workers log for rank {rank_id} for more details.") barrier() logger.info("Stress tests ended, now starting to collect and compare results") # If compare_interval_steps is None, only compare the last step result, and check for its validity. if not self.compare_interval_steps: logger.warning("For StressTestMonitor, compare_interval_steps is set to None, " "so only the last step result is compared.") else: interval_results = None subtask_global_step_num = None logger.info(f"Test results are compared every {self.compare_interval_steps} steps") for i in range(self.worker_num): if get_rank() % self.worker_num == i: node_num = get_rank() // self.worker_num log_dir = os.path.join(self.stress_test_log_dir, "node" + str(node_num)) log_file_path = os.path.join(log_dir, f"worker_{i}.log") logger.info(f"log_file_path created with {log_file_path}") interval_results, subtask_global_step_num = self.extract_interval_step_results(log_file_path) barrier() # Check if the compare_interval_steps is larger than the total steps in the stress test task if interval_results is None: logger.warning(f"compare_interval_steps {self.compare_interval_steps} is larger than the total number" f" of steps {subtask_global_step_num}, so only the last step result is compared.") else: gathered_interval_results, _ = all_gather_into_tensor(interval_results) gathered_interval_results = gathered_interval_results.asnumpy() logger.info("Stress tests interval results collected, now starting to compare interval results") logger.debug(f"Collected interval results are {gathered_interval_results}") _ = self.compare_gathered_results(gathered_interval_results) # Now compare the results from the last step, this is executed regardless of compare_interval_steps setting last_step_results = None for i in range(self.worker_num): if get_rank() % self.worker_num == i: node_num = get_rank() // self.worker_num log_dir = os.path.join(self.stress_test_log_dir, "node" + str(node_num)) log_file_path = os.path.join(log_dir, f"worker_{i}.log") logger.info(f"log_file_path created with {log_file_path}") last_step_results = self.extract_last_step_result(log_file_path) barrier() gathered_results, _ = all_gather_into_tensor(last_step_results) # <class 'mindspore.common.tensor.Tensor'> gathered_results = gathered_results.asnumpy() # <class 'numpy.ndarray'> logger.debug("Collected last step results are gathered_results.") logger.info("Last step results are collected from each rank, now starting to compare last step results") rank0_result = gathered_results[0] comparison = np.all(gathered_results == rank0_result, axis=1) if np.all(comparison): logger.info(f"STRESS TEST PASSED. ALL Results aligned at step {current_step}: " f"[loss, global_norm] = {rank0_result}") else: unmatched_rank = np.where(~comparison)[0] # Indices of rows that do not match discrepancies = gathered_results[unmatched_rank] # Get the discrepancies logger.warning(f"STRESS TEST FAILED at step {current_step}. Discrepancies found at rank: " f"{unmatched_rank}, values: {discrepancies}.") logger.info(f"On Step {current_step}: Stress test ended! Resume training of the main model.") return def extract_last_step_result(self, log_file): """Extract the last step's results from the log file.""" loss_value = None global_norm_value = None with open(log_file, 'r', encoding="utf-8") as file: lines = file.readlines() for line in reversed(lines): if "INFO - {" in line: # Extract loss and global_norm values loss_value = self.get_value_from_line(line, r"loss: (\d+\.\d+)") global_norm_value = self.get_value_from_line(line, r"global_norm: \[(\d+\.\d+)\]") if loss_value is not None and global_norm_value is not None: break return Tensor([[loss_value, global_norm_value]], ms.float32) def extract_interval_step_results(self, log_file): """Extract results from specific steps in the middle of log file""" last_recorded_step = 0 results = [] steps_per_epoch = None with open(log_file, 'r', encoding="utf-8") as file: lines = file.readlines() for line in lines: if "INFO - {" in line: # Get the number of steps per epoch to calculate global step number if not steps_per_epoch: step_info = re.search(r"step:\[\s*\d+/\s*(\d+)\]", line) steps_per_epoch = int(step_info.group(1)) # Calculate the global step number epoch_match = re.search(r"Epoch:\[\s*(\d+)", line) step_match = re.search(r"step:\[\s*(\d+)", line) epoch_number = int(epoch_match.group(1)) step_number = int(step_match.group(1)) global_step_number = (epoch_number - 1) * steps_per_epoch + step_number # Consider logging only if it matches the interval if global_step_number >= (self.compare_interval_steps + last_recorded_step): loss_value = self.get_value_from_line(line, r"loss: (\d+\.\d+)") global_norm_value = self.get_value_from_line(line, r"global_norm: \[(\d+\.\d+)\]") results.append(Tensor([[epoch_number, step_number, loss_value, global_norm_value]], ms.float32)) last_recorded_step = global_step_number # if results is empty, it means that compare_interval_steps is larger than the total step number if not results: return None, global_step_number results = Tensor(results) return results, global_step_number def compare_gathered_results(self, gathered_interval_results): """Compares results from different ranks at the same epoch and step number.""" results_dict: Dict[Tuple[int, int], List[Tuple[float, float]]] = {} for result in gathered_interval_results: epoch_number, step_number, loss_value, global_norm_value = result[0] epoch_step_key = (int(epoch_number), int(step_number)) # Organize results by epoch and step if epoch_step_key not in results_dict: results_dict[epoch_step_key] = [] results_dict[epoch_step_key].append((loss_value, global_norm_value)) consistent = True discrepancies = {} # Now iterate through the results_dict to check for consistency for epoch_step, values in results_dict.items(): # Retrieve loss-global pairs for comparison loss_global_pairs = [(val[0], val[1]) for val in values] # Check if all pairs are consistent if all(pair == loss_global_pairs[0] for pair in loss_global_pairs): logger.info(f"Results consistent for epoch {epoch_step[0]}, step {epoch_step[1]}: " f"(loss, global_norm) = {loss_global_pairs[0]}") else: consistent = False discrepancies[epoch_step] = [] # Collect the ranks associated with discrepancies for idx, val in enumerate(loss_global_pairs): if val != loss_global_pairs[0]: # Store the index in the original gathered_interval_results discrepancies[epoch_step].append((idx, val)) if consistent: logger.info("ALL INTERVAL TESTS PASSED. All results aligned across all intervals.") return True for epoch_step, disc_values in discrepancies.items(): indices = [val[0] for val in disc_values] value_pairs = [val[1] for val in disc_values] logger.warning(f"STRESS TEST FAILED. DISCREPANCIES found in epoch {epoch_step[0]}, " f"step {epoch_step[1]}: ranks {indices}, (loss, global_norm) = {value_pairs}") logger.warning("Check the workers log of the problematic rank for detailed results") return False def get_value_from_line(self, line, pattern): """Extracts a numerical value from a line based on a regex pattern.""" match = re.search(pattern, line) if match: return float(match.group(1)) return None def readlog(self, file_path): """ Search for the latest line indicating training has started, based on key identifiers. """ with open(file_path, 'r', errors='ignore', encoding="utf-8") as f: lines = f.readlines() # Define the keywords indicating training has started keywords = ['Epoch', 'step', 'loss', 'global_norm'] # Search backwards for the latest line containing all keywords for line in reversed(lines): if all(keyword in line for keyword in keywords): parts = line.split("- INFO -") if len(parts) > 1: return parts[1].strip() # Return the part after '- INFO -' return line.strip() # If no such line, training hasn't started return "Training has not started yet." @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class SDCMonitor(Callback): """Monitor SDC (Silent Data Corruption) by SilentCheck and CheckSum. Args: initial_step (int, optional): The beginning step. Default: ``0``. step_interval (int, optional): The interval of steps to monitor SilentCheck errors in device logs. Default: ``10``. strike_window_time (int, optional): The window time (minutes) to monitor SilentCheck error. Default: ``480``. strike_num (int, optional): The number of SilentCheck error to strike out and start CheckSum. Default: ``3``. checksum_time (int, optional): The duration (minutes) of CheckSum. Default: ``5``. checksum_cooldown_time (int, optional): The cooldown time (minutes) of CheckSum after it stops. Default: ``180``. """ def __init__(self, initial_step: int = 0, step_interval: int = 10, strike_window_time: int = 480, strike_num: int = 3, checksum_time: int = 5, checksum_cooldown_time: int = 180): super().__init__() logger.warning('SDCMonitor serves as an experimental interface and its functionality is not yet stable.') npu_asd_enable = int(os.getenv('NPU_ASD_ENABLE', '0')) ms_sdc_detect_enable = int(os.getenv('MS_SDC_DETECT_ENABLE', '0')) if npu_asd_enable != 1 or ms_sdc_detect_enable != 1 or not is_version_ge(ms.__version__, '2.7.0'): raise ValueError("SDCMonitor needs mindspore >= 2.7.0, and only works when environment variable " "'NPU_ASD_ENABLE' and 'MS_SDC_DETECT_ENABLE' are set to 1.") self.initial_step = initial_step self.step_interval = step_interval self.step_times = {datetime.now(): initial_step} # {timestamp: step} self.silent_check_error_times = {} # {timestamp: step} self.strike_window_time = timedelta(minutes=strike_window_time) self.strike_num = strike_num self.checksum_enable = False self.prev_checksum_time = datetime.min # start/stop time self.checksum_time = timedelta(minutes=checksum_time) self.checksum_cooldown_time = timedelta(minutes=checksum_cooldown_time) self.device_log_path = os.path.join(get_ascend_log_path(), 'debug', f'device-{get_real_local_rank()}') # device log file: device-<pid>_<timestamp>.log, e.g. device-311523_20250225184632284.log self.prev_log_file_time = "0" pid = os.getpid() self.log_file_pattern = re.compile(rf'device-{pid}_(\d{{17}})\.log$') self.silent_check_error_pattern = re.compile(r'^\[ERROR\].*silent_check_v3\.cc:.*SilentCheckV3', re.MULTILINE) # device log time: YYYY-MM-DD-HH:MM:SS.SSS.SSS self.log_time_pattern = re.compile(r'(\d{4}-\d{2}-\d{2}-\d{2}:\d{2}:\d{2}\.\d{3}\.\d{3})') logger.info(f"Device log path: {self.device_log_path}, pid: {pid}") self.all_reduce_net = AllReduceNet(GlobalComm.WORLD_COMM_GROUP) # AllReduce status and result of CheckSum def _get_log_files_to_check(self): """Get device log filenames after last check and sort them by timestamp.""" log_files = [] if not os.path.exists(self.device_log_path): return log_files for f in os.listdir(self.device_log_path): match = self.log_file_pattern.match(f) if match and match.group(1) >= self.prev_log_file_time: log_files.append(f) log_files.sort() return log_files def _parse_silent_check_error_times(self, log_files): """Parse SilentCheck error times of step from device logs""" # parse error log times, log file size < 20MB error_log_times = [] for file in log_files: file_path = os.path.join(self.device_log_path, file) if not os.path.exists(file_path): continue with open(file_path, 'r', encoding="utf-8") as f: logs = f.read() error_logs = self.silent_check_error_pattern.findall(logs) for log in error_logs: match = self.log_time_pattern.search(log) if match: log_time = match.group(1) # merge ms and us in str then convert to datetime log_time = re.sub(r'\.(\d{3})\.(\d{3})', lambda m: f".{m.group(1)}{m.group(2)}", log_time) log_time = datetime.strptime(log_time, "%Y-%m-%d-%H:%M:%S.%f") error_log_times.append(log_time) if not error_log_times: return {} # process from latest to earliest, stop early if error num reaches strike num error_times = {} # {timestamp: step} step_time_list = list(self.step_times.keys()) index = len(step_time_list) - 1 for log_time in reversed(error_log_times): while index > 0 and log_time <= step_time_list[index - 1]: index -= 1 if index == 0 or len(error_times) == self.strike_num: break left, right = step_time_list[index - 1], step_time_list[index] # all SilentCheck errors in a step is treated as one error if left < log_time <= right: step = self.step_times[right] logger.warning(f"SilentCheck detect SDC at step: {step}") error_times[log_time] = step index -= 1 return dict(reversed(list(error_times.items()))) # order from earliest to latest def _update_silent_check_error_times(self, new_silent_check_error_times, now): """Add new SilentCheck error times and remove expired ones.""" self.silent_check_error_times.update(new_silent_check_error_times) expired_error_times = [] for error_time, _ in self.silent_check_error_times.items(): if now - error_time > self.strike_window_time: expired_error_times.append(error_time) for error_time in expired_error_times: self.silent_check_error_times.pop(error_time) def _start_checksum(self, step): """Sync CheckSum enable status on all ranks and start CheckSum.""" # set context to skip pp validation during global AllReduce parallel_mode = ms.get_auto_parallel_context("parallel_mode") ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.STAND_ALONE) self.checksum_enable = bool( self.all_reduce_net(ms.Tensor([self.checksum_enable], ms.int32)).asnumpy()[0]) ms.set_auto_parallel_context(parallel_mode=parallel_mode) if self.checksum_enable: logger.info(f"Start CheckSum at step: {step}") self.prev_checksum_time = datetime.now() ms.sdc_detect_start() def _stop_checksum(self, step): """Stop CheckSum and aggregate SDC detection result.""" logger.warning(f"Stop CheckSum at step: {step}") ms.sdc_detect_stop() self.checksum_enable = False now = datetime.now() self.prev_checksum_time = now has_sdc = ms.get_sdc_detect_result() if has_sdc: logger.warning(f"CheckSum detects SDC on rank {get_real_rank()}") # set context to skip pp validation during global AllReduce parallel_mode = ms.get_auto_parallel_context("parallel_mode") ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.STAND_ALONE) has_sdc = bool(self.all_reduce_net(ms.Tensor([has_sdc], ms.int32)).asnumpy()[0]) ms.set_auto_parallel_context(parallel_mode=parallel_mode) if has_sdc: logger.warning("Detect SDC by SilentCheck and CheckSum, which means training may be unstable. " "Check training logs and device logs of each rank for more details.") self.silent_check_error_times.clear() self.step_times = {now: step} def on_train_step_end(self, run_context): """Monitor SilentCheck errors and manage CheckSum if strike out.""" cb_params = run_context.original_args() cur_step_num = cb_params.cur_step_num + self.initial_step now = datetime.now() # stop CheckSum and clear previous SilentCheck errors if self.checksum_enable: if now - self.prev_checksum_time >= self.checksum_time: self._stop_checksum(cur_step_num) return self.step_times[now] = cur_step_num # parse device logs and start CheckSum if strike out if cb_params.cur_step_num % self.step_interval == 0: logger.info(f"Checking device logs at step: {cur_step_num}...") log_files = self._get_log_files_to_check() new_silent_check_error_times = self._parse_silent_check_error_times(log_files) self._update_silent_check_error_times(new_silent_check_error_times, now) if now - self.prev_checksum_time >= self.checksum_cooldown_time: if len(self.silent_check_error_times) >= self.strike_num: self.checksum_enable = True logger.warning(f"SDC {self.strike_num} strikes and out on rank: {get_real_rank()}, " f"SilentCheck error steps: {list(self.silent_check_error_times.values())}") # any rank stikes out will enable CheckSum in all ranks by AllReduce self._start_checksum(cur_step_num) self.step_times = {now: cur_step_num} if log_files: self.prev_log_file_time = re.search(r'_(\d{17})\.log$', log_files[-1]).group(1) @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class ExpertMigrateCallback(Callback): """ Callback for expert migration in mixture of experts (MoE) training. This callback handles expert load history statistics and load balancing during training. Args: config (TransformerConfig, optional): Training configuration for transformer models. Default: ``None`` print_expert_load (bool, optional): Whether to print expert load statistics. Default: ``False`` manager (ExpertParallelManager, optional): Expert parallel manager instance. Default: ``None`` enable_expert_relocation (bool, optional): Whether to enable expert relocation. Default: ``False`` expert_relocation_initial_iteration (int, optional): Initial iteration to start relocation. Default: 20 expert_relocation_freq (int, optional): Frequency of expert relocation. Default: 50 save_checkpoint_steps (int, optional): Checkpoint saving steps. Default: ``None`` """ @args_type_check(config=TransformerConfig) def __init__(self, config=None, print_expert_load=False, manager=None, enable_expert_relocation=False, expert_relocation_initial_iteration=20, expert_relocation_freq=50, save_checkpoint_steps=None): self.config = config self.print_expert_load = print_expert_load self.manager = manager self.enable_expert_relocation = enable_expert_relocation self.expert_relocation_initial_iteration = expert_relocation_initial_iteration self.expert_relocation_freq = expert_relocation_freq self.save_checkpoint_steps = save_checkpoint_steps self.rank_id = get_rank() self.pp = self.config.pipeline_model_parallel_size self.dp = self.config.data_parallel_size self.tp = self.config.tensor_model_parallel_size self.ep = self.config.expert_model_parallel_size self.pp_stage_id = self.rank_id // (self.dp * self.tp) self.dp_group_id = (self.rank_id // self.tp) % self.dp self.tp_group_id = self.rank_id % self.tp self.ep_group_id = self.rank_id % self.ep self.num_layers = self.config.num_layers self.mtp_num_layers = self.config.mtp_num_layers self.relative_expert_mapping = [None] * (self.num_layers + self.mtp_num_layers) def _get_moe_layers_in_current_stage(self): """Get all MoE layers in current pipeline stage.""" moe_layers = [] for layer_index in range(self.num_layers): if not (hasattr(self.network, "model") and hasattr(self.network.model, "decoder") and hasattr(self.network.model.decoder, "layers")): continue layer = self.network.model.decoder.layers[layer_index] if hasattr(layer, 'pipeline_stage') and layer.pipeline_stage != self.pp_stage_id: continue if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts'): moe_layers.append((layer_index, layer)) for layer_index in range(self.mtp_num_layers): if self.pp_stage_id < self.pp - 1: continue if not (hasattr(self.network, "model") and hasattr(self.network.model, "mtp") and hasattr(self.network.model.mtp, "layers")): continue layer = self.network.model.mtp.layers[layer_index] if hasattr(layer, 'transformer_layer'): transformer_layer = layer.transformer_layer if hasattr(transformer_layer, 'mlp') and hasattr(transformer_layer.mlp, 'experts'): moe_layers.append((layer_index + self.num_layers, transformer_layer)) return moe_layers def _update_expert_load_history(self): """Update expert load history for all layers.""" for _, layer in self._get_moe_layers_in_current_stage(): if hasattr(layer.mlp, 'expert_load_history'): num_tokens_per_expert = layer.mlp.experts.num_tokens_per_expert layer.mlp.update_expert_load_history(num_tokens_per_expert) layer.mlp.experts.num_tokens_per_expert.set_data( ms.mint.zeros(self.config.num_moe_experts, dtype=ms.int32)) def _expert_weight_and_optimizer_state_relocation(self, is_triggered_restore=False): """ Relocate expert weights and optimizer states based on load balancing. Args: is_triggered_restore (bool): Whether expert relocation restore is triggered in this step. """ # Collect optimizer parameters optimizer_param_dict = {} for param_name, param in self.optimizer.parameters_and_names(): if param_name.startswith("adam_m.") or param_name.startswith("adam_v."): if "mlp.experts.weight" in param_name: optimizer_param_dict[param_name] = param original_mode = get_context("mode") for layer_index, layer in self._get_moe_layers_in_current_stage(): # Initialize relocation num_local_experts = layer.mlp.num_local_experts ep_group = layer.mlp.experts.token_dispatcher.ep_group q_mapping, local_expert_sorted_indices = layer.mlp.initialize_expert_relocation_dispatcher( is_triggered_restore) self.relative_expert_mapping[layer_index] = np.array(list(q_mapping.values())).flatten() context.set_context(mode=context.PYNATIVE_MODE) # Update communication info send_size_list, recv_size_list = self.manager.update_communication_info(q_mapping) # Collect expert parameters and optimizer states expert_param_list = [] for _, param in layer.mlp.parameters_and_names(): if not param.sliced or param.has_init: continue if "mlp.experts.weight" in param.name and "accu_grad" not in param.name: expert_param_list.append(param) m_name = f"adam_m.{param.name}" v_name = f"adam_v.{param.name}" if m_name in optimizer_param_dict and v_name in optimizer_param_dict: expert_param_list.extend([optimizer_param_dict[m_name], optimizer_param_dict[v_name]]) else: logger.warning(f"Missing optimizer params: {m_name}, {v_name}") # Prepare restore indices (if needed) local_expert_restore_indices = None if is_triggered_restore: expert_mapping = layer.mlp.expert_mapping.copy() # pylint: disable=R1721 expert_mapping_restored = Tensor([idx for idx in range(expert_mapping.shape[0])], ms.int32) layer.mlp.expert_mapping.set_data(expert_mapping_restored) local_expert_restore_indices = expert_mapping.reshape(self.ep, -1)[self.ep_local_rank] local_expert_restore_indices = (local_expert_restore_indices - self.ep_local_rank * local_expert_restore_indices.shape[0]) self.manager.restore_communication_info() # Relocate all parameters for param in expert_param_list: param_collected = Tensor(param).reshape(num_local_experts, -1) param_sorted = ms.mint.index_select(param_collected, 0, Tensor(local_expert_sorted_indices)) recv_tensor = ms.mint.zeros_like(param_collected) all_to_all_single(recv_tensor, param_sorted, recv_size_list, send_size_list, group=ep_group) if local_expert_restore_indices is not None: recv_tensor = ms.mint.index_select(recv_tensor, 0, local_expert_restore_indices) recv_tensor = recv_tensor.reshape(param.shape) param.set_data(recv_tensor, param.sliced) context.set_context(mode=original_mode) def _print_expert_load(self, is_triggered_expert_relocation=False): """Print expert load statistics for monitoring load balancing. Args: is_triggered_expert_relocation (bool): Whether expert relocation is triggered in this step. """ for layer_index, layer in self._get_moe_layers_in_current_stage(): num_local_experts = layer.mlp.num_local_experts expert_load_history = layer.mlp.expert_load_history.asnumpy() expert_load_per_device = expert_load_history.reshape(-1, num_local_experts).sum(-1) min_load = expert_load_per_device.min() max_load = expert_load_per_device.max() std_load = expert_load_per_device.std() expert_load = expert_load_per_device.astype(int).tolist() rank_info = (f"rank {self.rank_id} dp {self.dp_group_id} " f"pp {self.pp_stage_id} " f"tp {self.tp_group_id} ep {self.ep_group_id}") if not (self.enable_expert_relocation and is_triggered_expert_relocation): logger.info(f"[expert load] iter {self.cur_step}: {rank_info}, " f"layer {layer_index}, load: min {min_load:.2f}, " f"max {max_load:.2f}, std {std_load:.2f}, " f"details {expert_load}") continue # Handle expert relocation case expert_load_history_sorted = expert_load_history if self.relative_expert_mapping[layer_index] is not None: expert_load_history_sorted = expert_load_history_sorted[self.relative_expert_mapping[layer_index]] self.relative_expert_mapping[layer_index] = None expert_load_per_device_after = expert_load_history_sorted.reshape(-1, num_local_experts).sum(-1) min_load_after = expert_load_per_device_after.min() max_load_after = expert_load_per_device_after.max() std_load_after = expert_load_per_device_after.std() expert_load_after = expert_load_per_device_after.astype(int).tolist() logger.info(f"[expert load before] iter {self.cur_step}: {rank_info}, " f"layer {layer_index}, load: min {min_load:.2f}, " f"max {max_load:.2f}, std {std_load:.2f}, " f"details {expert_load}") logger.info(f"[expert load after (est.)] iter {self.cur_step}: {rank_info}, " f"layer {layer_index}, load: min {min_load_after:.2f}, " f"max {max_load_after:.2f}, std {std_load_after:.2f}, " f"details {expert_load_after}") def _reset_expert_load(self): """Reset expert load history counters.""" for _, layer in self._get_moe_layers_in_current_stage(): layer.mlp.expert_load_history.set_data( ms.mint.zeros(self.config.num_moe_experts, dtype=ms.float32)) layer.mlp.expert_load_history_cnt.set_data(Tensor(0, dtype=ms.int32)) def on_train_step_end(self, run_context): """ Print training info at the end of each step. Args: run_context (RunContext): Context of the process running. """ cb_params = run_context.original_args() self.cur_step = cb_params.cur_step_num # pylint: disable=W0212 network = cb_params.train_network while hasattr(network, 'network'): network = network.network parallel_mode = get_auto_parallel_context("parallel_mode") if parallel_mode in ["semi_auto_parallel", "auto_parallel"] and get_context('mode') == ms.GRAPH_MODE: network = network._backbone while hasattr(network, "network"): network = network.network self.network = network optimizer = cb_params.optimizer if optimizer is None: optimizer = getattr(cb_params.network, "optimizer", None) self.optimizer = optimizer is_triggered_expert_relocation = self.enable_expert_relocation and \ self.cur_step >= self.expert_relocation_initial_iteration and \ (self.cur_step - self.expert_relocation_initial_iteration) % (self.expert_relocation_freq) == 0 is_triggered_restore = False if self.cur_step == self.save_checkpoint_steps and self.enable_expert_relocation: is_triggered_expert_relocation = False if self.cur_step > self.expert_relocation_initial_iteration: is_triggered_expert_relocation = True is_triggered_restore = True self._update_expert_load_history() if is_triggered_expert_relocation: self._expert_weight_and_optimizer_state_relocation(is_triggered_restore) if self.print_expert_load and not is_triggered_restore: self._print_expert_load(is_triggered_expert_relocation) # reset expert load history if is_triggered_expert_relocation: self._reset_expert_load()