检查点恢复#
意外情况 —— 网络错误、断电、节点被抢占 —— 都可能中断一个长时间运行的分布式任务。
为了解决这一问题,RLinf 会在每隔 runner.save_interval 步时保存一个完整的检查点,
并允许你从最近的快照恢复,最大限度减少工作损失。
检查点布局#
假设有如下 YAML 片段:
runner:
task_type: math
logger:
log_path: ${runner.output_dir}/${runner.experiment_name}
project_name: rlinf
experiment_name: ${runner.experiment_name}
save_interval: 50
experiment_name: grpo-1.5b
output_dir: ./logs
如果使用 Megatron 作为训练后端,其检查点会出现在 output_dir/experiment_name/checkpoints/ 下,
而如果使用 FSDP/FSDP2 作为训练后端,其检查点会出现在 log_path/experiment_name/checkpoints/ 下。
Megatron 检查点#
Megatron检查点文件结构如下:
logs/grpo-1.5b/checkpoints/
├── global_step_50/
│ ├── actor/
│ │ ├── iter_0000050/
│ │ │ ├── mp_rank_00/
│ │ │ │ ├── distrib_optim.pt
│ │ │ │ └── model_optim_rng.pt
│ │ │ └── mp_rank_01/
│ │ │ ├── distrib_optim.pt
│ │ │ └── model_optim_rng.pt
│ │ └── latest_checkpointed_iteration.txt
│ └── data/
│ └── data.pt
└── global_step_100/
└── …
关键点#
分片权重 ——
mp_rank_*中的文件遵循 Megatron 的张量并行布局;每个 GPU 只会重新加载属于自己的分片。优化器 / RNG 状态 —— 同时 保存了 Adam 参数(
distrib_optim.pt)和随机数生成器,确保恢复后可以比特级复现。数据采样器 ——
data.pt存储了 dataloader,保证不会遗漏或重复样本。
FSDP/FSDP2 检查点#
FSDP/FSDP2 检查点文件结构如下:
-- global_step_2
-- actor
|-- __0_0.distcp
|-- __1_0.distcp
|-- __2_0.distcp
-- __3_0.distcp
FSDP/FSDP2 通过 DCP (torch.distributed.checkpoint) 保存和加载检查点,其结果为一组分布式检查点文件(.distcp)。 每个文件包含模型参数、优化器状态和 RNG 状态的分片。
检查点文件向Pytorch State Dict文件转换#
如果你需要将 FSDP/FSDP2 检查点转换为标准的 Pytorch State Dict 文件用于模型评估或是其他用途, 可以在
检查点保存文件夹下的`model`子文件夹找到以safetensors格式保存的模型 state dict文件。例如,
对于检查点目录 global_step_2,模型 state dict 文件位于 global_step_2/model/。我们会递归地复制模型的配置文件(.json)和代码文件(.py,如果存在的话),
以便你之后可以轻松地重新加载模型。
convert_dcp_to_state_dict.py [-h] --dcp_path DCP_PATH --output_path OUTPUT_PATH
其中 DCP_PATH 是包含 DCP 文件的目录,OUTPUT_PATH 是保存转换后模型 State Dict 文件的路径。
恢复训练#
选择最新的检查点
如果
global_step_150/是编号最高的目录,它就是最新的快照。修改 YAML
runner: resume_dir: ${runner.output_dir}/${runner.experiment_name}/checkpoints/global_step_150
完全按原方式重新启动
启动 Ray,然后运行相同的
run_main_*.sh启动脚本。 RLinf 会自动检测到resume_dir并:在每个节点/rank 上恢复模型分片、优化器、RNG 和 dataloader 状态。
从
global_step_150继续计数 —— 下一个保存的检查点将是global_step_200(因为save_interval为 50)。
小技巧
想验证恢复是否成功,可以查看日志行。 如果下一次训练从 step 150 开始,就说明恢复正常!