基于 Wan 世界模型的强化学习#

本文档提供在 RLinf 框架中启动与管理 Vision-Language-Action Models (VLAs) 训练任务的完整指南, 使用 Action-conditioned Wan 世界模型 (下文简称 Wan)作为环境后端。

核心目标是在无需真实机器人或传统物理仿真器的情况下,通过视觉生成模型模拟环境随动作的动态变化, 为策略优化提供闭环训练。

与在 LIBERO 环境中微调 VLA 的流程类似,本指南重点介绍如何在基于 Wan 的仿真环境中运行强化学习训练任务, 并展示该框架下模型具备的关键能力。

Wan 主要希望赋予模型以下能力:

  1. 视觉理解:Wan 借助当前观测图像与给定动作序列生成未来视频帧,为策略提供连续视觉反馈,使模型能够处理来自真实机器人相机的 RGB 图像。

  2. 语言理解:理解自然语言任务描述。

  3. 动作生成:产生精确的机器人动作(位置、旋转、夹爪控制)。

  4. 策略提升:借助 Wan 生成的“想象”轨迹,使用 PPO 等强化学习方法优化 VLA 策略。

环境#

作为世界模型,Wan 理论上可以拟合任意环境的任意任务并保持接口一致。以 LIBERO 环境 为例,环境接口与定义如下:

Wan 模拟 LIBERO 环境

  • Environment:视觉生成模型

  • Task:指挥一台 7 自由度机械臂完成多种家居操作技能(抓取放置、叠放、开抽屉、空间重排等)

  • Observation:视觉生成模型返回的图像

  • Action Space:7 维连续动作 - 末端执行器三维位置控制(x, y, z) - 三维旋转控制(roll, pitch, yaw) - 夹爪控制(开 / 合)

Wan 模拟 LIBERO 环境重置

不同于传统仿真器可通过 reset() 直接重置,Wan 需要接收初始帧与任务描述进行初始化与重置。 因此需提前下载初始化数据集并在配置中指定路径。

数据结构

  • Images:RGB 张量 [batch_size, 256, 256, 3]

  • Task Descriptions:自然语言指令

  • Actions:归一化连续值,转换为离散 tokens

  • Rewards:由世界模型中的奖励判定器给出,范围为 0 到 1

算法#

核心算法组件

  1. PPO(Proximal Policy Optimization)

    • 使用 GAE(Generalized Advantage Estimation)进行优势估计

    • 基于比率的策略裁剪

    • 价值函数裁剪

    • 熵正则化

  2. GRPO(Group Relative Policy Optimization)

    • 对于每个状态 / 提示,策略生成 G 个独立动作

    • 以组内平均奖励为基线,计算每个动作的相对优势

  3. Vision-Language-Action 模型

    • OpenVLA 架构,多模态融合

    • 动作 token 化与反 token 化

    • 带 Value Head 的 Critic 功能

依赖安装#

1. 克隆 RLinf 仓库#

# 为提高国内下载速度,可以使用:
# git clone https://ghfast.top/github.com/RLinf/RLinf.git
git clone https://github.com/RLinf/RLinf.git
cd RLinf

2. 安装依赖#

选项 1:Docker 镜像

使用 Docker 镜像运行实验。

docker run -it --rm --gpus all \
   --shm-size 20g \
   --network host \
   --name rlinf \
   -v .:/workspace/RLinf \
   rlinf/rlinf:agentic-rlinf0.2-wan
   # 如果需要国内加速下载镜像,可以使用:
   # docker.1ms.run/rlinf/rlinf:agentic-rlinf0.2-wan

选项 2:自定义环境

直接在本地环境中安装依赖:

# 为提高国内依赖安装速度,可在 install.sh 中添加 --use-mirror
bash requirements/install.sh embodied --model openvla-oft --env wan
source .venv/bin/activate

VLA 模型下载#

在开始训练之前,需要下载相应预训练模型:

# 使用下面任一方法下载模型
# 方法 1:使用 git clone
git lfs install
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-spatial-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-object-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero-goal-traj1
git clone https://huggingface.co/Haozhan72/Openvla-oft-SFT-libero10-traj1

# 方法 2:使用 huggingface-hub
# 为提升国内下载速度,可以设置:
# export HF_ENDPOINT=https://hf-mirror.com
pip install huggingface-hub
hf download Haozhan72/Openvla-oft-SFT-libero-spatial-traj1 --local-dir Openvla-oft-SFT-libero-spatial-traj1
hf download Haozhan72/Openvla-oft-SFT-libero-object-traj1 --local-dir Openvla-oft-SFT-libero-object-traj1
hf download Haozhan72/Openvla-oft-SFT-libero-goal-traj1 --local-dir Openvla-oft-SFT-libero-goal-traj1
hf download Haozhan72/Openvla-oft-SFT-libero10-traj1 --local-dir Openvla-oft-SFT-libero10-traj1

