RL with Wan World Model#
Wan as an action-conditioned world model.#
Train a VLA policy closed-loop without real robots or a physics simulator by using the action-conditioned Wan world model as the environment backend. Wan generates future video frames from the current observation and an action sequence, so the policy can be optimized on βimaginedβ rollouts with RL (GRPO/PPO).
Overview#
Train OpenVLA-OFT with GRPO on LIBERO suites simulated by the Wan world model.
LIBERO
GRPO
Spatial Β· Object Β· Goal
1 node Β· GPUs
run_embodiment.sh β watch env/success_once.Tasks#
As a world model, Wan can in principle fit many task settings behind a consistent interface. RLinf currently ships weights and init data for three LIBERO suites:
Environment |
Task / Suite |
Config / Weights |
Focus |
|---|---|---|---|
Wan |
LIBERO-Spatial |
|
Use Wan as a learned simulator for LIBERO spatial tasks. |
Wan |
LIBERO-Object |
|
Roll out object manipulation dynamics in the video world model. |
Wan |
LIBERO-Goal |
|
Evaluate goal-conditioned rollouts through Wan. |
Observation and Action#
Field |
Description |
|---|---|
Observation |
RGB frames generated by the world model, |
Action |
7-D continuous actions normalized and tokenized to condition generation. |
Reward |
World-model reward classifier output in |
Prompt |
Natural-language task description used to condition the video world model. |
Unlike a traditional simulator, Wan has no reset(): it requires initialization frames and a task description, so you download an initialization dataset and point the config at it.
Installation#
First, clone the RLinf repository:
# Mainland China users can use a mirror for faster cloning:
# git clone https://ghfast.top/github.com/RLinf/RLinf.git
git clone https://github.com/RLinf/RLinf.git
cd RLinf
Then set up the dependencies with one of the two methods below β a prebuilt
Docker image (recommended) or a custom environment. The general setup
(prerequisites, GPU drivers, the in-image switch_env helper, mirrors, and
troubleshooting) is documented once in Installation;
the commands in this recipe only differ in the Docker image tag and the
--env value.
Option 1: Docker image β image tag agentic-rlinf0.3-wan:
docker run -it --rm --gpus all \
--shm-size 20g \
--network host \
--name rlinf \
-v .:/workspace/RLinf \
rlinf/rlinf:agentic-rlinf0.3-wan
# Mainland China mirror: docker.1ms.run/rlinf/rlinf:agentic-rlinf0.3-wan
# Inside the container, switch to the OpenVLA-OFT virtual environment:
source switch_env openvla-oft
Option 2: Custom environment β install bundle --env wan:
# Add --use-mirror for faster downloads in mainland China.
bash requirements/install.sh embodied --model openvla-oft --env wan
source .venv/bin/activate
Download the VLA Model#
Download the OpenVLA-OFT SFT checkpoints:
# Method 1: git clone
git lfs install
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-spatial-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-object-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-goal-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero10-traj1
# Method 2: huggingface-hub (set HF_ENDPOINT=https://hf-mirror.com in mainland China)
pip install huggingface-hub
hf download Haozhan72/Openvla-oft-SFT-libero-spatial-traj1 --local-dir Openvla-oft-SFT-libero-spatial-traj1
hf download Haozhan72/Openvla-oft-SFT-libero-object-traj1 --local-dir Openvla-oft-SFT-libero-object-traj1
hf download Haozhan72/Openvla-oft-SFT-libero-goal-traj1 --local-dir Openvla-oft-SFT-libero-goal-traj1
hf download Haozhan72/Openvla-oft-SFT-libero10-traj1 --local-dir Openvla-oft-SFT-libero10-traj1
After downloading, set model_path and unnorm_key in the config:
rollout:
model:
model_path: Pathto/RLinf/RLinf-OpenVLAOFT-LIBERO-90-Base-Lora
actor:
model:
model_path: Pathto/RLinf/RLinf-OpenVLAOFT-LIBERO-90-Base-Lora
unnorm_key: libero_90_no_noops_trajall # or libero_130_no_noops_trajall for the RLinf-OpenVLAOFT-LIBERO-130-Base-Lora model
Download the World Model#
Besides the VLA model, download the Wan weights and initialization data. RLinf currently provides data/checkpoints for three suites; each Wan checkpoint is built from 1500 trajectories generated by VLA rollout:
# Method 1: git clone
git lfs install
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Spatial
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Object
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Goal
# Method 2: huggingface-hub (set HF_ENDPOINT=https://hf-mirror.com in mainland China)
pip install huggingface-hub
hf download RLinf/RLinf-Wan-LIBERO-Spatial --local-dir RLinf-Wan-LIBERO-Spatial
hf download RLinf/RLinf-Wan-LIBERO-Object --local-dir RLinf-Wan-LIBERO-Object
hf download RLinf/RLinf-Wan-LIBERO-Goal --local-dir RLinf-Wan-LIBERO-Goal
The directory structure of RLinf-Wan-LIBERO-Spatial is:
RLinf-Wan-LIBERO-Spatial/
βββ dataset/ # Initialization dataset for simulation
β βββ traj0.npy # Trajectories containing initial frame only
β βββ traj1.npy
β βββ ...
β βββ trajN.npy
β βββ traj0_kir.npy # Trajectories with pre-keyframe context
β βββ traj1_kir.npy
β βββ ...
β βββ trajN_kir.npy
βββ model-00001.safetensors # World model checkpoint
βββ resnet_rm.pth # Reward model checkpoint
βββ Wan2.2_VAE.pth # VAE checkpoint
After downloading, set the world-model paths in the config:
env:
train:
wan_wm_hf_ckpt_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/
Run It#
1. Model parameters
Configure actor.model (OpenVLA-OFT example):
actor:
model:
model_path: "/path/to/model/Openvla-oft-SFT-libero-spatial-traj1/" # SFT model path
model_type: "openvla_oft" # model type
use_proprio: False # whether to use proprioception
num_images_in_input: 1 # number of image inputs
num_action_chunks: 8 # number of action chunks
unnorm_key: "libero_spatial_no_noops" # normalization key (aligned with SFT)
Because the world model does not provide proprioception, does not render wrist views, and
uses a fixed chunk length, use_proprio defaults to False, num_images_in_input to
1, and num_action_chunks to 8.
2. Environment configuration
# Recommended: wan_libero_spatial for train, libero_spatial for eval
env/train: wan_libero_spatial
env/eval: libero_spatial
# In env/train/wan_libero_spatial.yaml:
wm_env_type: libero
task_suite_name: libero_spatial
reset_gripper_open: True
# Whether to enable KeyFrame-Init Rollout
enable_kir: True
# Number of World Model denoising inference steps
num_inference_steps: 5
# Initialization dataset path for world model reset
initial_image_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/dataset
# VAE weights
VAE_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/Wan2.2_VAE.pth
# Pretrained world model weights
model_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/model-00001.safetensors
# Reward model
reward_model:
type: ResnetRewModel
from_pretrained: /Pathto/model/RLinf-Wan-LIBERO-Spatial/resnet_rm.pth
Key environment parameters:
enable_kir: enable KIR (KeyFrame-Init Rollout). If disabled, reset samples only.npyfiles whose names do not include_kir; if enabled, reset samples from all initialization files indataset/.num_inference_steps: World Model generation/inference steps (default5). Fewer steps are faster but may reduce visual quality; even single-step generation can still improve performance.reward_model.type: reward model class βResnetRewModelorTaskEmbedResnetRewModel.reset_gripper_open: initialize with an open gripper. DefaultTruefor train and eval; changing it is not recommended.
3. Launch
OpenVLA-OFT + GRPO uses examples/embodiment/config/wan_libero_spatial_grpo_openvlaoft.yaml:
bash examples/embodiment/run_embodiment.sh wan_libero_spatial_grpo_openvlaoft
Visualization and Results#
Watch ``env/success_once`` for the unnormalized episodic success rate. For every logged metric, see Training metrics. Enable generated-rollout videos with:
env:
eval:
video_cfg:
save_video: True
video_base_dir: ${runner.logger.log_path}/video/eval
We evaluate every task_id Γ trial_id combination across the Object, Spatial, and Goal
suites β 1500 environments in total (10 tasks Γ 150 trials). For both SFT and RL-trained
models we use do_sample = True and temperature_train = 1.6 in rollout.sampling_params,
with reset_gripper_open = True.
Note
Wan training and inference are built on Diffsynth-Studio. In the results below we use a frozen world model to serve the RL training of the VLA policy, without co-evolution between the world model and the VLA. You can implement co-evolution manually for further gains.
Model |
Spatial |
Object |
Goal |
|---|---|---|---|
OpenVLA-OFT (LoRA-base) |
61.2% |
36.7% |
48.2% |
OpenVLA-OFT (RLinf-GRPO with Wan as world model) |
77.5% |
77.9% |
60.1% |
Improvement |
+16.3% |
+41.2% |
+11.9% |