Safetensors Weights

View Source On Gitee

Overview

Safetensors is a reliable and portable machine learning model storage format from Huggingface for storing Tensors securely and with fast storage (zero copies). This article focuses on how MindSpore Transformers supports saving and loading of this file format to help users use weights better and faster.

Safetensors Weights Samples

There are two main types of Safetensors files: complete weights files and distributed weights files. Below are examples of how they are obtained and the corresponding files.

Complete Weights

Safetensors complete weights can be obtained in two ways:

  1. Download directly from Huggingface.

  2. After MindSpore Transformers distributed training, the weights are generated by merge script.

Huggingface Safetensors example catalog structure is as follows:

qwen2_7b
 └── hf_unified_safetenosrs
        ├── model-00001-of-00004.safetensors
        ├── model-00002-of-00004.safetensors
        ├── model-00003-of-00004.safetensors
        ├── model-00004-of-00004.safetensors
        └── model.safetensors.index.json        # Huggingface weight parameter and file storage relationship mapping json file

MindSpore Safetensors example catalog structure is as follows:

qwen2_7b
 └── ms_unified_safetenosrs
        ├── model-00001-of-00004.safetensors
        ├── model-00002-of-00004.safetensors
        ├── model-00003-of-00004.safetensors
        ├── model-00004-of-00004.safetensors
        ├── hyper_param.safetensors            # Hyperparameter files for training task records
        └── param_name_map.json                # MindSpore weight parameter and file storage relationship mapping json file

Distributed Weights

Safetensors distributed weights can be obtained in two ways:

  1. Generated by distributed training with MindSpore Transformers.

  2. Using format conversion script, the original distributed ckpt weights are changed to the Safetensors format.

Distributed Safetensors example catalog structure is as follows:

qwen2_7b
 └── distributed_safetenosrs
        ├── rank_0
            └── qwen2_7b_rank_0.safetensors
        ├── rank_1
            └── qwen2_7b_rank_1.safetensors
        ...
        └── rank_x
            └── qwen2_7b_rank_x.safetensors

Weight Saving

Overview

In the training process of deep learning models, saving the model weights is a crucial step. The weight saving function allows us to store the model parameters at any stage of training, so that users can restore, continue training, evaluate or deploy after training is interrupted or completed. At the same time, by saving weights, experimental results can be reproduced in different environments.

Currently, MindSpore TransFormer supports reading and saving weight files in the safetensors format.

Directory Structure

During the training process, MindSpore Transformers will generate a weight saving folder: checkpoint in the output directory (same as training log, default is ./output).

If the configuration item save_network_params:True is set in yaml file, an additional weight saving folder checkpoint_network will be generated.

Folder

Description

checkpoint

Save model weights, optimizer state, step and epoch in safetensors files, which can be used to restore training from breakpoints.

checkpoint_network

Only the model weight parameters are saved in the safetensors file, which is suitable for subsequent fine-tuning, reasoning, and evaluation. It does not support breakpoint continuation.

checkpoint Directory Structure

Take an 8-rank task as an example, the weight files in the output folder are saved in the following format:

output
    ├── checkpoint
        ├── rank_0
            ├── meta.json
            └── {prefix}-{epoch}_{step}.safetensors
        ...
        └── rank_7
            ├── meta.json
            └── {prefix}-{epoch}_{step}.safetensors
    └──checkpoint_network
        ├── rank_0
            └── {prefix}-{epoch}_{step}.safetensors
        ...
        └── rank_7
            └── {prefix}-{epoch}_{step}.safetensors

Weight-related File Description

File

Description

meta.json

Records the epoch, step and weight name of the last saved weight. Each rank process maintains a meta.json file independently.

{prefix}-{epoch}_{step}.safetensors

The saved weight file, prefix contains rank_id information, and the format is {prefix}-{epoch}_{step}.safetensors. If a file with the same prefix already exists, the system will automatically increment the suffix.
When data sinking is enabled, the epoch position is calculated as \(\frac{CurrentTotalStepNumber}{SinkSize} = \frac{((CurrentEpoch-1)*StepsPerEpoch+CurrentStepInEpoch)}{SinkSize}\), and step is fixed to sink_size.

Configuration and Usage

YAML Parameter Configuration

Users can control the weight saving behavior by modifying the configuration file. The following are the main parameters:

