权重保存与加载

MindSpore Transformers 动态图(PyNative)训练统一以 Safetensors 格式保存与加载权重。框架在 checkpoint 段(CheckpointConfig)集中配置保存与加载行为:保存由 mindformers/pynative/callback/checkpoint_callback.pyCheckpointCallback 在训练过程中按步触发,加载由 Trainer 在训练启动阶段完成。两者底层分别调用 mindformers.checkpoint.save_checkpointmindformers.checkpoint.load_checkpoint,并通过权重目录内的 common.json 衔接断点续训。

动态图仅用 Safetensors

动态图路径全程以 Safetensors 为唯一权重格式,不涉及 ckpt 与 Safetensors 之间的格式转换,也不需要单独的转换脚本。

术语说明

本文涉及以下术语:

  • 权重加载(Weight Loading)

    仅加载模型参数,不恢复优化器状态和训练进度。常用于微调、蒸馏、迁移学习等场景。

  • 断点续训

    与整段续训含义相同,强调从某个 checkpoint 恢复训练过程。

  • Checkpoint

    训练过程中保存的模型状态文件,用于恢复训练或导出模型。

  • Global Step

    训练过程中已经执行的优化步数,用于学习率调度、训练恢复以及训练统计。

如何选择 Checkpoint 配置

请根据实际需求选择保存和加载方式:

是否需要从中断位置继续训练?
│
├── 是
│   └── 使用整段续训
│       ├── 保存 optimizer 状态
│       ├── 保存训练进度
│       └── 加载时使用 resume 配置
│
└── 否
    │
    ├── 仅用于模型推理
    │   └── 仅保存权重
    │
    └── 用于微调新任务
        └── 加载模型权重,不加载优化器状态

常见场景对应关系:

场景

推荐方式

训练中断恢复

整段续训

集群故障恢复

整段续训

模型微调

权重加载

模型推理部署

权重加载

模型转换导出

权重加载

按场景选哪些字段

下表帮助快速定位每个场景需要关注的字段,详细语义见后文对应小节。

场景

关键字段

概述

关闭保存(只显示 loss 不保存权重)

enable_save: False

不挂保存回调,仅保留 Loss/Monitor

基础定时保存

save_path / save_interleaved_steps / save_max

按步存、保留最近若干份

异步保存

async_save: True

落盘与计算重叠,降低保存阻塞

仅存权重(不保存优化器权重)

no_save_optim: True

体积更小,但无法保存优化器状态

去冗余保存

remove_redundancy: True

多卡分片去重,减小占用

多卡布局缓存

save_global_layout_cache

复用分片元信息,避免每次重算

全量断点续训

load_path + no_load_optim: False

恢复权重、优化器、step 与数据游标

微调仅加载权重

load_path + no_load_optim: True

只取权重,优化器从头开始

多卡均衡加载

load_balanced: True

shard 均衡 + 参数广播,消除冗余参数重复加载,仅多卡分片场景有意义

字段归属

保存字段被 CheckpointCallback 消费,加载字段被 Trainer._load_checkpoint 消费。两类字段都写在同一个 checkpoint 段下,互不影响。


保存

触发机制

保存逻辑全部在 CheckpointCallback 中,关键行为如下:

  • 是否挂载回调Trainer._create_built_in_callbacks 检查 enable_save。当 enable_save=False根本不会构造 CheckpointCallback,只保留 LossCallbackMonitorCallback;此时其余保存字段全部无效。

  • 按步保存on_step_endstate.global_step % save_interleaved_steps == 0 时触发一次保存。

  • 训练结束补存on_train_end 在训练结束时额外保存一份最终权重,确保末尾的训练成果不丢。

  • 去重:内部维护_last_triggered_step,若本次 step 与上次已保存的 step 重名则覆盖。

  • 超额清理save_max 控制最多保留份数,超出时按时间删除最旧的目录,但仅会删除本轮次训练保存的权重。

  • 路径校验save_path 为空会在构造回调时直接抛出 ValueError("save_path must be provided for CheckpointCallback."),因此启用保存时必须配置 save_path

输出目录结构

每次保存会在 save_path 下生成一个按 step 命名的子目录,内含:

  • Safetensors 权重分片(模型分片,按 no_save_optim 决定是否含优化器分片);

  • common.json:续训元信息(见加载一节);

  • 分片布局元数据 metadata.json

字段表

字段

类型

默认值

说明

enable_save

bool

True

是否启用权重保存;为 False 时不挂保存回调。

save_path

str

""

保存目录;启用保存时必填,为空报 ValueError

save_max

int

5

保存权重最大数,超出时按时间删最旧,每次仅对当前训练保存的权重进行删除(特殊场景:若当前 step 已保存过 checkpoint,则覆盖该 step 对应目录)

save_interleaved_steps

int

1000

保存间隔步数;global_step 为其整数倍时触发。

no_save_optim

bool

False

True 时仅保存模型权重,不保存优化器状态。

async_save

bool

