Comparing the Model Precision with that of Megatron-LM
1. Overview
In the LLM training system, precision validation at the model level is a key step to ensure training stability and result reliability. As training tasks become increasingly complex and model structures become larger, it is particularly important to ensure alignment of the overall model behavior between different implementations.
Megatron-LM is a mature framework for large-scale training tasks. It is highly modular and scalable and is widely used in training scenarios with high performance requirements. MindSpore Transformers r1.6.0 upgrades the model architecture by using the ModuleSpec configuration mode to build models. This makes model structure definition more flexible and easier to reuse, greatly improving development efficiency. In addition, comprehensive training support is provided in the NPU environment, fully leveraging the advantages of the NPU architecture.
This document focuses on the validation of precision consistency at the model level. By building equivalent model structures and configurations and using unified inputs, this document compares key training performance indicators such as the forward output, loss, and gradient behavior to validate the reliability and precision controllability of MindSpore Transformers in the NPU environment.
2. Environment
This section describes the recommended basic operating environment for the precision comparison experiment.
Driver Version
GPU |
Version |
NPU |
Version |
|---|---|---|---|
CUDA |
12.1 |
CANN |
8.1.RC1 |
Important Libraries and Dependency Versions
GPU |
Version |
NPU |
Version |
|---|---|---|---|
Megatron-LM |
core_r0.12.0 |
MindSpore Transformers |
master |
Python |
3.10 or later |
Python |
3.10 or later |
PyTorch |
2.7.0 |
MindSpore |
2.6.0 |
NumPy |
1.26.4 |
NumPy |
1.26.4 |
Transformer Engine |
2.1.0 |
||
Apex |
0.1 |
Image Links
The GPU/NPU dependency versions in the preceding tables are for reference only. The actual versions in official images prevail.
Megatron-LM: For details, see Megatron-LM documentation.
MindSpore Transformers: For details, see MindSpore Transformers documentation.
3. Precision Comparison Process
This section describes the model-level precision consistency validation process between MindSpore Transformers and the mainstream Megatron-LM in an NPU environment. This process is used to guide users through the entire alignment from model configuration, data input, and forward output, to gradient backpropagation, and finally evaluate the precision consistency of the two frameworks under the same task.
3.1 Configuration Alignment
The first step of the precision comparison process is to ensure that the two frameworks use the same model configuration. This section provides the configuration files of Megatron-LM and MindSpore Transformers, which define the model structure, parallel policy, and key training hyperparameters.
The configuration alignment aims to ensure that the two systems are as consistent as possible in the initial state, so that the forward output and gradient backpropagation can be compared.
The following tables describe the configuration comparison with Megatron-LM.
Model configurations
This document supports only the precision comparison of the mcore model. Therefore,
use-mcore-modelmust be configured for Megatron-LM, anduse_legacy: Falsemust be configured for MindSpore Transformers.Megatron-LM
Description
MindSpore Transformers
Description
use-legacy-modelanduse-mcore-modelSpecifies whether to use the mcore model.
use_legacySpecifies whether to use the mcore model.
num-layersNumber of network layers, that is, number of transformer layers.
num_layersNumber of network layers, that is, number of transformer layers.
encoder-num-layersNumber of encoder layers.
Not supported.
decoder-num-layersNumber of decoder layers.
Not supported.
hidden-sizeSize of the hidden layer, which is the dimension in the hidden state.
hidden_sizeSize of the hidden layer, which is the dimension in the hidden state.
ffn-hidden-sizeSize of the hidden layer in the feedforward network.
intermediate_sizeSize of the hidden layer in the feedforward network.
num-attention-headsNumber of attention heads.
num_headsNumber of attention heads.
kv-channelsNumber of key/value tensor channels.
head_dimNumber of key/value tensor channels.
group-query-attentionSpecifies whether to enable group query attention.
use_gqaSpecifies whether to enable group query attention.
num-query-groupsNumber of query groups.
n_kv_headsNumber of query groups.
max-position-embeddingsMaximum position encoding length.
max_position_embeddingsMaximum position encoding length.
position-embedding-typePosition encoding type, such as learned_absolute and rope.
position_embedding_typePosition encoding type, such as learned_absolute and rope.
use-rotary-position-embeddingsSpecifies whether to use rotary position embedding (RoPE).
Specified by
position_embedding_type==ropeSpecifies whether to use RoPE.
rotary-baseRotary base used for RoPE.
rotary_baseRotary base used for RoPE.
rotary-percentRoPE usage ratio.
rotary_percentRoPE usage ratio.
rotary-interleavedSpecifies whether to use interleaved RoPE.
rotary_interleavedSpecifies whether to use interleaved RoPE.
rotary-seq-len-interpolation-factorRotary sequence length interpolation factor.
rotary_seq_len_interpolation_factorRotary sequence length interpolation factor.
use-rope-scalingSpecifies whether to enable RoPE scaling.
use_rope_scalingSpecifies whether to enable RoPE scaling.
rope-scaling-factorRoPE scaling factor.
scaling_factorRoPE scaling factor.
no-position-embeddingSpecifies whether to disable location encoding.
no-position-embeddingSpecifies whether to disable location encoding.
disable-bias-linearDisables bias in linear layers.
add_bias_linearEnables bias in linear layers.
mrope-sectionInformation of multiple RoPE sections.
Not supported.
make-vocab-size-divisible-byDivides the size of the word table by a specified number.
Not supported.
By default, the dictionary size is not changed.
init-method-stdStandard deviation of the normal distribution used during model parameter initialization.
init_method_stdStandard deviation of the normal distribution used during model parameter initialization.
attention-dropoutDropout probability applied in the multi-head self-attention mechanism.
attention_dropoutDropout probability applied in the multi-head self-attention mechanism.
hidden-dropoutDropout probability in the hidden layer.
hidden_dropoutDropout probability in the hidden layer.
normalizationNormalization method, which can be LayerNorm or RMSNorm.
normalizationNormalization method, which can be LayerNorm or RMSNorm.
norm-epsilonNormalized stability factor (epsilon).
rms_norm_epsRMSNorm stability factor.
apply-layernorm-1pSpecifies whether to add 1 after LayerNorm.
Not supported.
apply-residual-connection-post-layernormSpecifies whether the residual connection is applied after LayerNorm.
apply_residual_connection_post_layernormSpecifies whether the residual connection is applied after LayerNorm.
openai-geluSpecifies whether to use the GELU activation function of the OpenAI version.
Not supported.
squared-reluSpecifies whether to use the square ReLU activation function.
Not supported.
Specified by
swiglu,openai-gelu, andsquared-reluThe default value is torch.nn.functional.gelu.
hidden_actActivation function type.
gated_linear_unitSpecifies whether to use gate linear unit in multi-layer perceptron (MLP).
gated_linear_unitSpecifies whether to use gate linear unit in MLP.
swigluSpecifies whether to use the SwiGLU activation function.
hidden_act==siluandgated_linear_unitSpecifies whether to use the SwiGLU activation function.
no-persist-layer-normDisables persistence layer normalization.
Not supported.
untie-embeddings-and-output-weightsSpecifies whether to decouple the weights of the input embedding layer and output layer.
untie_embeddings_and_output_weightsSpecifies whether to decouple the weights of the input embedding layer and output layer.
Specified by
fp16andbf16Tensor compute precision during training.
compute_dtypeTensor compute precision during training.
grad-reduce-in-bf16Gradient reduction using BFloat16.
Not supported.
Not supported.
By default, the initialization tensor is generated in BFloat16 format.
param_init_typeInitial precision of the weight tensor. The default value is Float32, which ensures that the backward gradient is updated in Float32.
Not supported.
By default, layer normalization is calculated in Float32.
layernorm_compute_typeLayer normalization tensor calculation precision.
attention-softmax-in-fp32Executes attention softmax in Float32.
softmax_compute_typeSoftmax tensor calculation precision.
Not supported.
rotary_dtypePosition encoding tensor calculation precision.
loss-scaleOverall loss scaling factor.
loss_scale_valueOverall loss scaling factor, which is configured in runner_wrapper. If
compute_dtypeis set to BFloat16, the value is usually set to 1.0.initial-loss-scaleInitial loss scaling factor.
Not supported.
min-loss-scaleMinimum loss scaling factor.
Not supported.
loss-scale-windowDynamic window size scaling.
loss_scale_windowDynamic window size scaling.
hysteresisLoss scale hysteresis parameter.
Not supported.
fp32-residual-connectionUses Float32 for residual connection.
Not supported.
accumulate-allreduce-grads-in-fp32Accumulates and reduces gradients using Float32.
Not supported.
Accumulates and reduces gradients using Float32 by default.
fp16-lm-cross-entropyUses Float16 to execute the cross entropy of the LLM.
Not supported.
Uses Float32 to execute the cross entropy of the LLM by default.
q-lora-rankLoRA rank of the query projection layer, which is used when Q-LoRA is enabled.
q_lora_rankLoRA rank of the query projection layer, which is used when Q-LoRA is enabled.
kv-lora-rankLoRA rank of the key/value projection layer, which is used when KV-LoRA is enabled.
kv_lora_rankLoRA rank of the key/value projection layer, which is used when KV-LoRA is enabled.
qk-head-dimNumber of dimensions per Q/K head.
qk_nope_head_dimNumber of dimensions per Q/K head.
qk-pos-emb-head-dimNumber of relative position embedding dimensions per Q/K head.
qk_rope_head_dimNumber of relative position embedding dimensions per Q/K head.
v-head-dimNumber of dimensions per value projection (V head).
v_head_dimNumber of dimensions per value projection (V head).
rotary-scaling-factorRoPE scaling coefficient.
scaling_factorRoPE scaling coefficient.
use-precision-aware-optimizerEnables the optimizer with precision awareness to automatically manage parameter updates of different data types.
Not supported.
main-grads-dtypeData type of the main gradient.
Not supported.
By default, Float32 is used as the data type of the main gradient.
main-params-dtypeData type of the main parameter.
Not supported.
By default, Float32 is used as the data type of the main parameter.
exp-avg-dtypeData type of the exponential moving average (EMA).
Not supported.
exp-avg-sq-dtypeData type of the EMA square item.
Not supported.
first-last-layers-bf16Specifies whether to forcibly use BFloat16 at the first and last layers.
Not supported.
num-layers-at-start-in-bf16Number of layers that start with BFloat16.
Not supported.
num-layers-at-end-in-bf16Number of layers that end with BFloat16.
Not supported.
multi-latent-attentionSpecifies whether to enable the multi-hidden variable attention mechanism.
multi_latent_attentionSpecifies whether to enable the multi-hidden variable attention mechanism.
qk-layernormEnables query/key layer normalization.
qk-layernormEnables query/key layer normalization.
Optimizer and learning rate scheduling configurations
Megatron-LM
Description
MindSpore Transformers
Description
optimizerOptimizer type, such as Adam and SGD.
typeOptimizer type, such as Adam and SGD.
adam-beta1andadam-beta2β parameter of the Adam optimizer.
betasβ parameter of the Adam optimizer.
adam-epsε in the Adam optimizer (to prevent division by zero).
epsε in the Adam optimizer (to prevent division by zero).
weight-decayWeight decay coefficient.
weight-decayWeight decay coefficient.
start-weight-decayInitial weight decay.
Not supported.
end-weight-decayFinal weight decay.
Not supported.
weight-decay-incr-styleWeight decay adjustment policy, which can be constant, linear, and cosine.
Not supported.
clip-gradGradient clipping threshold.
clip_gradGradient clipping threshold, which is configured in runner_wrapper. The value is usually 1.0.
lrLearning rate.
learning_rateLearning rate.
lr-decay-styleLearning rate decay mode.
typeLearning rate decay mode.
lr-decay-itersNumber of iterations corresponding to the learning rate decay.
total_stepsTotal number of iterations by default.
lr-decay-samplesNumber of samples corresponding to the learning rate decay.
Not supported.
lr-warmup-itersNumber of warm-up iteration steps of the learning rate.
warmup_stepsNumber of warm-up iteration steps of the learning rate.
lr-warmup-fractionProportion of the learning rate warm-up phase.
warmup_ratioProportion of the learning rate warm-up phase.
lr-warmup-initInitial learning rate for warm-up.
warmup_lr_initInitial learning rate for warm-up.
min-lrMinimum learning rate.
min-lrMinimum learning rate.
Parallel and distributed configurations
Megatron-LM
Description
MindSpore Transformers
Description
tensor-model-parallel-sizeDegree of tensor model parallelism.
model_parallelDegree of tensor model parallelism.
pipeline-model-parallel-sizeParallel size of the pipeline model.
pipeline_stageParallel size of the pipeline model.
sequence-parallelSpecifies whether to enable sequence parallelism.
use_seq_parallelSpecifies whether to enable sequence parallelism.
context-parallel-sizeContext parallel size.
context_parallelContext parallel size.
use-distributed-optimizerSpecifies whether to use a distributed optimizer.
parallel_optimizer_configSpecifies whether to use a distributed optimizer.
expert-model-parallel-sizeDegree of model parallelism at the expert layer.
expert_parallelDegree of model parallelism at the expert layer.
expert-tensor-parallel-sizeDegree of tensor parallelism at the expert layer.
expert_model_parallelDegree of tensor parallelism at the expert layer.
FlashAttention/Fused Attention
Megatron-LM
Description
MindSpore Transformers
Description
attention-backendAttention implementation backend, which can be flash, fused, unfused, local, and auto.
Not supported.
use-flash-attnSpecifies whether to enable FlashAttention.
use_flash_attentionSpecifies whether to enable FlashAttention. FlashAttention is enabled by default.
no-masked-softmax-fusionDisables masked softmax fusion.
Not supported.
no-bias-gelu-fusionDisables bias+GELU fusion.
Not supported.
no-bias-swiglu-fusionDisables bias+SwiGLU fusion.
Not supported.
no-bias-dropout-fusionDisables bias+Dropout fusion.
Not supported.
no-rope-fusionDisables RoPE fusion.
Not supported.
cross-entropy-loss-fusionEnables cross entropy loss fusion.
Not supported.
MoE
Megatron-LM
Description
MindSpore Transformers
Description
num-expertsNumber of experts at each layer.
num-expertsNumber of experts at each layer.
moe-layer-freqNumber of layers between inserted MoE layers.
moe-layer-freqNumber of layers between inserted MoE layers.
moe-ffn-hidden-sizeNumber of dimensions in the hidden FFN layer in MoE.
moe_intermediate_sizeNumber of dimensions in the hidden FFN layer in MoE.
moe-shared-expert-intermediate-sizeNumber of middle dimensions shared by experts.
moe_shared_expert_intermediate_sizeNumber of middle dimensions shared by experts.
moe-shared-expert-overlapSpecifies whether to overlap the middle layer shared by experts.
moe_shared_expert_overlapSpecifies whether to overlap the middle layer shared by experts.
moe-grouped-gemmSpecifies whether to use the grouped GEMM optimization.
use_gmmSpecifies whether to use the grouped GEMM optimization.
moe-router-load-balancing-typeRouter load balancing policy.
moe_router_load_balancing_typeRouter load balancing policy.
moe-router-dtypeRouter score data type.
router_dense_typeRouter score data type.
moe-router-score-functionRouter score calculation method (for example, softmax).
use_gating_sigmoidSpecifies whether to use the Sigmoid activation function.
moe-router-topkNumber of top-k selected routers.
num_experts_chosenNumber of top-k selected routers.
moe-router-pre-softmaxSpecifies whether to preprocess before softmax.
moe_router_pre_softmaxSpecifies whether to preprocess before softmax.
moe-router-num-groupsNumber of token groups.
n_groupsNumber of token groups.
moe-router-group-topkNumber of top-k tokens in each group.
topk_groupNumber of top-k tokens in each group.
moe-router-topk-scaling-factorTop-k score scaling factor.
routed_scaling_factorTop-k score scaling factor.
moe-router-enable-expert-biasSpecifies whether to use the bias of an expert.
balance_via_topk_biasSpecifies whether to use the bias of an expert.
moe-router-bias-update-rateUpdate rate of expert bias.
topk_bias_update_rateUpdate rate of expert bias.
moe-use-legacy-grouped-gemmSpecifies whether to use the source version of Grouped GEMM.
Not supported.
moe-aux-loss-coeffAuxiliary loss coefficient of MoE.
Not supported.
moe-z-loss-coeffMoE z-loss coefficient.
Not supported.
moe-input-jitter-epsInput jitter noise of MoE.
moe_input_jitter_epsInput jitter noise of MoE.
moe-token-dispatcher-typeToken scheduling policy (for example, allgather).
Not supported.
moe-enable-deepepSpecifies whether to enable DeepEP hybrid expert optimization.
moe_enable_deepepSpecifies whether to enable DeepEP hybrid expert optimization.
moe-per-layer-loggingPrints logs at each MoE layer.
moe_per_layer_loggingPrints logs at each MoE layer.
moe-expert-capacity-factorExpansion ratio of the expert capacity.
capacity_factorExpansion ratio of the expert capacity.
moe-pad-expert-input-to-capacitySpecifies whether to fill the expert input to the capacity upper limit.
moe_pad_expert_input_to_capacitySpecifies whether to fill the expert input to the capacity upper limit.
moe-token-drop-policyToken discarding policy (for example, probs or position).
enable_sdropToken discarding policy (for example, probs or position).
moe-extended-tpEnables extended tensor parallelism.
Not supported.
moe-use-upcyclingSpecifies whether to enable expert upcycling.
Not supported.
moe-permute-fusionEnables internal permute fusion optimization of experts.
moe_permute_fusionEnables internal permute fusion optimization of experts.
mtp-num-layersNumber of MoE layers.
mtp_depthNumber of MoE layers.
mtp-loss-scaling-factorLoss scaling in the MoE architecture.
mtp_loss_factorLoss scaling in the MoE architecture.
Data loading and tokenization
Megatron-LM
Description
MindSpore Transformers
Description
data-pathandsplitGeneral data path.
data_pathSampling ratio and path of the Megatron dataset.
train-data-pathTraining data path.
Not supported.
valid-data-pathValidation data path.
Not supported.
test-data-pathTest data path.
Not supported.
vocab-sizeVocabulary size.
vocab_sizeVocabulary size.
vocab-fileVocabulary file path.
Not supported.
merge-fileBPE combination rule file.
Not supported.
tokenizer-typeTokenizer type (for example, GPT2BPETokenizer).
Not supported.
The tokenizer corresponding to Hugging Face is used by default.
seq-lengthInput sequence length.
seq_lengthInput sequence length.
encoder-seq-lengthEncoder input length.
Not supported.
decoder-seq-lengthDecoder input length.
Not supported.
retriever-seq-lengthRetriever sequence length (if enabled).
Not supported.
num-workersNumber of threads for loading data.
num_parallel_workersNumber of threads for loading data.
num-dataset-builder-threadsNumber of threads for building datasets.
Not supported.
data-cache-pathData cache path.
Not supported.
Training control and save
Megatron-LM
Description
MindSpore Transformers
Description
Not supported.
Total number of local samples processed in each iteration.
batch_sizeTotal number of local samples processed in each iteration, which is configured in
runner_wrapper.Not supported.
Total number of local samples processed in each iteration.
micro_batch_interleave_numNumber of micro-batch interleaving. When
micro_batch_interleave_numis greater than 1, multiple copies are enabled for parallel processing.global_batch_sizeTotal number of global samples processed in each iteration.
batch_sizeanddata_parallelTotal number of global samples processed in each iteration, which is the value of
batch_sizemultiplied by the value ofdata_parallelmultiplied by the value ofmicro_batch_interleave_num.Not supported.
Number of iteration periods.
epochsNumber of iteration periods, which is configured in
runner_wrapper.train-samplesTotal number of training samples.
sizesTotal number of training samples, which is configured in
train_dataset.train-itersTotal number of training iterations.
epochs,sizes, andglobal_batch_sizeTotal number of training iterations, which is the value of
sizesdivided by the value ofglobal_batch_sizeand multiplied by the value ofepochs.log-intervalLog recording interval (number of iteration steps).
per_print_timesLog recording interval (number of iteration steps), which is configured in
MFLossMonitorofcallbacks.eval-itersNumber of iterations used in each evaluation.
Not supported.
eval-intervalNumber of evaluation interval steps.
Not supported.
saveModel save path.
output_dirModel save path.
save-intervalModel save interval (number of iteration steps).
save_checkpoint_stepsModel save interval (number of iteration steps), which is configured in
CheckpointMonitorofcallbacks.non-persistent-save-interval(Non-persistent) temporary storage interval.
Not supported.
non-persistent-ckpt-typeTemporary storage type (for example, global or local).
Not supported.
pretrained-checkpointPretrained model path.
Not supported.
ckpt-stepLoads the weight of a specified step.
load_checkpointandresume_trainingLoads the weight of a specified name in resumable training scenarios.
loadLoads a model from the path.
load_checkpointLoads a model from the path.
exit-intervalIteration interval for exiting training.
stop_stepNumber of iterations after which the training is stopped, which is configured in
TrainCallMonitorofcallbacks.exit-duration-in-minsInterval for exiting training (in minutes).
Not supported.
Recomputation configurations
The recomputation configuration logic of MindSpore Transformers is greatly different from that of Megatron-LM. For details, see Recomputation.
Megatron-LM
Description
MindSpore Transformers
Description
recompute-activationsSpecifies whether to enable activation recomputation to save memory.
recomputeSpecifies whether to enable complete activation recomputation to save memory (
bool).recompute-granularityRecomputation granularity (for example, full or selective).
select_recomputeSpecifies whether to enable selective recomputation.
recompute-methodRecomputation method (for example, uniform or block).
Not supported.
recompute-num-layersNumber of recomputation layers.
recomputeNumber of recomputation layers (for example,
tupleorlist).distribute-saved-activationsDistributed storage activation value.
Not supported.
checkpoint-activationsSpecifies whether to enable the activation checkpoint mechanism to reduce the video RAM.
Not supported.
moe-layer-recomputeEnables recomputation at the MoE layer.
Not supported.
Note: The two frameworks have other configurations that are not closely related to training. For details about MindSpore Transformers, see Configuration Description. You can run the torchrun --nproc_per_node=1 pretrain_gpt.py --help command to view the Megatron-LM configuration.
3.2 Dataset Alignment
In the precision comparison process, ensure that the two frameworks use the same data input. This section describes how to align the dataset creation and configuration of Megatron-LM and MindSpore Transformers to ensure the consistency of input samples, providing a basis for subsequent weight loading and precision validation.
3.2.1 Preparing a Dataset
Both frameworks support loading the Megatron dataset. The dataset is preprocessed, serialized into a binary format (for example, .bin or .idx file), and accompanied by a specific indexing mechanism, which facilitates efficient parallel loading and data segmentation in a distributed cluster environment.
Dataset download: wikitext-103 dataset
Tokenizer model download: tokenizer.json
3.2.2 Processing a Dataset
Generating Megatron BIN files
Place the dataset file
wiki.train.tokensand the tokenization model filetokenizer.jsonin the../datasetdirectory, and create thedata.jsonfile by referring to Megatron Dataset > Data Preprocessing.Run the following commands to convert the dataset file into a BIN file:
cd $MINDFORMERS_HOME python mindformers/tools/dataset_preprocess/preprocess_indexed_dataset.py \ --input /path/data.json \ --output-prefix ../dataset/wiki_4096 \ --vocab-file ../dataset/tokenizer.json \ --seq-length 4096 \ --workers 1
Building the Megatron BIN dataset module
Run the following commands to build the Megatron BIN dataset module.
pip install pybind11 cd $MINDFORMERS_HOME/mindformers/dataset/blended_datasets make
$MINDFORMERS_HOMEindicates the directory where the MindSpore Transformers source code is stored.
3.2.3 Configuring a Dataset
This section compares and describes the dataset configuration items in the configuration files of the two frameworks.
Megatron-LM:
The dataset configuration items in the Megatron-LM sample are as follows:
TOKENIZER_MODEL="/path/to/tokenizer.json" DATA_PATH="/path/to/wiki_text_document" DATA_ARGS=( --tokenizer-type HuggingFaceTokenizer --tokenizer-model ${TOKENIZER_MODEL} --data-path $DATA_PATH --split 1,0,0 )
In the preceding information:
tokenizer-type: type of the tokenization model file.tokenizer-model: location of the tokenization model filetokenizer.json, which is accurate to the full file name.data-path: location of the processed dataset, which is accurate to the prefix of the.binor.idxfile.split: sampling ratio of the dataset.
MindSpore Transformers:
The dataset configuration items in the MindSpore Transformers sample are as follows:
config: # GPTDataset configuration items. data_path: # Sampling ratio and path of the Megatron dataset. - '1' - "/home/to/wiki_text_document"
Note that the first parameter of
data_pathis the dataset sampling ratio, and the setting in the example is equivalent to--splitin the Megatron-LM example. The second parameter is the location of the processed dataset, which is accurate to the prefix of the.binor.idxfile. The setting in the example is equivalent to--data-pathin the Megatron-LM example.
3.3 Weight Alignment
To ensure the consistency of model behavior between different frameworks, the weights obtained after training must be accurately mapped to the corresponding positions in MindSpore Transformers and Megatron-LM through proper weight conversion and segmentation.
Weight Conversion
The weight formats, parameter naming modes, and tensor arrangements of MindSpore Transformers and Megatron-LM are different. Directly loading the weights will result in incompatibility. Therefore, you need to use a dedicated conversion script to convert the model weights exported from the source framework to the format that can be identified by the target framework.
Generating initial weights of MindSpore Transformers
Modify the
example.yamlfile by referring to Callbacks Configuration and run the command provided in Viewing Results to obtain an initial weight incheckpointsofoutput_dirinexample.yamlthrough pre-training. The modification is as follows:# Before (example.yaml) load_checkpoint: '/path/to/checkpoints/'
# After (example.yaml) load_checkpoint: '' callbacks: - type: CheckpointMonitor prefix: "deepseekv3" save_checkpoint_steps: 1 keep_checkpoint_max: 2 integrated_save: False async_save: False checkpoint_format: "safetensors" - type: TrainCallBack stop_step: 1
Note: After obtaining the weight, restore
example.yaml.MindSpore Transformers to Megatron-LM
To accurately map the weights of MindSpore Transformers to the equivalent weights that can be loaded by Megatron-LM, a weight conversion script is provided. You can obtain the equivalent weights by executing the weight conversion script.
3.4 Viewing Results
After the preceding steps are complete, you can start training and extract key data from the output result in the log to check the precision comparison result.
Megatron-LM
Save the
example.shfile to the Megatron-LM code directory and run the following command:bash example.shMindSpore Transformers
Run the following commands in the MindSpore Transformers code directory:
bash scripts/msrun_launcher.sh "run_mindformer.py \ --config /path/to/example.yaml"
configis the model configuration file, which is stored in the config directory of the MindSpore Transformers code repository.Result comparison
View the output logs of the two models. The log path of Megatron-LM is
logs/${logtime}.loginexample.sh, and that of MindSpore Transformers ismsrun_log/worker_0.loginoutput_dirofexample.yaml. The following table lists the comparison results.Megatron-LM
MindSpore Transformers
Description
iterationepochandstepNumber of global iterations during training. In MindSpore Transformers,
(epoch, step)indicates the current training location, while Megatron-LM uses a singleiteration. The relationship between them is as follows:iteration = (epoch – 1) x steps_per_epoch + steplm losslossTraining loss, which is a core indicator in precision comparison. The value of
lossof MindSpore Transformers is the sum oflm lossandaux loss. The values will be printed separately in the future.learning ratelrLearning rate, which is the precision comparison reference indicator.
grad normglobal normGlobal gradient norm, which is the precision comparison reference indicator.