下载完成后,请确保在配置 yaml 文件中正确指定模型路径与 unnorm_key。

rollout:
   model:
      model_path: Pathto/RLinf/RLinf-OpenVLAOFT-LIBERO-90-Base-Lora
actor:
   model:
      model_path: Pathto/RLinf/RLinf-OpenVLAOFT-LIBERO-90-Base-Lora
      unnorm_key: libero_90_no_noops_trajall # 对于 RLinf-OpenVLAOFT-LIBERO-130-Base-Lora 模型,使用 libero_130_no_noops_trajall

WM (World Model) 模型下载#

除 VLA 模型之外,还需下载 Wan 权重与用于仿真初始化的数据集。 当前 RLinf 仅提供 libero-spatial、libero-object 和 libero-goal 三个 suite 的权重与数据。各 suite 的 Wan 权重均基于 VLA 模型 rollout 的 1500 条轨迹构建,下载方式如下:

# 下载权重与初始化数据
# 方法 1:使用 git clone
git lfs install
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Spatial
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Object
git clone https://huggingface.co/RLinf/RLinf-Wan-LIBERO-Goal

# 方法 2:使用 huggingface-hub
pip install huggingface-hub
hf download RLinf/RLinf-Wan-LIBERO-Spatial --local-dir RLinf-Wan-LIBERO-Spatial
hf download RLinf/RLinf-Wan-LIBERO-Object --local-dir RLinf-Wan-LIBERO-Object
hf download RLinf/RLinf-Wan-LIBERO-Goal --local-dir RLinf-Wan-LIBERO-Goal

RLinf-Wan-LIBERO-Spatial 的目录结构如下:

RLinf-Wan-LIBERO-Spatial/
    ├── dataset/                            # 用于仿真初始化数据集
    │   ├── traj0.npy                       # 仅包含初始帧的轨迹
    │   ├── traj1.npy
    │   ├── ...
    │   └── trajN.npy
    │   ├── traj0_kir.npy                   # 包含关键帧之前的轨迹
    │   ├── traj1_kir.npy
    │   ├── ...
    │   └── trajN_kir.npy
    ├── model-00001.safetensors              # 世界模型权重文件
    ├── resnet_rm.pth                        # 奖励模型权重文件
    └── Wan2.2_VAE.pth                       # VAE 模型权重文件

下载完成后,请确保在配置 yaml 文件中正确指定模型路径。

env:
    train:
        wan_wm_hf_ckpt_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/

运行脚本#

请确保在运行下面命令前已激活正确的 Python 虚拟环境(venv)。 如果使用官方 Docker 镜像,请通过 source switch_env openvla-oft 切换到 openvla-oft 环境。

1. 关键参数配置

以 OpenVLA-OFT 模型为例,在 actor.model 中需要配置以下关键参数:

actor:
  model:
    model_path: "/path/to/model/Openvla-oft-SFT-libero-spatial-traj1/"    # SFT 模型路径
    model_type: "openvla_oft"                                             # 模型类型设置为 openvla_oft
    use_proprio: False                                                    # 是否使用本体感觉信息
    num_images_in_input: 1                                                # 输入图像数量
    num_action_chunks: 8                                                  # 动作块数量
    unnorm_key: "libero_spatial_no_noops"                                 # 动作归一化键(与 SFT 一致)。RLinf-OpenVLAOFT-LIBERO-130-Base-Lora 使用 libero_130_no_noops_trajall;RLinf-OpenVLAOFT-LIBERO-90-Base-Lora 使用 libero_90_no_noops_trajall。

需要注意的是,world model 不提供本体信息、不生成腕部视角且 chunk 固定, 因此 use_proprio 默认 False,num_images_in_input 默认 1,num_action_chunks 默认 8。

2. 环境配置

在环境配置文件中设置以下关键参数:

# 在 CHOSEN_CONFIG 中覆写

# 推荐训练使用 wan_libero_spatial,评估使用 libero_spatial
env/train: wan_libero_spatial
env/eval: libero_spatial

# 在 env/train/wan_libero_spatial.yaml 中:
simulator_type: libero
task_suite_name: libero_spatial
# 是否使用 KeyFrame-Init Rollout
enable_kir: True
# world model 初始化的初始图像路径
initial_image_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/dataset
# VAE权重
VAE_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/Wan2.2_VAE.pth
# 预训练的世界模型权重
model_path: /Pathto/model/RLinf-Wan-LIBERO-Spatial/model-00001.safetensors
# 奖励模型权重
reward_model:
  type: ResnetRewModel
  from_pretrained: /Pathto/model/RLinf-Wan-LIBERO-Spatial/resnet_rm.pth

