Adding New Models with FSDP+HuggingFace#
This document focuses on how to use the HuggingFace Transformers library together with PyTorch FSDP (Fully Sharded Data Parallel) to train and generate with models. It supports any model implemented in HuggingFace as long as it is compatible with PyTorch. As an example, this guide provides a step-by-step workflow showing how to integrate a new HuggingFace model into RLinf following the OpenVLA pattern.
Prerequisites#
Familiarity with HuggingFace Transformers
Understanding of the RLinf framework architecture
Knowledge of PyTorch and distributed training
Step-by-Step Implementation#
1. Model Registration#
The recommended way to integrate a model into RLinf is to register a model builder. This is also the preferred approach when another repository depends on RLinf.
from rlinf.models import register_model
def build_my_model(cfg, torch_dtype):
from your_package.your_model_action_model import (
YourModelForRLActionPrediction,
)
return YourModelForRLActionPrediction.from_pretrained(
cfg.model_path,
torch_dtype=torch_dtype,
hidden_size=cfg.hidden_size,
unnorm_key=cfg.unnorm_key,
action_dim=cfg.action_token_len,
attn_implementation=cfg.attn_implementation,
low_cpu_mem_usage=cfg.low_cpu_mem_usage,
trust_remote_code=cfg.trust_remote_code,
)
register_model("your_model_type", build_my_model, category="embodied")
This single registration does two things:
Makes
model_typepass RLinf config validation.Makes
rlinf.models.get_model(cfg)build your model correctly.
Distributed registration when RLinf is a dependency (RLINF_EXT_MODULE)#
When RLinf is pulled in as a third-party dependency and Ray launches
distributed workers, calling register_model(...) only in the main (driver)
process does not carry over to worker processes: workers are separate Python
processes and will not re-execute the registration code from your entry script.
As a result, code paths that actually run on workers (rollout, training, etc.)
may still see an unregistered model and fail to build or resolve it.
Set the environment variable RLINF_EXT_MODULE to the import path of an
extension module. On each worker initialization, RLinf imports that module
and, when defined, invokes its register() so workers perform the same model
(or other extension) registration as the driver, without patching RLinf.
The extension module should implement register() and call register_model
and any other needed hooks there. For the exact env var format and examples, see
the docstring of ClusterEnvVar.EXT_MODULE in
rlinf/scheduler/cluster/cluster.py, e.g.:
export RLINF_EXT_MODULE=rlinf_ext
# or a fully qualified package path, e.g.:
export RLINF_EXT_MODULE=workflows.scripts.rlinf_ext
Standalone checkpoint utilities (for example
python -m rlinf.utils.ckpt_convertor.fsdp_convertor.convert_pt_to_hf) also
load RLINF_EXT_MODULE before calling get_model, matching worker behavior.
If your model_type is registered only via an extension module, set this
environment variable when running those tools as well.
2. Model Implementation#
Create your class in rlinf/models/embodiment/your_model_action_model.py and
inherit from a HuggingFace base class.
For a custom VLA model integrated into RLinf, you should follow the interface
contract of rlinf.models.embodiment.base_policy.BasePolicy and provide at
least forward and predict_action_batch. If you inherit directly from
BasePolicy, you will usually need to implement default_forward and
predict_action_batch, while forward is dispatched by the base class.
predict_action_batch is used to encapsulate generation, decoding, and optional
value computation so the RL-specific logic stays inside the model.
from typing import Any
from rlinf.models.embodiment.base_policy import BasePolicy
class YourModelForRLActionPrediction(BasePolicy):
def default_forward(
self,
forward_inputs: dict[str, Any],
**kwargs,
) -> dict[str, Any]:
return {
"logprobs": ..., # [B, ...]
"values": ..., # [B] or [B, 1]
"entropy": ..., # [B] or [B, 1]
}
def predict_action_batch(
self,
env_obs: dict[str, Any],
mode: str = "train",
compute_values: bool = True,
**kwargs,
) -> tuple[Any, dict[str, Any]]:
actions = ... # [B, action_chunk, action_dim] or flattened env action
result = {
"prev_logprobs": ...,
"prev_values": ...,
"forward_inputs": {
"chains": ...,
"denoise_inds": ...,
"action": ...,
"model_action": ...,
},
}
return actions, result
Refer to rlinf/models/embodiment/openpi/openpi_action_model.py. The inputs and
outputs of these interfaces are best understood as follows:
forward(forward_type=..., **kwargs): A unified dispatch entry that routes todefault_forward,sft_forward,sac_forward, or other training branches based onforward_type. For a custom VLA model, it is recommended to keep this dispatch layer so rollout, actor, SFT, and other callers all go through the same entry point.default_forward(forward_inputs, **kwargs): Used during training to recompute logprob / value / entropy from the intermediate results cached during rollout. Following OpenPI,forward_inputsusually includes at least:chains: the action chain from diffusion or iterative sampling, used to recompute logprob during training.denoise_inds: the denoising step index for each sample, which determines which training signal to use.tokenized_prompt/tokenized_prompt_mask: the text prompt and its mask.action: the actual action sent to the environment, usually flattened as[B, ...].model_action: the raw action output of the model, usually preserved before output transformation.Observation fields: such as
observation/image,observation/state,observation/wrist_image, and so on, so the model can reconstruct the observation.
OpenPI’s
default_forwardreturns adictwhose key fields include:logprobs: recomputed log probabilities of the current action under the training branch.values: state values produced by the value head.entropy: policy entropy or an equivalent uncertainty statistic used for PPO / RL loss.
predict_action_batch(env_obs, mode="train"|"eval", compute_values=True, **kwargs): Used during rollout / inference. It takes environment observations as input and returns actions that can be executed directly. Following OpenPI,env_obsusually includes:main_images: the main-view image.wrist_images: wrist camera images, optional.extra_view_images: additional view images, optional.states: low-dimensional states.task_descriptions: text task descriptions.
OpenPI returns
(actions, result):actions: action tensors sent directly to the environment.result["prev_logprobs"]: the old-policy logprob recorded during rollout.result["prev_values"]: the old value recorded during rollout.result["forward_inputs"]: the cached fields that will be fed intodefault_forwardagain during training.
When implementing a custom model, the field names do not need to be exactly the
same as OpenPI, but it is recommended to guarantee two things: first,
predict_action_batch should return enough forward_inputs for
default_forward to reconstruct all information required for training;
second, default_forward should output at least the fields actually required
by the training algorithm, such as logprobs, values, and entropy.
3. Configuration File#
Create examples/embodiment/config/your_config.yaml.
It should include fields such as model_type, action_token_len, and
precision. This template exposes your model hyperparameters so experiments are
easy to configure.
model:
model_type: "your_model_type"
action_token_len: 7
action_chunks_len: 1
unnorm_key: your_action_key
precision: "bf16"
vocab_size: 32000
hidden_size: 4096
image_size: [224, 224]
is_lora: False
low_cpu_mem_usage: True
trust_remote_code: True
If your custom model needs explicit wrapping rules under FSDP, you can also add
fsdp_config.wrap_policy to the config. When this field contains custom
rules, RLinf will prioritize them. Otherwise, it will continue to follow the
model’s own _no_split_modules / _no_split_names together with the
built-in logic.
Common fields include:
transformer_layer_cls_to_wrap: specify modules by class name that should be wrapped as transformer blocks. This is suitable for the main structure, such as decoder layers or attention blocks.module_classes_to_wrap: specify extra modules by class name that should be wrapped, such as vision encoders, projectors, or adapters that are not standard transformer blocks.no_split_names: perform finer-grained wrapping based on a module’s_fsdp_wrap_name. This is useful when class names are not distinctive enough, but the module role is fixed.
fsdp_config:
wrap_policy:
transformer_layer_cls_to_wrap:
- "MyDecoderLayer"
module_classes_to_wrap:
- "MyVisionTower"
- "MyProjector"
no_split_names:
- "value_head"
If your custom module is essentially the main transformer block, configuring
transformer_layer_cls_to_wrap is usually enough. module_classes_to_wrap
is more suitable for modules outside the main block that should still
participate in FSDP wrapping separately.
4. Model Usage#
After calling register_model("your_model_type", build_my_model, category="embodied"),
you can directly use model_type: your_model_type in YAML configs. If your
model also depends on a corresponding processor class, it is recommended to
load it inside build_my_model before returning the model instance.
from omegaconf import OmegaConf
from rlinf.models import get_model
cfg = OmegaConf.create(
{
"model_type": "your_model_type",
"precision": "fp32",
"is_lora": False,
}
)
model = get_model(cfg)
5. Notes#
RLinf performs extra environment-side processing on actions produced by different models. For example,
prepare_actions(...)inrlinf/envs/action_utils.pyapplies different post-processing logic based onenv_typeandmodel_type. If you are using a brand new custommodel_type, it is recommended to complete this extra action post-processing directly insidepredict_action_batch(...), so theactionsreturned to the environment are already in their final executable format. This avoids adding extra environment-side branching logic for every new model type.When registering a custom model with
register_model(...), you can passforce=Trueto override an existing implementation. This is suitable when you want to replace an RLinf built-in model while keeping the originalmodel_type. In that case, you can continue to reuse RLinf’s existing action processing, training branches, and other compatibility logic for that built-inmodel_type, reducing the amount of additional adaptation code. The prerequisite is that your model interface remains fully aligned with the built-in one.Different VLA models use different wrap policies under FSDP in RLinf. You can refer to
get_fsdp_wrap_policy(...)inrlinf/hybrid_engines/fsdp/utils.pyfor the current logic. If you are integrating a custom model, it is recommended to explicitly add the corresponding wrap policy config so that key modules such as transformer layers, vision encoders, projectors, and value heads participate in FSDP wrapping as expected. Otherwise, you may encounter degraded training performance, unreasonable memory usage, or modules that are not sharded correctly.