Search-R1的强化学习训练#

结合工具调用的Multi-turn RL被证明能够将大语言模型(LLM)的交互边界扩展到真实世界。本文档介绍了如何在 RLinf 框架下复现论文Search-R1: Training LLMs to Reason and Leverage Search Engines with Reinforcement Learning中的实验,使用强化学习(RL)来训练大语言模型(LLM)通过调用搜索工具回答问题。

环境#

RLinf环境#

RLinf 环境配置参照 RLinf Installation

Local Wiki Server运行环境#

我们使用search-R1示例中的local retrieve server,通过conda安装faiss,详细文档见SearchR1,安装过程参考Search-R1 & veRL-SGLang,同样使用conda来配置环境

conda create -n retriever python=3.10 -y
conda activate retriever

conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y
pip install transformers datasets pyserini huggingface_hub

#  安装 GPU 版 faiss
conda install faiss-gpu=1.8.0 -c pytorch -c nvidia -y

pip install uvicorn fastapi

Wiki配置文件#

我们使用Asearcher提供的本地检索文件,下载文件大约 50~60GB

conda activate retriever

save_path=/the/path/to/save
python examples/agent/searchr1/download.py --save_path $save_path

从huggingface上下载e5-base-v2 embedding模型,并生成index

bash examples/agent/tools/search_local_server_faiss/build_index.sh

将之前下载好的wiki文件路径和index路径等写入examples/agent/searchr1/launch_local_server.sh

#!/bin/bash

set -ex

WIKI2018_WORK_DIR=$save_path

index_file=$WIKI2018_WORK_DIR/e5.index/e5_Flat.index
corpus_file=$WIKI2018_WORK_DIR/wiki_corpus.jsonl
pages_file=$WIKI2018_WORK_DIR/wiki_webpages.jsonl
retriever_name=e5
retriever_path=path/to/intfloat/e5-base-v2

python3  ./local_retrieval_server.py --index_path $index_file \
                                            --corpus_path $corpus_file \
                                            --pages_path $pages_file \
                                            --topk 3 \
                                            --retriever_name $retriever_name \
                                            --retriever_model $retriever_path \
                                            --faiss_gpu --port 8000

运行launch_local_server.sh启动Local Wiki Server,等待直至输出server ip等信息,代表server启动完成

(Optional) 使用Qdrant作为Wiki Server#

我们也支持使用 qdrant 作为 wiki 服务器。如果你不打算使用 qdrant,可以直接跳到 训练 部分。

使用上一部分中提到的方式准备 Asearcher 提供的本地 wiki corpus 检索文件。

从 huggingface 上下载e5-base-v2 embedding 模型。

下载 qdrant 并按照以下步骤构建 qdrant collection。首先,创建一个文件夹并把下载好的 qdrant 二进制文件放入该文件夹中,方便后续存储 qdrant 程序及其构建的 collection 文件。

examples/agent/tools/search_local_server_qdrant/build_index_qdrant.shexamples/agent/tools/search_local_server_qdrant/launch_local_server_qdrant.sh 中,根据之前下载的 wiki corpus, e5-base-v2 和 qdrant 路径更新 WIKI2018_DIRretriever_pathqdrant_path 的文件路径。

使用以下指令构建 qdrant wiki 服务器的 collection:

# 创建 qdrant 存放的文件夹
mkdir -p /path/to/qdrant
# 拷贝二进制执行文件
cp qdrant /path/to/qdrant

# 启动 qdrant server
/path/to/qdrant/qdrant &

# 构建 qdrant collection
bash examples/agent/tools/search_local_server_qdrant/build_index_qdrant.sh

运行 launch_local_server_qdrant.sh 启动 Local Qdrant Wiki Server ,等待直至输出 server ip 等信息,代表 server 启动完成

# 启动 qdrant server
/path/to/qdrant/qdrant &

# 启动基于 qdrant 的 wiki server
bash examples/agent/tools/search_local_server_qdrant/launch_local_server_qdrant.sh

Qdrant 默认使用 HNSW 图索引算法。关于 HNSW 图索引的优化,请参考 Qdrant 文档

在8*H100上训练#

从huggingface上下载训练集 ,并将路径写入 examples/agent/searchr1/config/train_qwen2.5.yaml:

data:
  ……
  train_data_paths: ["/path/to/train.jsonl"]

修改 train_qwen2.5.yamlrollout.model.model_path 的路径

rollout:
  group_name: "RolloutGroup"

  gpu_memory_utilization: 0.8
  model:
    model_path: /path/to/model/Qwen2.5-3B-Instruct
    model_type: qwen2.5

如果使用 sampling_params.stop 来控制模型停止节省训练时间,detokenize应当设置为True

rollout:
   ……
   distributed_executor_backend: mp   # ray or mp
   disable_log_stats: False
   detokenize: True

由于 Search-R1 会re-tokenize模型输出, recompute_logprobs 应当设置为True

algorithm:
   ……
   recompute_logprobs: True
   shuffle_rollout: False

运行 bash examples/agent/searchr1/run_train.sh 启动训练。

测试#

运行以下命令将 Megatron checkpoint 转换为 HuggingFace model

CKPT_PATH_MG={your_output_dir}/{exp_name}/checkpoints/global_step_xxx/actor
CKPT_PATH_HF={path/to/save/huggingface/model}
CKPT_PATH_ORIGINAL_HF={path/to/model/Qwen2.5-3B-Instruct}
CKPT_PATH_MF="${CKPT_PATH_HF}_middle_file"

python -m rlinf.utils.ckpt_convertor.megatron_convertor.convert_mg_to_middle_file \
    --load-path "${CKPT_PATH_MG}" \
    --save-path "${CKPT_PATH_MF}" \
    --model qwen_2.5_3b \
    --tp-size 1 --ep-size 1 --pp-size 1 \
    --te-ln-linear-qkv true --te-ln-linear-mlp_fc1 true \
    --te-extra-state-check-none true --use-gpu-num 0 --process-num 16

python -m rlinf.utils.ckpt_convertor.megatron_convertor.convert_middle_file_to_hf \
    --load-path "${CKPT_PATH_MF}" \
    --save-path "${CKPT_PATH_HF}" \
    --model qwen_2.5_3b \
    --use-gpu-num 0 --process-num 16

rm -rf "${CKPT_PATH_MF}"
rm -f "${CKPT_PATH_HF}"/*.done
shopt -s extglob
cp "${CKPT_PATH_ORIGINAL_HF}"/!(*model.safetensors.index.json) "${CKPT_PATH_HF}"

将转换得到的huggingface model路径填入 examples/agent/searchr1/config/eval_qwen2.5.yaml

rollout:
  group_name: "RolloutGroup"

  gpu_memory_utilization: 0.8
  model:
    model_path: /path/to/eval/model
    model_type: qwen2.5

修改测试数据集路径

data:
  ……
  val_data_paths: ["/path/to/eval.jsonl"]

运行 bash examples/agent/searchr1/run_eval.sh 启动测试。

训练曲线#

下面展示 reward 曲线和训练时间曲线。

Qwen2.5-3B-Instruct in RLinf

相较于原版性能( response length 稳定后,单 step 133s),我们加速了 55%,同时 reward 曲线和 eval 结果保持一致。

Qwen2.5-3B-Instruct in original implementation at PeterGriffinJin/Search-R1

References#

search-r1: PeterGriffinJin/Search-R1

Search-R1 & veRL-SGLang: zhaochenyang20/Awesome-ML-SYS-Tutorial

Asearcher: inclusionAI/ASearcher