Rollout 接口#
本节介绍 RLinf 框架中 Rollout 类的关键 API。 它包含基于 SGLang 和 Hugging Face 后端的实现。
SGLang#
- class rlinf.workers.rollout.sglang.sglang_worker.SGLangWorker#
- __init__(config, placement)#
Initialize the Worker with the given parent address and world size.
Only non-Ray workers should provide parent_address, world_size and rank. For example, when a Worker is created via multiprocessing by another Worker, the parent address, world size and rank should be provided.
- 参数:
parent_address (
Optional[WorkerAddress]) -- The address of the parent worker. This is used to set up the WorkerAddress for this worker.world_size (
Optional[int]) -- The total number of workers in the group. If not provided, it will be set to the environment variable WORLD_SIZE.rank (
Optional[int]) -- The rank of this worker in the group. If not provided, it will be set to the environment variable RANK.config (DictConfig)
placement (ModelParallelComponentPlacement)
- get_sampling_param_from_config()#
Get sampling parameters from the configuration.
- 参数:
cfg (DictConfig)
- 返回类型:
dict
- shutdown()#
Shutdown the SGLang task.
- async async_generate(prompt=None, sampling_params=None, input_ids=None, image_data=None, return_logprob=False, request_info=None)#
Asynchronously generate text using the underlying SGLang engine and return the engine result together with the original input_ids, answers, and idx.
This wrapper calls self._engine.async_generate(...) and forwards the provided arguments. Because the SGLang engine does not include the original input_ids in its response, this method returns the input_ids alongside the engine result for downstream use.
- 参数:
prompt (
List[str] | str | None) -- Same as SGLang engine's prompt argument.sampling_params (
List[Dict] | Dict | None) -- Same as SGLang engine's sampling_params argument.input_ids (
List[List[int]] | List[int] | None) -- Same as SGLang engine's input_ids argument.return_logprob (
List[bool] | bool | None) -- Same as SGLang engine's return_logprob argument.request_info (
Any | None) -- Any additional request info you wish to be associated with this generation request. This argument will not be passed to the SGLang engine and returned directly.image_data (list | None)
- 返回:
A tuple containing the engine result and the original request_info.
- 返回类型:
Tuple[Dict, Any | None]
- async offload_engine()#
Offload the model weights from the SGLang engine.
- async abort_generation()#
Abort the generation.
- async sync_model_from_actor()#
Update the weights of the SGLang engine.
Huggingface#
- class rlinf.workers.rollout.hf.huggingface_worker.MultiStepRolloutWorker#
- __init__(cfg)#
Initialize the Worker with the given parent address and world size.
Only non-Ray workers should provide parent_address, world_size and rank. For example, when a Worker is created via multiprocessing by another Worker, the parent address, world size and rank should be provided.
- 参数:
parent_address (
Optional[WorkerAddress]) -- The address of the parent worker. This is used to set up the WorkerAddress for this worker.world_size (
Optional[int]) -- The total number of workers in the group. If not provided, it will be set to the environment variable WORLD_SIZE.rank (
Optional[int]) -- The rank of this worker in the group. If not provided, it will be set to the environment variable RANK.cfg (DictConfig)
- get_dones_and_rewards(env_output)#
Get dones and rewards from environment batch, handling auto_reset if needed.
- 参数:
env_output (dict[str, Tensor]) -- Environment batch containing dones, rewards, and optionally final_obs
- 返回:
Tuple of (dones, rewards) tensors.
- 返回类型:
tuple[Tensor | None, Tensor | None]
- sync_model_from_actor()#
Sync model parameters from the actor worker.