Users can modify the fields under CheckpointMonitor in the yaml configuration file to control the weight saving behavior.

Taking DeepSeek-V3 pre-training yaml as an example, the following configuration can be made:

# callbacks
callbacks:
  ...
  - type: CheckpointMonitor
    prefix: "deepseekv3"
    save_checkpoint_steps: 1000
    keep_checkpoint_max: 5
    save_network_params: False
    integrated_save: False
    async_save: False
    checkpoint_format: "safetensors"
  ...

The meaning of this configuration is: save the safetensors weights every 1000 steps, store up to 5 weights at the same time, do not merge and save the split Tensor in parallel scenarios, and do not use asynchronous method to save weight files.

The main parameters concerning the preservation of the weight configuration are listed in the following table:

Parameter names

Descriptions

Description of values

prefix

The prefix name of the model weights file, which can be used to refer to the model name.

(str, optional) - Default: "CKP" .

save_checkpoint_steps

Save the weights several steps of training.

(int, optional) - Default: 1 , model weights are not saved when not set.

keep_checkpoint_max

The maximum number of weight files that can be saved at the same time, and when the limit is reached the oldest weight file will be deleted when the weights are saved.

(int, optional) - Default: 5 , the number of weights under the folder is not monitored and deleted when not set.

integrated_save

Whether to merge and save split Tensors in parallel scenarios. The merge and save feature is only supported in automatic parallel scenarios, not in manual parallel scenarios.

(bool, optional) - Default: False

async_save

Whether to save safetensors files asynchronously.

(bool, optional) - Asynchronous threads are used by default when True. Default: False .

checkpoint_format

The format of the output file, which needs to be configured as safetensors.

(str, optional) - Format in which model weights are saved. Supports "ckpt", "safetensors" .Default: ckpt . (Note: the ckpt format will be sunsetted in a later release, the safetensors format is recommended.)

remove_redundancy

Whether redundancy is removed when saving model weights.

(bool, optional) - Default: False .

save_network_params

Whether to additionally save only network parameters.

(bool, optional) - Whether to additionally save only network parameters. Default: False .

If you want to know more about CheckpointMonitor, you can refer to CheckpointMonitor API documentation.

Weight Loading

Overview

MindSpore Transformers supports training, inference, and resumable training in a full range of scenarios with single and multiple cards, including full weights and distributed weights. Please refer to the following instructions to adjust the configuration for the corresponding scenarios.

Configuration Description

Parameter names

Descriptions

load_checkpoint

The path to the folder where the weights are preloaded.
- In case of full weights, fill in the path to the folder where the slices/individual weight files are located.
Note: Huggingface safetensor weights loading is supported (currently only Llama series models are supported). During the online loading process, a copy of the converted MindSpore safetensor weights file is saved to /output/ms_safetensors.
- In case of distributed weights, they need to be stored in model_dir/rank_x/xxx.safetensor format, with the folder path filled in as model_dir.

load_ckpt_format

The format of the loaded model weights, optionally ckpt, safetensors, defaults to ckpt.
Loading weights in safetensors format needs to change this configuration to safetensors.

use_parallel

Whether to load in parallel.

auto_trans_ckpt

Whether to enable the online slicing function.
- If loading weight is full weight:
a. when use_parallel: True, it is judged as distributed loading, auto_trans_ckpt: True needs to be set synchronously to turn on online slicing.
b. When use_parallel: False, it is judged as single card loading, you need to set auto_trans_ckpt: False synchronously to disable the online slicing function.
- If loading weight is distributed weight:
a. Without changing the original slicing strategy, you need to set auto_trans_ckpt: False to load directly according to the original slicing strategy.
b. To change the original slicing strategy, set auto_trans_ckpt: True and configure src_strategy_path_or_dir to be the original slicing strategy file path.
When the task is pulled up, the weights are merged online into full weights, which are sliced and loaded according to the parallelism strategy set in the configuration file. The online merged weights are saved in the current directory under the /output/unified_checkpoint file.

Complete Weight Loading

Single-card Loading

# configuration file
load_checkpoint: '/qwen2_7b/unified_safetenosrs'    # Load full weights file path
load_ckpt_format: 'safetensors'                     # Load weight file format
auto_trans_ckpt: False                              # Full weights + single card loading requires this configuration item to be turned off
use_parallel: False                                 # single card loading
parallel_config:                                    # Configure the target distributed strategy
  data_parallel: 1
  model_parallel: 1
  pipeline_stage: 1

