Rollout 接口#

本节介绍 RLinf 框架中 Rollout 类的关键 API。 它包含基于 SGLangHugging Face 后端的实现。

SGLang#

class rlinf.workers.rollout.sglang.sglang_worker.SGLangWorker#
__init__(config, placement, weight_reload='sync', config_rollout=None)#

Initialize the Worker with the given parent address and world size.

Only non-Ray workers should provide parent_address, world_size and rank. For example, when a Worker is created via multiprocessing by another Worker, the parent address, world size and rank should be provided.

参数:
  • parent_address (Optional[WorkerAddress]) -- The address of the parent worker. This is used to set up the WorkerAddress for this worker.

  • world_size (Optional[int]) -- The total number of workers in the group. If not provided, it will be set to the environment variable WORLD_SIZE.

  • rank (Optional[int]) -- The rank of this worker in the group. If not provided, it will be set to the environment variable RANK.

  • config (DictConfig)

  • placement (ModelParallelComponentPlacement)

  • weight_reload (Literal['sync', 'cpu', None])

  • config_rollout (DictConfig)

get_sampling_param_from_config()#

Get sampling parameters from the configuration.

参数:

cfg_sampling_params (DictConfig)

返回类型:

dict

shutdown()#

Shutdown the SGLang task.

async async_generate(prompt=None, sampling_params=None, input_ids=None, image_data=None, return_logprob=False, request_info=None)#

Asynchronously generate text using the underlying SGLang engine and return the engine result together with the original input_ids, answers, and idx.

This wrapper calls self._engine.async_generate(...) and forwards the provided arguments. Because the SGLang engine does not include the original input_ids in its response, this method returns the input_ids alongside the engine result for downstream use.

参数:
  • prompt (List[str] | str | None) -- Same as SGLang engine's prompt argument.

  • sampling_params (List[Dict] | Dict | None) -- Same as SGLang engine's sampling_params argument.

  • input_ids (List[List[int]] | List[int] | None) -- Same as SGLang engine's input_ids argument.

  • return_logprob (List[bool] | bool | None) -- Same as SGLang engine's return_logprob argument.

  • request_info (Any | None) -- Any additional request info you wish to be associated with this generation request. This argument will not be passed to the SGLang engine and returned directly.

  • image_data (list | None)

返回:

A tuple containing the engine result and the original request_info.

返回类型:

Tuple[Dict, Any | None]

async offload_engine()#

Release the model weights from the SGLang engine.

async onload_engine()#

Onload the model weights from cpu to the SGLang engine.

async abort_generation()#

Abort the generation.

async sync_model_from_actor()#

Update the weights of the SGLang engine.

Huggingface#

class rlinf.workers.rollout.hf.huggingface_worker.MultiStepRolloutWorker#
__init__(cfg)#

Initialize the Worker with the given parent address and world size.

Only non-Ray workers should provide parent_address, world_size and rank. For example, when a Worker is created via multiprocessing by another Worker, the parent address, world size and rank should be provided.

参数:
  • parent_address (Optional[WorkerAddress]) -- The address of the parent worker. This is used to set up the WorkerAddress for this worker.

  • world_size (Optional[int]) -- The total number of workers in the group. If not provided, it will be set to the environment variable WORLD_SIZE.

  • rank (Optional[int]) -- The rank of this worker in the group. If not provided, it will be set to the environment variable RANK.

  • cfg (DictConfig)

async sync_model_from_actor()#

Sync model parameters from the actor worker.

async recv_env_output(input_channel, mode='train')#

Receive env outputs from mapped env ranks and merge if needed.

参数:
  • input_channel (Channel) -- Channel carrying env->rollout outputs.

  • mode (Literal['train', 'eval']) -- Rollout mode, either "train" or "eval".

返回:

A single env output dict. When multiple env ranks are mapped to this rollout worker, outputs are merged on batch dimension.

返回类型:

dict[str, Any]

send_chunk_actions(output_channel, chunk_actions, mode='train')#

Send action shards to mapped env ranks.

参数:
  • output_channel (Channel) -- Channel carrying rollout->env action chunks.

  • chunk_actions (Tensor | ndarray) -- Predicted action chunk batch (tensor or ndarray).

  • mode (Literal['train', 'eval']) -- Rollout mode, either "train" or "eval".