Adding New Models with FSDP+HuggingFace#
This document focus on using the HuggingFace Transformers library with PyTorch FSDP (Fully Sharded Data Parallel) to train and generate from models. It supports any model implemented in HuggingFace and compatible with PyTorch. As an example, this section provides a step-by-step recipe for integrating a new HuggingFace model into RLinf, following the OpenVLA pattern.
Prerequisites#
Familiarity with HuggingFace Transformers library
Understanding of the RLinf framework architecture
Knowledge of PyTorch and distributed training
Step-by-Step Implementation#
1. Model Configuration and Registration#
Edit rlinf/models/__init__.py to extend get_model_config_and_processor.
This registers your model’s Config, ImageProcessor, and Processor so RLinf can load them by name and wire up preprocessing automatically.
def get_model_config_and_processor(cfg: DictConfig):
if cfg.model.model_type == "your_model_type":
from your_package.configuration import YourModelConfig
from your_package.processing import YourImageProcessor, YourProcessor
AutoConfig.register("your_model", YourModelConfig)
AutoImageProcessor.register(YourModelConfig, YourImageProcessor)
AutoProcessor.register(YourModelConfig, YourProcessor)
model_config = AutoConfig.from_pretrained(
cfg.tokenizer.tokenizer_model
)
image_processor = YourImageProcessor.from_pretrained(
cfg.tokenizer.tokenizer_model,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
cfg.tokenizer.tokenizer_model,
trust_remote_code=True,
padding_side="left"
)
input_processor = YourProcessor.from_pretrained(
cfg.tokenizer.tokenizer_model,
tokenizer=tokenizer,
image_processor=image_processor,
trust_remote_code=True
)
return model_config, input_processor
2. Model Implementation#
Create your class in rlinf/models/embodiment/your_model_action_model.py inheriting from a HuggingFace base.
Implement predict_action_batch to wrap generation, decoding, and optional value computation, keeping RL logic encapsulated.
from transformers import YourBaseModel
from rlinf.models.embodiment.modules.value_head import ValueHead
class YourModelForRLActionPrediction(YourBaseModel):
def __init__(self, config, hidden_size, unnorm_key, action_dim):
super().__init__(config)
self._init_logits_processor()
action_norm_stats = self.get_action_stats(unnorm_key)
self.min_action = np.array(action_norm_stats["q01"])
self.max_action = np.array(action_norm_stats["q99"])
self.value_head = ValueHead(hidden_size)
self.action_dim = action_dim
def _init_logits_processor(self):
self.logits_processors = LogitsProcessorList()
self.logits_processors.append(
YourLogitsProcessor(self.config.n_action_bins)
)
@torch.no_grad()
def predict_action_batch(
self, input_ids=None, attention_mask=None, pixel_values=None,
do_sample=True, **kwargs
):
generated = self.generate(
input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
output_scores=True,
output_logits=True,
output_hidden_states=True,
return_dict_in_generate=True,
do_sample=do_sample,
logits_processor=self.logits_processors,
**kwargs
)
sequences = generated.sequences
actions = sequences[:, -self.action_dim:]
logits = torch.stack(generated.logits, dim=1)
if hasattr(self, "value_head"):
values = self.value_head(generated.hidden_states)
else:
values = torch.zeros_like(logits[..., :1])
return actions, sequences, logits, values
3. Model Loading#
Modify get_model in rlinf/models/__init__.py to call from_pretrained for your class when cfg.model_type matches. This ensures checkpoints load with the correct dtype, dimensions, and LoRA hooks.
def get_model(cfg: DictConfig, override_config_kwargs=None):
torch_dtype = torch_dtype_from_precision(cfg.precision)
model_path = cfg.model_path
if cfg.model_type == "your_model_type":
from .embodiment.your_model_action_model import (
YourModelForRLActionPrediction
)
model = YourModelForRLActionPrediction.from_pretrained(
model_path,
torch_dtype=torch_dtype,
hidden_size=cfg.hidden_size,
unnorm_key=cfg.unnorm_key,
action_dim=cfg.action_token_len,
attn_implementation=cfg.attn_implementation,
low_cpu_mem_usage=cfg.low_cpu_mem_usage,
trust_remote_code=cfg.trust_remote_code,
)
if cfg.is_lora:
# Add LoRA support here
pass
return model
4. Configuration File#
Create examples/embodiment/config/your_config.yaml with fields like model_type, action_token_len, and precision.
This template exposes your model’s hyperparameters for easy experiment setup.
model:
model_type: "your_model_type"
action_token_len: 7
action_chunks_len: 1
unnorm_key: your_action_key
micro_batch_size: 1
val_micro_batch_size: 8
precision: "bf16"
vocab_size: 32000
hidden_size: 4096
image_size: [224, 224]
is_lora: False
attn_implementation: "flash_attention_2"
low_cpu_mem_usage: True
trust_remote_code: True