Multi-cards Loading

# configuration file
load_checkpoint: '/qwen2_7b/unified_safetenosrs'    # Load full weights file path
load_ckpt_format: 'safetensors'                     # Load weight file format
auto_trans_ckpt: True                               # This configuration item needs to be turned on for full weights + distributed loading to turn on online slicing
use_parallel: True                                  # Multi-cards loading
parallel_config:                                    # Configure the target distributed strategy
  data_parallel: 1
  model_parallel: 4
  pipeline_stage: 1

Distributed Weight Loading

Multi-card Loading-Original Slicing Strategy

# configuration file
load_checkpoint: '/output/distributed_safetenosrs'  # Load source distributed weights file paths
load_ckpt_format: 'safetensors'                     # Load weight file format
auto_trans_ckpt: False                              # Disable the online slicing function
parallel_config:                                    # Configure the target distributed strategy
  data_parallel: 2
  model_parallel: 4
  pipeline_stage: 1

Multi-Card Loading - Changing the Slicing Strategy

# configuration file
load_checkpoint: '/output/distributed_safetenosrs'  # Load source distributed weights file paths
src_strategy_path_or_dir: '/output/src_strategy'    # Load source strategy file for merging source distributed weights into full weights
load_ckpt_format: 'safetensors'                     # Load weight file format
auto_trans_ckpt: True                               # Able the online slicing function
parallel_config:                                    # Configure the target distributed strategy
  data_parallel: 4
  model_parallel: 2
  pipeline_stage: 1

In large cluster scale scenarios, to avoid the online merging process taking too long and occupying training resources, it is recommended to pass in the original distributed weights file after merge complete weights offline, when there is no need to pass in the path of the source cut-partitioning strategy file.

Special Scenarios

Physical Machine Multi-machcine Multi-card Training

Large-scale models usually need to be trained by clusters of multiple servers. Weight slicing conversion needs to rely on the target slicing strategy file after the compilation is completed. In this multi-machine and multi-card scenario, if there is a shared disk between servers and the generated strategy file is in the same directory, you can use the automatic conversion function; if there is no shared disk between servers, you need to manually copy the strategy file and then carry out the conversion function. The following is an example of two servers and 16 cards training.

Scenario 1: There are shared disks between servers

In scenarios where there are shared disks between servers, you can use MindSpore Transformers Auto-Weight Conversion feature to automatically perform weight conversion prior to multi-computer, multi-card training. Assuming that /data is a shared disk on the server and the project code for MindSpore Transformers is located under the data/mindformers path.

Parameter Configuration:

output_dir: './output'                              # The strategy file is generated under ./output/strategy, which is used to slice the weights online.
load_checkpoint: '/qwen2_7b/unified_safetenosrs'    # Load full weights file path
load_ckpt_format: 'safetensors'                     # Load weight file format
auto_trans_ckpt: True                               # This configuration item needs to be turned on for full weights + distributed loading to turn on online slicing
train_dataset: &train_dataset
  data_loader:
    type: MindDataset
    dataset_dir: "/worker/dataset/wiki103/"
    shuffle: True
parallel_config:                                    # Configuring a 16-card distributed strategy (for information only)
  data_parallel: 2
  model_parallel: 4
  pipeline_stage: 2
  micro_batch_num: 2
  vocab_emb_dp: True
  gradient_aggregation_group: 4
  micro_batch_interleave_num: 1

Initiating tasks:

Use mindformers/scripts/msrun_launcher.sh to initiate tasks.

# The first server (master node)
bash scripts/msrun_launcher.sh "run_mindformer.py \
  --config {CONFIG_PATH} \
  --run_mode train" \
  16 8 ${ip} ${port} 0 output/msrun_log False 300
# The second server (sub-node)
bash scripts/msrun_launcher.sh "run_mindformer.py \
  --config {CONFIG_PATH} \
  --run_mode train" \
  16 8 ${ip} ${port} 1 output/msrun_log False 300

Scenario 2: No shared disks between servers

In the case where there is no shared disk between servers, you need to perform an offline merge and forward operation on the generated strategy files before enabling the online slicing function. The following steps describe how to perform this operation and start a multi-machine, multi-card training task.