环境配置中的关键参数说明:

  • enable_kir:是否启用关键帧初始化 KIR (KeyFrame-Init) ,如果关闭,环境将会从 dataset/ 中名字中不含 kir 的 npy 文件进行初始化,如果开启,环境将会从 dataset/ 中的所有初始化文件中进行等可能的初始化

  • reward_model.type:奖励模型类型,支持多种选择,包括 ResnetRewModel``和 ``TaskEmbedResnetRewModel 等。

3. 配置文件

目前支持 OpenVLA-OFT 模型与 GRPO 算法,对应配置文件:

  • OpenVLA-OFT + GRPOexamples/embodiment/config/wan_libero_spatial_grpo_openvlaoft.yaml

4. 启动命令

选择配置后,运行以下命令开始训练:

bash examples/embodiment/run_embodiment.sh CHOSEN_CONFIG

例如,在 Wan 环境中使用 GRPO 训练 OpenVLA-OFT 模型:

bash examples/embodiment/run_embodiment.sh wan_libero_spatial_grpo_openvlaoft

可视化与结果#

1. TensorBoard 日志

# 启动 TensorBoard
tensorboard --logdir ./logs --port 6006

2. 关键监控指标

  • 训练指标

    • train/actor/approx_kl:近似 KL,用于监控策略更新幅度

    • train/actor/clip_fraction:触发 PPO 裁剪的样本比例

    • train/actor/clipped_ratio:裁剪后概率比的均值,用于衡量策略更新受裁剪影响程度

    • train/actor/grad_norm:梯度范数

    • train/actor/lr:学习率

    • train/actor/policy_loss:PPO/GRPO 的策略损失

    • train/critic/value_loss:价值函数损失

    • train/critic/value_clip_ratio:PPO-style value function clipping 中触发裁剪的比例

    • train/critic/explained_variance:衡量价值函数拟合程度,越接近 1 越好

    • train/entropy_loss:策略熵

    • train/loss:总训练损失(actor_loss + critic_loss + entropy_loss regularization)

  • Rollout 指标

    • rollout/advantages_max:优势函数最大值

    • rollout/advantages_mean:优势函数均值

    • rollout/advantages_min:优势函数最小值

    • rollout/rewards:一个 chunk 的奖励(参考 libero_env.py 的 L414)

  • 环境指标

    • env/episode_len:回合实际经历的环境步数(单位:step)

    • env/return:回合总回报。在 LIBERO 的稀疏奖励设置中该指标不具参考意义,因为回合中几乎始终为 0,仅在成功终止时为 1。

    • env/reward:step-level 奖励(任务未完成时为 0,仅成功终止时为 1)。 日志数值按回合步数归一化,难以直接反映真实任务表现。

    • env/success_once:推荐用于监控训练效果,直接反映未归一化的任务成功率。

3. 视频生成

env:
   eval:
      video_cfg:
         save_video: True
         video_base_dir: ${runner.logger.log_path}/video/eval

4. 训练日志工具集成

runner:
   task_type: embodied
   logger:
      log_path: "../results"
      project_name: rlinf
      experiment_name: "libero_10_grpo_openvlaoft"
      logger_backends: ["tensorboard"] # wandb, swanlab

LIBERO 部分结果#

目前仅测试使用 Wan 模拟 libero-spatial、libero-object 和 libero-goal 环境并训练 VLA 模型,更多环境仍在测试中。

对于每个 LIBERO 套件,我们评估所有 task_id 与 trial_id 的组合。Spatial、Object 和 Goal 套件共评估 1500 个环境(10 个任务 × 150 个试次)。

我们根据模型的训练配置设置评估超参: 对于 SFT 模型与 RL 训练模型,均设置 do_sample = Truetemperature = 1.6 以评估性能。

备注

我们基于 Diffsynth-Studio 框架进行Wan的训练与推理。 在下面的评测结果中,我们仅使用冻结的世界模型服务于 VLA 模型的强化学习训练,并未使用世界模型与 VLA 的协同进化。用户可通过手动实现协同进化,实现性能的继续增长。

使用 Wan 模拟器的 LIBERO 任务组评测结果#

模型

Spatial

Object

Goal

OpenVLA-OFT (LoRA-base)

61.2%

36.7%

48.2%

OpenVLA-OFT(Wan 作为世界模型的 RLinf-GRPO)

71.5%

77.9%

60.1%

效果提升

+10.3%

+41.2%

+11.9%