RL with Wan World Model#

https://raw.githubusercontent.com/RLinf/misc/main/pic/wan.png

Wan as an action-conditioned world model.#

Train a VLA policy closed-loop without real robots or a physics simulator by using the action-conditioned Wan world model as the environment backend. Wan generates future video frames from the current observation and an action sequence, so the policy can be optimized on β€œimagined” rollouts with RL (GRPO/PPO).

Overview#

Train OpenVLA-OFT with GRPO on LIBERO suites simulated by the Wan world model.

Environments

LIBERO

Algorithms

GRPO

Tasks

Spatial Β· Object Β· Goal

Hardware

1 node Β· GPUs

You’ll do: install β†’ download the VLA model β†’ download the Wan world-model weights + init data β†’ launch run_embodiment.sh β†’ watch env/success_once.
Prerequisites: Installation Β· an OpenVLA-OFT SFT checkpoint Β· Wan world-model weights and init dataset (steps below).

Tasks#

As a world model, Wan can in principle fit many task settings behind a consistent interface. RLinf currently ships weights and init data for three LIBERO suites:

Environment

Task / Suite

Config / Weights

Focus

Wan

LIBERO-Spatial

RLinf/RLinf-Wan-LIBERO-Spatial

Use Wan as a learned simulator for LIBERO spatial tasks.

Wan

LIBERO-Object

RLinf/RLinf-Wan-LIBERO-Object

Roll out object manipulation dynamics in the video world model.

Wan

LIBERO-Goal

RLinf/RLinf-Wan-LIBERO-Goal

Evaluate goal-conditioned rollouts through Wan.

Observation and Action#

Field

Description

Observation

RGB frames generated by the world model, [B, 256, 256, 3], seeded from initialization frames.

Action

7-D continuous actions normalized and tokenized to condition generation.

Reward

World-model reward classifier output in [0, 1].

Prompt

Natural-language task description used to condition the video world model.

Unlike a traditional simulator, Wan has no reset(): it requires initialization frames and a task description, so you download an initialization dataset and point the config at it.

Installation#

First, clone the RLinf repository:

# Mainland China users can use a mirror for faster cloning:
# git clone https://ghfast.top/github.com/RLinf/RLinf.git
git clone https://github.com/RLinf/RLinf.git
cd RLinf

Then set up the dependencies with one of the two methods below β€” a prebuilt Docker image (recommended) or a custom environment. The general setup (prerequisites, GPU drivers, the in-image switch_env helper, mirrors, and troubleshooting) is documented once in Installation; the commands in this recipe only differ in the Docker image tag and the --env value.

Option 1: Docker image β€” image tag agentic-rlinf0.3-wan:

docker run -it --rm --gpus all \
   --shm-size 20g \
   --network host \
   --name rlinf \
   -v .:/workspace/RLinf \
   rlinf/rlinf:agentic-rlinf0.3-wan
   # Mainland China mirror: docker.1ms.run/rlinf/rlinf:agentic-rlinf0.3-wan

# Inside the container, switch to the OpenVLA-OFT virtual environment:
source switch_env openvla-oft

Option 2: Custom environment β€” install bundle --env wan:

# Add --use-mirror for faster downloads in mainland China.
bash requirements/install.sh embodied --model openvla-oft --env wan
source .venv/bin/activate

Download the VLA Model#

Download the OpenVLA-OFT SFT checkpoints:

# Method 1: git clone
git lfs install
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-spatial-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-object-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-goal-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero10-traj1

# Method 2: huggingface-hub (set HF_ENDPOINT=https://hf-mirror.com in mainland China)
pip install huggingface-hub
hf download Haozhan72/Openvla-oft-SFT-libero-spatial-traj1 --local-dir Openvla-oft-SFT-libero-spatial-traj1
hf download Haozhan72/Openvla-oft-SFT-libero-object-traj1 --local-dir Openvla-oft-SFT-libero-object-traj1
hf download Haozhan72/Openvla-oft-SFT-libero-goal-traj1 --local-dir Openvla-oft-SFT-libero-goal-traj1
hf download Haozhan72/Openvla-oft-SFT-libero10-traj1 --local-dir Openvla-oft-SFT-libero10-traj1

After downloading, set model_path and unnorm_key in the config:

rollout:
   model:
      model_path: Pathto/RLinf/RLinf-OpenVLAOFT-LIBERO-90-Base-Lora
actor:
   model:
      model_path: Pathto/RLinf/RLinf-OpenVLAOFT-LIBERO-90-Base-Lora
      unnorm_key: libero_90_no_noops_trajall # or libero_130_no_noops_trajall for the RLinf-OpenVLAOFT-LIBERO-130-Base-Lora model

Download the World Model#

Besides the VLA model, download the Wan weights and initialization data. RLinf currently provides data/checkpoints for three suites; each Wan checkpoint is built from 1500 trajectories generated by VLA rollout:

# Method 1: git clone
git lfs install
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Spatial
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Object
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Goal