1.Getting Distributed Strategies

Before performing the offline weight conversion, it is first necessary to obtain the distributed policy files of each node.

  # Set only_save_strategy to True to get a distributed strategy file, which is generated and the task exits automatically
  only_save_strategy: True

  # Configure dataset paths
  train_dataset: &train_dataset
    data_loader:
      type: MindDataset
      dataset_dir: "/worker/dataset/wikitext_2048/"
      shuffle: True

  # Configure 16-card distributed strategy (for information only)
  parallel_config:
    data_parallel: 2
    model_parallel: 4
    pipeline_stage: 2
    micro_batch_num: 2
    vocab_emb_dp: True
    gradient_aggregation_group: 4
    micro_batch_interleave_num: 1

The strategy files for each node will be saved separately in their respective output/strategy directories. For example, node 0 will save only the ckpt_strategy_rank_0-7.ckpt file and node 1 will save only the ckpt_strategy_rank_8-15.ckpt file. Subsequently, the strategy files of all nodes need to be centralized on the same server for subsequent operations, and the directories and files after centralization are as follows.

output
    ├── strategy
        ├── ckpt_strategy_rank_0.ckpt
        ...
        ├── ckpt_strategy_rank_7.ckpt
        ├── ckpt_strategy_rank_8.ckpt
        ...
        └── ckpt_strategy_rank_15.ckpt

2. Merging Distributed Strategy

Call the strategy merge interface to merge all strategy files after centralization into one file for subsequent weight slicing.

import mindspore as ms
ms.parallel.merge_pipeline_strategys("/output/strategy", "/output/merged_strategy/dst_strategy.ckpt")

3.Weight Slice Loading

Distribute strategy files + online slicing (recommended):

Distribute the merged strategy file dst_strategy.ckpt to each node under the . /output/merged_strategy/ directory, turn on auto-slicing, and pull up the training task again. The configuration file for each node needs to be modified.

output_dir: './output'                              # Make sure that each node under ./output/merged_strategy/ has the merged strategy file
load_checkpoint: '/qwen2_7b/unified_safetenosrs'    # Load full weights file path
load_ckpt_format: 'safetensors'                     # Load weight file format
auto_trans_ckpt: True                               # This configuration item needs to be turned on for full weights + distributed loading to turn on online slicing

Offline slicing + distributing distributed weights:

According to the weight slicing guide, the full weights are first sliced offline into distributed weights files, which are then distributed to each machine, the automatic slicing is turned off, and load_checkpoint is configured as the distributed weights path. Each node's configuration file needs to be modified.

Because distributed weight files are generally larger than strategy files and distribution operations are more time-consuming, the first approach is more recommended.

load_checkpoint: '/output/distributed_safetenosrs'  # Load distributed weights file path
load_ckpt_format: 'safetensors'                     # Load weight file format
auto_trans_ckpt: False                              # Distributed weight loading with online slicing turned off

4. Initiating tasks

Use mindformers/scripts/msrun_launcher.sh to initiate tasks.

# The first server (master node)
bash scripts/msrun_launcher.sh "run_mindformer.py \
  --config {CONFIG_PATH} \
  --run_mode train" \
  16 8 ${ip} ${port} 0 output/msrun_log False 300
# The second server (sub-node)
bash scripts/msrun_launcher.sh "run_mindformer.py \
  --config {CONFIG_PATH} \
  --run_mode train" \
  16 8 ${ip} ${port} 1 output/msrun_log False 300

Weight Features

De-redundant Saving and Loading

Currently when MindSpore Transformers saves weights, by default it duplicates multiple consistent weight files in the dp/opt domain, resulting in additional storage overhead and burden. The following configuration and usage methods can be used to realize dp/opt de-redundant saving and loading, effectively reducing the storage pressure under large-scale clusters of thousands of cards and above. This feature is only effective under distributed weights, and complete weights do not involve de-redundancy.

The following configuration is enabled when saved:

callbacks:
  - type: CheckpointMonitor
    checkpoint_format: safetensors                  # Save weights file format
    remove_redundancy: True                         # Turn on de-redundancy when saving weights

The saved distributed weights are of different sizes, and the total weight file is smaller than that before the de-redundancy feature is turned on:

output
    ├── checkpoint
        ├── rank_0
            └── example-1_1.safetensors  #file size:5.2G
        ├── rank_1
            └── example-1_1.safetensors  #file size:5.2G
        ...
        ├── rank_6
            └── example-1_1.safetensors  #file size:4.1G
        └── rank_7
            └── example-1_1.safetensors  #file size:4.1G

Turn on the following configuration when loading:

load_ckpt_format: 'safetensors'    # Load weight file format
remove_redundancy: True            # Turn on de-redundancy when loading weights

MindSpore Transformers version 1.5.0 and below may cause accuracy anomalies when the saved and loaded configuration items for de-redundancy are not the same, please make sure the configuration is correct. Version 1.5.0 and above will automatically identify and load the weights based on whether they are de-redundant or not, so you don't need to pay attention to the loaded configuration.

Loading Hugging Face safetensors

By adding the pretrained_model_dir field in the configuration file, specify a folder directory that stores all model files downloaded from Hugging Face (including config. json, tokenizer, weight files, etc.), and then directly instantiated the model configuration and tokenizer, loading Hugging Face weights.

Taking Qwen3 as an example, the meaning of the fields configured in the YAML configuration file is as follows: the folder directory specified in pretrained_model_dir stores the Qwen3 model configuration file, tokenizer file, and weight file on Hugging Face.

use_legacy: False
load_checkpoint : ''
pretrained_model_dir: "/path/qwen3"
model:
  model_config:
    compute_dtype: "bfloat16"
    layernorm_compute_dtype: "float32"
    softmax_compute_dtype: "float32"
    rotary_dtype: "bfloat16"
    params_dtype: "bfloat16"
generation:
  max_length: 30

Parameter Descriptions:

  • use_legacy - This parameter is set to False to enable Hugging Face loading

  • load_checkpoint - User defined weight loading path, high priority

  • pretrained_model_dir - Hugging Face weight, low priority

The priority for selecting the weight path of load_checkpoint is high. When configuring this parameter, the weight files in the path of pretrained_model_dir will not be loaded.

When load_checkpoint is not configured, if there are safetensor weight files in the path 'pretrained_model_dir', it will be loaded. If it does not exist, the weights will be randomly initialized.

This feature currently only supports Qwen3 series and DeepSeek V3 series models in fine-tuning/inference scenarios, and is being continuously updated.

Weight Slicing and Merging

Overview

In the current distributed training and inference environment, when users need to change the distributed strategy, they need to merge the existing distributed weights into the complete weights before completing the weight loading by online slicing/offline slicing. In order to meet the needs of weight conversion in different scenarios, you can refer to the following scripts and interfaces to realize the functions of weight multi-card merging single card and single card slicing multi-card.

Weight Merging

Usage Directions

Use the safetensors weights merging script provided by MindSpore Transformers to perform safetensors weight merging as follows. The format of the merged weights is complete-weights.

python toolkit/safetensors/unified_safetensors.py \
  --src_strategy_dirs src_strategy_path_or_dir \
  --mindspore_ckpt_dir mindspore_ckpt_dir\
  --output_dir output_dir \
  --file_suffix "1_1" \
  --has_redundancy has_redundancy

Parameter Descriptions

  • src_strategy_dirs: The path to the distributed strategy file corresponding to the source weights, usually saved by default in the output/strategy/ directory after starting the training task. Distributed weights need to be filled in according to the following:

    • Source weights turn on pipeline parallelism: The weight conversion is based on the merged strategy files, fill in the path to the distributed strategies folder. The script will automatically merge all ckpt_strategy_rank_x.ckpt files in the folder and generate merged_ckpt_strategy.ckpt in the folder. If merged_ckpt_strategy.ckpt already exists, you can just fill in the path to that file.

    • Source weights turn off pipeline parallelism: The weight conversion can be based on any of the strategy files, just fill in the path to any of the ckpt_strategy_rank_x.ckpt files.

    Note: If merged_ckpt_strategy.ckpt already exists in the strategy folder and the folder path is still passed in, the script will first delete the old merged_ckpt_strategy.ckpt and merge it to create a new merged_ckpt_strategy.ckpt for weight conversion. Therefore, make sure that the folder has sufficient write permissions, otherwise the operation will report an error.

  • mindspore_ckpt_dir: Distributed weights path, please fill in the path of the folder where the source weights are located, the source weights should be stored in model_dir/rank_x/xxx.safetensors format, and fill in the folder path as model_dir.

  • output_dir: The path where the target weights will be saved. The default value is "/new_llm_data/******/ckpt/nbg3_31b/tmp", i.e., the target weights will be placed in the /new_llm_data/******/ckpt/nbg3_31b/tmp directory.

  • file_suffix: The naming suffix of the target weights file. The default value is "1_1", i.e. the target weights will be looked up in the *1_1.safetensors format.

  • has_redundancy: Whether the merged source weights are redundant weights, defaults to True.

  • filter_out_param_prefix: You can customize the parameters to be filtered out when merging weights, and the filtering rules are based on prefix name matching. For example, optimizer parameter "adam_".

  • max_process_num: Maximum number of processes to merge. Default value: 64.