False

True 时启用异步保存,落盘与计算重叠。

prefix

str

"checkpoint"

保存文件名前缀。

remove_redundancy

bool

False

True 时去除多卡分片间的冗余数据。

save_global_layout_cache

bool

True

True 时缓存多卡全局分片布局,避免每次保存重算分片元信息。

场景化配置

场景一:基础定时保存

最常用的配置:每隔固定步数存一份,并保留最近若干份。适合大多数训练任务。

checkpoint:
  enable_save: True
  save_path: "./output/ckpt"
  save_interleaved_steps: 1000   # 每 1000 步保存一次
  save_max: 5                    # 最多保留 5 份,超出删最旧
  prefix: "checkpoint"           # 保存的权重名前缀
  no_save_optim: False           # 同时保存优化器状态,便于整段续训(恢复权重、优化器状态和训练进度)
  async_save: False
  remove_redundancy: False
  save_global_layout_cache: True

适用:单数据源、单/多卡常规训练。代价:每次保存阻塞训练直到落盘完成;若 step 数很大、保存频繁,可考虑下面的异步保存。

场景二:异步保存(降低保存阻塞)

当权重较大、同步落盘明显拖慢训练节奏时,开启 async_save。框架在保存前调用 AsyncSaveManager.prepare_before_save,将磁盘写入与后续训练计算重叠。

checkpoint:
  enable_save: True
  save_path: "./output/ckpt"
  save_interleaved_steps: 1000
  save_max: 5
  async_save: True               # 启用异步保存

适用:大模型、保存频繁、磁盘较慢的场景。代价:保存在后台进行,会占用额外内存/线程资源;异常退出时最近一次异步保存可能尚未完成。

场景三:仅存权重(不存优化器)

no_save_optim=True 时保存只包含模型权重,体积显著减小。适合只需要权重产物(如后续仅做微调)的场景。

checkpoint:
  enable_save: True
  save_path: "./output/ckpt"
  save_interleaved_steps: 1000
  save_max: 5
  no_save_optim: True            # 不保存优化器状态

⚠️ 续训影响

no_save_optim=True 保存的权重缺少优化器状态,无法用于严格还原优化器动量和二阶矩的整段续训。若后续要从该权重进行续训,需在加载侧配合 no_load_optim: True

场景四:去冗余保存(减小占用)

多卡训练时,不同 rank 间存在重复的权重分片。remove_redundancy=True 在保存时去除这些冗余数据,减小磁盘占用。

checkpoint:
  enable_save: True
  save_path: "./output/ckpt"
  save_interleaved_steps: 1000
  save_max: 5
  remove_redundancy: True        # 去除分片间冗余

适用:多卡分片、磁盘空间紧张的场景。注意去冗余依赖多卡分片元信息(sharded_tensor_metas,单卡恒为空),单卡下该开关不生效(被静默忽略)

场景五:多卡布局缓存

多卡场景下,保存需要先收集各 rank 的分片元信息(sharded_tensor_metas)。save_global_layout_cache=True(默认)会缓存这份元信息并在后续保存中复用;设为 False 则每次保存后清空缓存、下次重新计算。

checkpoint:
  enable_save: True
  save_path: "./output/ckpt"
  save_interleaved_steps: 1000
  save_max: 5
  save_global_layout_cache: True # 缓存分片布局,加速后续保存

适用:分片布局在训练过程中固定不变的常规多卡训练,保持 True 可省去每次重算开销。仅当分片布局可能在训练中变化、需要每次重新采集时才设为 False。单卡场景不涉及分片,该字段无实际作用。


加载

加载流程

加载在 Trainer.train 启动阶段执行,权重加载目录取自 load_path。核心流程在 Trainer._load_checkpoint

  1. 读取元信息:从 load_path/common.json 读出 CommonInfo(见下表)。

  2. 恢复 step 与数据游标:当 no_load_optim=False(整段续训)时,以 common.json 中的 global_step 作为续训起点;若当前 global_batch_size 与保存时不同,按比例缩放 global_step = global_step * (旧 global_batch_size / global_batch_size),再调用 train_dataset.set_init_step(global_step) 将数据集游标对齐到续训位置,并写回 state.global_step

  3. 加载权重(自动重分片):调用 load_checkpoint,内部用 ReshardLoader 处理分布式重分片,因此保存与加载的并行布局可以不同,框架会自动 reshard 到当前布局。

  4. 均衡加载(可选)load_balanced=True 时走 apply_balance_shard_strategy 求出各 rank 间的冗余参数映射,再用 single_parameter_broadcast 在 rank 间广播参数,从而消除冗余参数的重复读取

  5. 优化器主权重刷新(可选)no_load_optim=True 时不加载优化器,加载完成后调用 optimizer.reload_main_params_from_model(),用刚加载的模型参数刷新 fp32 主权重,使主权重与模型参数对齐。

未配置加载目录

load_path 为空且未传入 checkpoint_path 时不加载任何权重,训练从随机初始化开始。

