使用 FSDP+HuggingFace 添加新模型#

本文档重点介绍如何使用 HuggingFace Transformers 库与 PyTorch FSDP(Fully Sharded Data Parallel,全分片数据并行) 来训练和生成模型。它支持 HuggingFace 中实现的任意模型,只要兼容 PyTorch 即可。 作为示例,本节将提供一个逐步的操作流程,展示如何按照 OpenVLA 模式将一个新的 HuggingFace 模型集成到 RLinf 中。

前置条件#

  • 熟悉 HuggingFace Transformers 库

  • 理解 RLinf 框架架构

  • 掌握 PyTorch 与分布式训练知识

逐步实现#

1. 模型配置与注册#

编辑 rlinf/models/__init__.py,扩展 get_model_config_and_processor。 这会注册你模型的 ConfigImageProcessorProcessor,使 RLinf 可以按名称加载它们并自动完成预处理。

def get_model_config_and_processor(cfg: DictConfig):
    if cfg.model.model_type == "your_model_type":
        from your_package.configuration import YourModelConfig
        from your_package.processing import YourImageProcessor, YourProcessor

        AutoConfig.register("your_model", YourModelConfig)
        AutoImageProcessor.register(YourModelConfig, YourImageProcessor)
        AutoProcessor.register(YourModelConfig, YourProcessor)

        model_config = AutoConfig.from_pretrained(
            cfg.tokenizer.tokenizer_model
        )
        image_processor = YourImageProcessor.from_pretrained(
            cfg.tokenizer.tokenizer_model,
            trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(
            cfg.tokenizer.tokenizer_model,
            trust_remote_code=True,
            padding_side="left"
        )
        input_processor = YourProcessor.from_pretrained(
            cfg.tokenizer.tokenizer_model,
            tokenizer=tokenizer,
            image_processor=image_processor,
            trust_remote_code=True
        )

    return model_config, input_processor

2. 模型实现#

rlinf/models/embodiment/your_model_action_model.py 中创建你的类,并继承自 HuggingFace 基类。 实现 predict_action_batch,用于封装生成、解码和可选的数值计算,将 RL 逻辑保持在模型内部。

from transformers import YourBaseModel
from rlinf.models.embodiment.modules.value_head import ValueHead

class YourModelForRLActionPrediction(YourBaseModel):
    def __init__(self, config, hidden_size, unnorm_key, action_dim):
        super().__init__(config)
        self._init_logits_processor()
        action_norm_stats = self.get_action_stats(unnorm_key)
        self.min_action = np.array(action_norm_stats["q01"])
        self.max_action = np.array(action_norm_stats["q99"])
        self.value_head = ValueHead(hidden_size)
        self.action_dim = action_dim

    def _init_logits_processor(self):
        self.logits_processors = LogitsProcessorList()
        self.logits_processors.append(
            YourLogitsProcessor(self.config.n_action_bins)
        )

    @torch.no_grad()
    def predict_action_batch(
        self, input_ids=None, attention_mask=None, pixel_values=None,
        do_sample=True, **kwargs
    ):
        generated = self.generate(
            input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            output_scores=True,
            output_logits=True,
            output_hidden_states=True,
            return_dict_in_generate=True,
            do_sample=do_sample,
            logits_processor=self.logits_processors,
            **kwargs
        )
        sequences = generated.sequences
        actions = sequences[:, -self.action_dim:]
        logits = torch.stack(generated.logits, dim=1)
        if hasattr(self, "value_head"):
            values = self.value_head(generated.hidden_states)
        else:
            values = torch.zeros_like(logits[..., :1])
        return actions, sequences, logits, values

3. 模型加载#

修改 rlinf/models/__init__.py 中的 get_model,当 cfg.model_type 匹配时调用 from_pretrained 加载你的类。 这能确保检查点加载时保持正确的 dtype、维度和 LoRA hooks。

def get_model(cfg: DictConfig, override_config_kwargs=None):
    torch_dtype = torch_dtype_from_precision(cfg.precision)
    model_path = cfg.model_path
    if cfg.model_type == "your_model_type":
        from .embodiment.your_model_action_model import (
            YourModelForRLActionPrediction
        )
        model = YourModelForRLActionPrediction.from_pretrained(
            model_path,
            torch_dtype=torch_dtype,
            hidden_size=cfg.hidden_size,
            unnorm_key=cfg.unnorm_key,
            action_dim=cfg.action_token_len,
            attn_implementation=cfg.attn_implementation,
            low_cpu_mem_usage=cfg.low_cpu_mem_usage,
            trust_remote_code=cfg.trust_remote_code,
        )

    if cfg.is_lora:
        # 在此添加 LoRA 支持
        pass

    return model

4. 配置文件#

examples/embodiment/config/your_config.yaml 中创建配置文件, 包含 model_typeaction_token_lenprecision 等字段。 该模板会暴露你模型的超参数,方便实验设置。

model:
  model_type: "your_model_type"
  action_token_len: 7
  action_chunks_len: 1
  unnorm_key: your_action_key
  micro_batch_size: 1
  val_micro_batch_size: 8
  precision: "bf16"
  vocab_size: 32000
  hidden_size: 4096
  image_size: [224, 224]
  is_lora: False
  attn_implementation: "flash_attention_2"
  low_cpu_mem_usage: True
  trust_remote_code: True