使用 FSDP+HuggingFace 添加新模型 SFT 训练#
本文档重点介绍如何使用 HuggingFace Transformers 库与 PyTorch FSDP(Fully Sharded Data Parallel,全分片数据并行) 来训练和生成模型。它支持 HuggingFace 中实现的任意模型,只要兼容 PyTorch 即可。 作为示例,本节将提供一个逐步的操作流程,展示如何按照 sft 模式将一个新的 HuggingFace 模型集成到 RLinf 中。
前置条件#
熟悉 HuggingFace Transformers 库
理解 RLinf 框架架构
掌握 PyTorch 与分布式训练知识
本文目标#
你将学会:把一个“新模型”接入 RLinf 的 SFT 训练流程,并成功跑通训练 / 评估 / 断点续训。
本文基于 RLinf 当前 SFT 主流程:
rlinf/runners/sft_runner.py- 训练调度器rlinf/workers/sft/fsdp_sft_worker.py- SFT Worker 基类
当前 RLinf 中 SFT 的主要步骤分为如下几个部分:
启动 Runner
Runner 初始化 Worker(加载模型、优化器、数据)
每个 step 调用 Worker 的
run_training()到达条件后调用
run_eval()/save_checkpoint()重复直到训练结束
在代码里对应关系:
SFTRunner.run(): - 调self.actor.run_training()- 根据val_check_interval和save_interval决定 eval/saveFSDPSftWorker.run_training(): - 从 dataloader 拿 batch - 调用你实现的get_train_model_output(batch)- backward + optimizer step + lr scheduler stepFSDPSftWorker.run_eval(): - 逐 batch 调用你实现的get_eval_model_output(batch)- 汇总进行 sft 模型效果的评估eval_accuracy
所以你要适配新模型,核心需要实现在 rlinf/workers/sft/fsdp_sft_worker.py 中的三个抽象新方法,才能将新数据集以及新模型接入到 RLinf 的 SFT 训练流程中:
@abstractmethod
def build_dataloader(self):
raise NotImplementedError
@abstractmethod
def get_train_model_output(self, batch: dict[str, Any]):
raise NotImplementedError
@abstractmethod
def get_eval_model_output(self, batch: dict[str, Any]):
raise NotImplementedError
训练前置条件#
开始适配前请确认:
下载需要 sft 训练的新模型权重(HF 路径或本地路径)
下载需要 sft 训练的新数据集(文本 / 图文 / 多模态)
理解训练数据格式(文本 / 图文 / 多模态)以及如何进行预处理
你知道监督目标(如 next-token loss、分类准确率)
准备好 eval 数据集进行模型验证
识别模型类型#
让 RLinf 配置文件识别你的模型类型
RLinf 通过 SupportedModel 识别模型类型。对于自定义 SFT 模型,
你已经不需要再直接修改 SupportedModel 源码。
推荐做法是在 build_config(...) 或训练启动前先注册模型类型,
然后在 YAML 中把 actor.model.model_type 设为注册后的值。
示例:
from rlinf.config import SupportedModel
SupportedModel.register("my_new_model")
示例 YAML:
actor:
model:
model_type: "my_new_model"
model_path: "/path/to/your/model"
确保 get_model 可返回模型#
确保 FSDP已经支持你的模型, get_model(...) 能返回你的模型
FSDPSftWorker.model_provider_func() 会调用:
model = get_model(self.cfg.actor.model)
必须保证 FSDPModelManager.model_provider_func() 能返回你的模型:
get_model能识别my_new_model返回对象支持训练前向(通常是
model(..., labels=...)返回loss)
创建 Worker 子类#
新建一个 Worker 子类,实现 build_dataloader、get_train_model_output、get_eval_model_output
建议新建文件,例如:
rlinf/workers/sft/fsdp_my_model_sft_worker.py
继承 FSDPSftWorker,实现 3 个方法。
from typing import Any
import torch
from omegaconf import DictConfig
from rlinf.workers.sft.fsdp_sft_worker import FSDPSftWorker
class FSDPMyModelSftWorker(FSDPSftWorker):
def __init__(self, cfg: DictConfig):
super().__init__(cfg)
def build_dataloader(self, data_paths: list[str], eval_dataset: bool = False):
# 1) 构建 dataset
# 3) 返回 data_loader 和 data_config(dict)
...
return data_loader, {"num_samples": len(dataset)}
def get_train_model_output(self, batch: dict[str, Any]):
# 模型的核心训练过程
# 返回 loss(Tensor)
...
return loss
def get_eval_model_output(self, batch: dict[str, Any]):
# 模型的核心评估过程
# 返回当前 batch 的正确样本数(整数)
...
return correct_count
实现 build_dataloader#
build_dataloader 方法用于构建数据加载器,你需要确保返回的数据加载器能够正确地处理训练和评估数据。
你必须保证 batch 字段和后续训练函数一致。
run_training() 内部是:
batch = next(self.data_iter)losses = self.get_train_model_output(batch)
也就是说你在 get_train_model_output 里访问的 key,必须由 collate 产出。
建议 checklist:
训练时 batch 至少有: -
input_ids``(或你的同义字段) - ``attention_mask``(可选,但建议有) - ``labels或可构造 labels 的字段评估时 batch 至少有: - 推理输入 - 参考答案(用于算准确率)
常见错误:
collate_fn输出list[dict],但训练代码当成dict用某些样本缺多模态字段,导致 batch 拼接错位
eval 还在
drop_last=True,导致评估样本被丢弃
实现 get_train_model_output#
get_train_model_output 方法用于获取模型的训练输出,你需要确保返回的输出能够正确地进行训练。
FSDPSftWorker 会对你返回的 loss 做:
支持 list/tuple/tensor 自动归一
gradient accumulation
scaler.backward
所以你只要保证最后返回的是 loss(或可堆叠 loss 列表)。
标准 CausalLM 写法(推荐):
def get_train_model_output(self, batch):
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
with self.amp_context:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
return outputs.loss
实现 get_eval_model_output#
get_eval_model_output 方法用于获取模型的评估输出,你需要确保返回的输出能够正确地进行评估。
run_eval() 里逻辑是:
累加每个 batch 的返回值到
correct再除以
total得eval_accuracy
所以你的 get_eval_model_output 应该返回当前 batch 的正确样本数。
示例:
def get_eval_model_output(self, batch):
# 1) 生成预测
# 2) 与结果进行正确性比较
# 3) 返回正确数量
return correct
YAML 配置#
建议先用保守参数跑通:
runner:
task_type: sft
max_epochs: 5
val_check_interval: -1
save_interval: -1
actor:
training_backend: fsdp
micro_batch_size: 2
global_batch_size: 32
model:
model_type: my_new_model
model_path: /path/to/model
data:
train_data_paths: /path/to/train_path
val_data_paths: /path/to/eval_path
跑通后再逐步加大 batch、打开 eval/save。
常见问题排查#
KeyError: xxx- collate 没有产出训练函数需要的字段Expected all tensors on same device- 某些 batch 字段没to(self.device)global_batch_size is not divisible ...- 调整global_batch_size / micro_batch_size / world_sizeeval_accuracy 异常偏低- 检查评估提取答案逻辑 - 检查drop_last是否导致评估样本丢失resume 后数据重复/跳过- 检查_data_epoch/_data_iter_offset保存与恢复流程