# Method 2: huggingface-hub (set HF_ENDPOINT=https://hf-mirror.com in mainland China)
pip install huggingface-hub
hf download RLinf/RLinf-Wan-LIBERO-Spatial --local-dir RLinf-Wan-LIBERO-Spatial
hf download RLinf/RLinf-Wan-LIBERO-Object --local-dir RLinf-Wan-LIBERO-Object
hf download RLinf/RLinf-Wan-LIBERO-Goal --local-dir RLinf-Wan-LIBERO-Goal

The directory structure of RLinf-Wan-LIBERO-Spatial is:

RLinf-Wan-LIBERO-Spatial/
    β”œβ”€β”€ dataset/                            # Initialization dataset for simulation
    β”‚   β”œβ”€β”€ traj0.npy                       # Trajectories containing initial frame only
    β”‚   β”œβ”€β”€ traj1.npy
    β”‚   β”œβ”€β”€ ...
    β”‚   └── trajN.npy
    β”‚   β”œβ”€β”€ traj0_kir.npy                   # Trajectories with pre-keyframe context
    β”‚   β”œβ”€β”€ traj1_kir.npy
    β”‚   β”œβ”€β”€ ...
    β”‚   └── trajN_kir.npy
    β”œβ”€β”€ model-00001.safetensors             # World model checkpoint
    β”œβ”€β”€ resnet_rm.pth                       # Reward model checkpoint
    └── Wan2.2_VAE.pth                      # VAE checkpoint

After downloading, set the world-model paths in the config:

env:
    train:
        wan_wm_hf_ckpt_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/

Run It#

1. Model parameters

Configure actor.model (OpenVLA-OFT example):

actor:
  model:
    model_path: "/path/to/model/Openvla-oft-SFT-libero-spatial-traj1/"    # SFT model path
    model_type: "openvla_oft"                                             # model type
    use_proprio: False                                                    # whether to use proprioception
    num_images_in_input: 1                                                # number of image inputs
    num_action_chunks: 8                                                  # number of action chunks
    unnorm_key: "libero_spatial_no_noops"                                 # normalization key (aligned with SFT)

Because the world model does not provide proprioception, does not render wrist views, and uses a fixed chunk length, use_proprio defaults to False, num_images_in_input to 1, and num_action_chunks to 8.

2. Environment configuration

# Recommended: wan_libero_spatial for train, libero_spatial for eval
env/train: wan_libero_spatial
env/eval: libero_spatial

# In env/train/wan_libero_spatial.yaml:
wm_env_type: libero
task_suite_name: libero_spatial
reset_gripper_open: True
# Whether to enable KeyFrame-Init Rollout
enable_kir: True
# Number of World Model denoising inference steps
num_inference_steps: 5
# Initialization dataset path for world model reset
initial_image_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/dataset
# VAE weights
VAE_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/Wan2.2_VAE.pth
# Pretrained world model weights
model_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/model-00001.safetensors
# Reward model
reward_model:
  type: ResnetRewModel
  from_pretrained: /Pathto/model/RLinf-Wan-LIBERO-Spatial/resnet_rm.pth

Key environment parameters:

  • enable_kir: enable KIR (KeyFrame-Init Rollout). If disabled, reset samples only .npy files whose names do not include _kir; if enabled, reset samples from all initialization files in dataset/.

  • num_inference_steps: World Model generation/inference steps (default 5). Fewer steps are faster but may reduce visual quality; even single-step generation can still improve performance.

  • reward_model.type: reward model class β€” ResnetRewModel or TaskEmbedResnetRewModel.

  • reset_gripper_open: initialize with an open gripper. Default True for train and eval; changing it is not recommended.

3. Launch

OpenVLA-OFT + GRPO uses examples/embodiment/config/wan_libero_spatial_grpo_openvlaoft.yaml:

bash examples/embodiment/run_embodiment.sh wan_libero_spatial_grpo_openvlaoft

Visualization and Results#

Watch ``env/success_once`` for the unnormalized episodic success rate. For every logged metric, see Training metrics. Enable generated-rollout videos with:

env:
   eval:
      video_cfg:
         save_video: True
         video_base_dir: ${runner.logger.log_path}/video/eval

We evaluate every task_id Γ— trial_id combination across the Object, Spatial, and Goal suites β€” 1500 environments in total (10 tasks Γ— 150 trials). For both SFT and RL-trained models we use do_sample = True and temperature_train = 1.6 in rollout.sampling_params, with reset_gripper_open = True.

Note

Wan training and inference are built on Diffsynth-Studio. In the results below we use a frozen world model to serve the RL training of the VLA policy, without co-evolution between the world model and the VLA. You can implement co-evolution manually for further gains.

Evaluation results on LIBERO suites with the Wan simulator#

Model

Spatial

Object

Goal

OpenVLA-OFT (LoRA-base)

61.2%

36.7%

48.2%

OpenVLA-OFT (RLinf-GRPO with Wan as world model)

77.5%

77.9%

60.1%

Improvement

+16.3%

+41.2%

+11.9%