Samples

Scenario one:

If merging to remove redundant safetensors weights, you can fill in the parameters as follows:

python toolkit/safetensors/unified_safetensors.py \
  --src_strategy_dirs src_strategy_path_or_dir \
  --mindspore_ckpt_dir mindspore_ckpt_dir\
  --output_dir output_dir \
  --file_suffix "1_1" \
  --has_redundancy False

Scenario two:

If merge filtering the Adam optimizer's safetensors weights, you can fill in the parameters as follows:

python toolkit/safetensors/unified_safetensors.py \
  --src_strategy_dirs src_strategy_path_or_dir \
  --mindspore_ckpt_dir mindspore_ckpt_dir\
  --output_dir output_dir \
  --file_suffix "1_1" \
  --filter_out_param_prefix "adam_"

Weight Slicing

Usage Directions

Use strategy merging interface and slicing saving interface provided by MindSpore. The safetensors weights are sliced and saved offline as follows. The format of the sliced weights is distributed weights.

import mindspore as ms
# step1: Merge target slicing strategy document
ms.parallel.merge_pipeline_strategys("output/strategy", "output/merged_strategy/dst_strategy.ckpt")
# step2: Based on the merged target slicing strategy and the complete weights, the weights are sliced and saved as distributed weights
ms.load_distributed_checkpoint(
            network=None,
            predict_strategy='output/merged_strategy/dst_strategy.ckpt',
            unified_safetensors_dir='/path/unified_safetensors',
            dst_safetensors_dir='/path/distributed_safetensors',
            format='safetensors',
            max_process_num=64
        )

Parameter Descriptions

  • network (Cell) - Distributed Predictive Network, when format is safetensors, network is passed as None, at which point the interface executes the save mode.

  • predict_strategy (Union[dict, str]) - The target slice strategy file. Default: None .

  • unified_safetensors_dir (str) - Directory of complete weights files. Default: None .

  • dst_safetensors_dir (str) - The save directory for the weights in the save mode scenario.

  • max_process_num (int) - Maximum number of processes. Default: 64.

Weights Format Conversion

Converting Ckpt ot Safetensors

MindSpore Transformers stock weights file is in ckpt format, which can be formatted into safetensors file in the following two ways.

Interface Calling

Call Mindspore format conversion interface to implement.

import mindspore as ms
ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./output/safetensors_path/")
#Parameter descriptions
#file_path (str) - Path to directory containing checkpoint files or path to individual checkpoint files (.ckpt)
#save_path (str, optional) - Path to the directory where safetensors files are stored. Default: None

Training Tasks

The MindSpore Transformers training task is started after adjusting the configuration file, and the conversion is achieved by loading in ckpt format and saving in safetensor format on a trial basis.

load_checkpoint: 'output/checkpoint/'               # Load weights file path
load_ckpt_format: 'ckpt'                            # Load weight file format为ckpt
callbacks:
  - type: CheckpointMonitor
    checkpoint_format: 'safetensors'                # Save the weights file format as safetensor

Usage Example

Examples of Training Tasks

If you use the full weighted multicard online fine-tuning, take the Qwen2.5-7B model as an example and modify the configuration item finetune_qwen2_5_7b_8k.yaml

# Modified configuration
load_checkpoint: '/qwen2.5_7b/hf_unified_safetenosrs' # Load weights file path
load_ckpt_format: 'safetensors'                     # Load weights file format
auto_trans_ckpt: True                               # This configuration item needs to be turned on for complete weights to enable the online slicing feature
parallel_config:                                    # Configure the target distributed strategy
  data_parallel: 2
  model_parallel: 4
  pipeline_stage: 1
