Rollout 接口#
本节介绍 RLinf 框架中 Rollout 类的关键 API。 它包含基于 SGLang 和 Hugging Face 后端的实现。
SGLang#
- class rlinf.workers.rollout.sglang.sglang_worker.SGLangWorker#
- __init__(config, placement, weight_reload='sync', config_rollout=None)#
Initialize the Worker with the given parent address and world size.
Only non-Ray workers should provide parent_address, world_size and rank. For example, when a Worker is created via multiprocessing by another Worker, the parent address, world size and rank should be provided.
- 参数:
parent_address (
Optional[WorkerAddress]) -- The address of the parent worker. This is used to set up the WorkerAddress for this worker.world_size (
Optional[int]) -- The total number of workers in the group. If not provided, it will be set to the environment variable WORLD_SIZE.rank (
Optional[int]) -- The rank of this worker in the group. If not provided, it will be set to the environment variable RANK.config (DictConfig)
placement (ModelParallelComponentPlacement)
weight_reload (Literal['sync', 'cpu', None])
config_rollout (DictConfig)
- get_sampling_param_from_config()#
Get sampling parameters from the configuration.
- 参数:
cfg_sampling_params (DictConfig)
- 返回类型:
dict
- shutdown()#
Shutdown the SGLang task.
- async async_generate(prompt=None, sampling_params=None, input_ids=None, image_data=None, return_logprob=False, request_info=None)#
Asynchronously generate text using the underlying SGLang engine and return the engine result together with the original input_ids, answers, and idx.
This wrapper calls self._engine.async_generate(...) and forwards the provided arguments. Because the SGLang engine does not include the original input_ids in its response, this method returns the input_ids alongside the engine result for downstream use.
- 参数:
prompt (
List[str] | str | None) -- Same as SGLang engine's prompt argument.sampling_params (
List[Dict] | Dict | None) -- Same as SGLang engine's sampling_params argument.input_ids (
List[List[int]] | List[int] | None) -- Same as SGLang engine's input_ids argument.return_logprob (
List[bool] | bool | None) -- Same as SGLang engine's return_logprob argument.request_info (
Any | None) -- Any additional request info you wish to be associated with this generation request. This argument will not be passed to the SGLang engine and returned directly.image_data (list | None)
- 返回:
A tuple containing the engine result and the original request_info.
- 返回类型:
Tuple[Dict, Any | None]
- async offload_engine()#
Release the model weights from the SGLang engine.
- async onload_engine()#
Onload the model weights from cpu to the SGLang engine.
- async abort_generation()#
Abort the generation.
- async sync_model_from_actor()#
Update the weights of the SGLang engine.
Huggingface#
- class rlinf.workers.rollout.hf.huggingface_worker.MultiStepRolloutWorker#
- __init__(cfg)#
Initialize the Worker with the given parent address and world size.
Only non-Ray workers should provide parent_address, world_size and rank. For example, when a Worker is created via multiprocessing by another Worker, the parent address, world size and rank should be provided.
- 参数:
parent_address (
Optional[WorkerAddress]) -- The address of the parent worker. This is used to set up the WorkerAddress for this worker.world_size (
Optional[int]) -- The total number of workers in the group. If not provided, it will be set to the environment variable WORLD_SIZE.rank (
Optional[int]) -- The rank of this worker in the group. If not provided, it will be set to the environment variable RANK.cfg (DictConfig)
- async sync_model_from_actor()#
Sync model parameters from the actor worker.
- async recv_env_output(input_channel, mode='train')#
Receive env outputs from mapped env ranks and merge if needed.
- 参数:
input_channel (Channel) -- Channel carrying env->rollout outputs.
mode (Literal['train', 'eval']) -- Rollout mode, either
"train"or"eval".
- 返回:
A single env output dict. When multiple env ranks are mapped to this rollout worker, outputs are merged on batch dimension.
- 返回类型:
dict[str, Any]
- send_chunk_actions(output_channel, chunk_actions, mode='train')#
Send action shards to mapped env ranks.
- 参数:
output_channel (Channel) -- Channel carrying rollout->env action chunks.
chunk_actions (Tensor | ndarray) -- Predicted action chunk batch (tensor or ndarray).
mode (Literal['train', 'eval']) -- Rollout mode, either
"train"or"eval".