common.json 记录的字段

common.json 示例

典型的 common.json 内容如下:

{
  "epoch_num": 1,
  "step_num": 100,
  "global_step": 100,
  "loss_scale": "1.0",
  "global_batch_size": 2,
  "ckpt_status": null
}

common.json 由保存侧的 CommonInfo 写出(mindformers/checkpoint/checkpoint.py),续训时用于恢复 step 与数据游标:

字段

说明

续训中的作用

epoch_num

当前训练所处 epoch

元信息记录

step_num

当前 epoch 内的步数

元信息记录

global_step

跨 epoch 的全局训练步数

续训起点;global_batch_size 变化时按比例缩放后驱动 set_init_step

loss_scale

梯度放大系数

元信息记录

global_batch_size

多卡训练的全局批大小

与当前配置比较,决定是否缩放 global_step

ckpt_status

权重健康状态标记

标识该权重是否健康,默认为 null,开启健康权重检测时记录权重健康状态。

字段表

字段

类型

默认值

说明

load_path

str

""

加载目录;为空时不加载,从随机初始化开始。

no_load_optim

bool

False

True 时仅加载模型权重,不加载优化器状态,并刷新 fp32 主权重。

load_balanced

bool

False

True 时通过 shard 均衡 + 参数广播消除冗余参数的重复加载(多卡分片场景)。

load_balanced 适用于以下情况:

  • TP(Tensor Parallel)规模较大;

  • PP(Pipeline Parallel)规模较大;

  • DP(Data Parallel)规模较大;

  • checkpoint 文件数量较多;

  • 加载阶段出现明显 I/O 瓶颈。

典型场景:

TP=8
PP=8
DP=16

在大规模分布式训练中,checkpoint 文件通常分布在多个存储节点上。启用 load_balanced 后,可以减少部分 rank 的加载压力,提高整体恢复效率。

⚠️ load_worker_number 在动态图中未生效

CheckpointConfig 中虽声明了 load_worker_number(默认 1),但动态图加载路径并未消费该字段Trainer._load_checkpoint 调用 load_checkpoint 时未传 reshard_worker_num,其始终使用默认值 1

因此当前配置 load_worker_number 不会改变加载并行度,请勿依赖它加速读取。

场景化配置

场景一:全量断点续训

恢复权重、优化器状态、global_step 与数据集游标,从中断处无缝继续。这是续训的默认形态。

checkpoint:
  load_path: "./output/ckpt/checkpoint_5000"   # 指向某次保存的 step 目录
  no_load_optim: False                         # 加载优化器状态
  load_balanced: False

适用:训练中断后原地恢复。前提:被加载的权重保存时 no_save_optim=False(含优化器状态)。续训整体行为见 断点续训;数据游标恢复见 数据集

场景二:微调仅加载权重

只加载模型权重,优化器从头初始化。适合从预训练权重出发做下游微调,不希望继承预训练阶段的优化器动量。

checkpoint:
  load_path: "./pretrained/ckpt"
  no_load_optim: True              # 不加载优化器状态

no_load_optim=True 时不会基于 common.json 缩放/续接 global_step,训练步从配置起点开始;加载完成后框架自动调用 reload_main_params_from_model() 刷新 fp32 主权重,避免主权重与模型参数错位。

场景三:多卡均衡加载

load_balanced=True 时,框架用 shard 均衡策略求出各 rank 间的冗余参数映射,再通过参数广播让每个冗余参数只被读取一次,从而消除冗余参数的重复加载、降低分布式加载的显存与 IO。

checkpoint:
  load_path: "./output/ckpt/checkpoint_5000"
  no_load_optim: False
  load_balanced: True              # 跨卡均衡加载

适用条件

load_balancedapply_balance_shard_strategy 进行分片重分布,仅在多卡分片场景下有意义:单卡或无分片时不会带来收益。

并行维度与分片关系见 分布式并行训练


完整示例:训练中断后恢复训练

第一步:配置保存

checkpoint:
  enable_save: True
  save_path: "./checkpoints"
  save_interleaved_steps: 1000
  no_save_optim: False

训练过程中生成:

checkpoints/
├── iteration_00000100/
├── iteration_00000200/
└── latest_checkpointed_iteration.txt

第二步:训练中断

假设训练在 step=2300 时因节点故障退出。

最近一次成功保存的 checkpoint 为:

iteration_00000200

第三步:配置加载

checkpoint:
  load_path: "./checkpoints"

第四步:恢复过程

重新启动训练,系统自动完成:

  1. 加载模型权重;

  2. 加载优化器状态;

  3. 恢复随机数状态;

  4. 恢复 global_step;

  5. 恢复数据集读取位置。

恢复完成后:

global_step = 2000

训练将从 checkpoint 对应位置继续执行,而不是从头开始训练。


第五步:继续训练

2000 -> 2001 -> 2002 -> ...

学习率调度、优化器动量以及数据集游标均保持连续。

相关文档