callbacks:
  - type: CheckpointMonitor
    checkpoint_format: safetensors                  # Save weights file format

If you use distributed weights multicard online fine-tuning, take the Qwen2.5-7B model as an example, modify the configuration item finetune_qwen2_5_7b_8k.yaml:

# Modified configuration
load_checkpoint: '/qwen2.5_7b/distributed_safetenosrs' # Load weights file path
load_ckpt_format: 'safetensors'                      # Load weights file format
parallel_config:                                     # Configure the target distributed strategy
  data_parallel: 2
  model_parallel: 4
  pipeline_stage: 1
callbacks:
  - type: CheckpointMonitor
    checkpoint_format: safetensors                  # Save weights file format

Execute the command when completed:

bash scripts/msrun_launcher.sh "run_mindformer.py \
 --config research/qwen2_5/finetune_qwen2_5_7b_8k.yaml \
 --train_dataset_dir /{path}/alpaca-data.mindrecord \
 --register_path research/qwen2_5 \
 --use_parallel True \
 --run_mode finetune" 8

After the task is executed, a checkpoint folder is generated in the mindformers/output directory, while the model files are saved in that folder.

For more details, please refer to Introduction to SFT fine-tuning and Introduction to Pre-training.

Example of an Inference Task

If you use complete weighted multicard online inference, take the Qwen2.5-7B model as an example, and modify the configuration item predict_qwen2_5_7b_instruct.yaml:

# Modified configuration
load_checkpoint: '/qwen2.5_7b/hf_unified_safetenosrs' # Load weights file path
load_ckpt_format: 'safetensors'                     # Load weights file format
auto_trans_ckpt: True                               # This configuration item needs to be turned on for complete weights to enable the online slicing function
parallel_config:
  data_parallel: 1
  model_parallel: 2
  pipeline_stage: 1

If you use distributed weighted multicard online inference, take the Qwen2.5-7B model as an example, modify the configuration item predict_qwen2_5_7b_instruct.yaml:

# Modified configuration
load_checkpoint: '/qwen2.5_7b/distributed_safetenosrs' # Load weights file path
load_ckpt_format: 'safetensors'                      # Load weights file format
parallel_config:
  data_parallel: 1
  model_parallel: 2
  pipeline_stage: 1

Execute the command when completed:

bash scripts/msrun_launcher.sh "python run_mindformer.py \
--config research/qwen2_5/predict_qwen2_5_7b_instruct.yaml \
--run_mode predict \
--use_parallel True \
--register_path research/qwen2_5 \
--predict_data 'I love Beijing, because'" \
2

The results of executing the above single-card inference and multi-card inference commands are as follows:

'text_generation_text': [I love Beijing, because it is a city with a long history and culture.......]

For more details, please refer to: Introduction to Inference

Examples of Resumable Training after Breakpoint Tasks

MindSpore Transformers supports step-level resumable training after breakpoint, which allows you to save a model's checkpoints during training and load the saved checkpoints to restore the previous state to continue training after a break in training.

If you use distributed weight multicard resumable training and do not change the slicing strategy, modify the configuration item and start the original training task:

# Modified configuration
load_checkpoint: '/output/checkpoint'                # Load source distributed weights file path
load_ckpt_format: 'safetensors'                      # Load weights file format
resume_training: True                                # Resumable training after breakpoint switch
callbacks:
  - type: CheckpointMonitor
    checkpoint_format: safetensors                   # Save weights file format

If the distributed weight multi-card training is renewed and the slicing strategy is changed, it is necessary to pass in the path of the source slicing strategy file and start the original training task after modifying the configuration items:

# Modified configuration
load_checkpoint: '/output/checkpoint'               # Load source distributed weights file path
src_strategy_path_or_dir: '/output/src_strategy'    # Load source strategy file for merging source distributed weights into full weights
load_ckpt_format: 'safetensors'                     # Load weights file format
auto_trans_ckpt: True                               # Enable online slicing
resume_training: True                               # Resumable training after breakpoint switch
parallel_config:                                    # Configure the target distributed strategy
  data_parallel: 2
  model_parallel: 4
  pipeline_stage: 1
callbacks:
  - type: CheckpointMonitor
    checkpoint_format: safetensors                  # Save weights file format

For more details, please refer to: Introduction to Breakpoints.