使用 Channel 进行通信#

channel 模块为 Worker 之间的异步数据交换提供了一个高层次的 分布式生产者–消费者队列 抽象。 一个 Channel 允许一个或多个生产者 Worker 向命名队列中 put 数据项, 并允许一个或多个消费者 Worker get 这些数据项, 同时可以选择基于每个数据项的权重来累积 批次

Channel 的创建与连接#

可以通过如下方式创建一个新的 channel:

Worker.create_channel(
    channel_name,
    node_id=0,
    maxsize=0
)

该方法:

  • 确定放置位置 — 如果未指定 group_affinitygroup_rank_affinity,则 channel 会托管在当前 Worker 的 grouprank 上(即相同节点和 GPU)。

  • 启动专用的 channel actor — 使用 PackedPlacementStrategy 在所选节点/GPU 上启动一个 ChannelWorker (实际持有队列),并设置 num_processes=1

  • 返回 一个 Channel 对象,用于封装该 actor。channel actor 的地址为 channel_name:0

若要从其他 Worker 连接到已存在的 channel,请使用:

Worker.connect_channel(channel_name)

该方法会在 Ray 命名空间中查找对应的 channel actor,并返回一个与该 actor 和当前 Worker 绑定的 Channel 对象。

向 Channel 中放入数据#

使用 channel.put(item, weight=0, key="default", async_op=False) 发送数据。

  • 发送 Worker 首先将 item 传输给实际拥有目标队列的 ChannelWorker

  • ChannelWorker 接收数据后,将其封装为一个带有指定 weightWeightedItem,并放入指定队列。 如果队列设置了大小限制(maxsize > 0)且已满,则入队会阻塞,直到队列有空间可用。

从 Channel 中获取数据#

使用 channel.get(key="default", async_op=False) 获取数据,这实际上是 put 的逆过程。

  • ChannelWorker 会先从指定队列中取出一个数据项。

  • 然后将该数据项发送给请求的 Worker,并最终返回给调用者。

批量获取#

使用 channel.get_batch(batch_weight, key="default", async_op=False) 一次获取多个数据。

  • ChannelWorker 会不断从队列中取出数据项,并累加其权重值。

  • 当累计权重达到或超过 batch_weight 时,停止取数。

  • 所有取出的数据项会组合成一个列表,并通过一次消息发送给请求的 Worker。

该功能适合在处理体验或任务时动态形成批次, 当每个数据项有不同的开销或大小(权重)时,可以保证批次大致均匀。

负载均衡#

在 Rollout 阶段,轨迹长度往往差异较大。 如果不加设计地直接分配到各个数据并行(DP)训练组,会导致严重的负载不均。

为了解决这一问题,我们实现了基于 channel 的负载均衡机制。 具体来说,生成阶段的所有生成器会依次将完整的 rollout 轨迹 put 到共享的 rollout_output_queue 中。 由于轨迹按时间顺序插入,rollout_output_queue 中的序列长度会随时间逐渐增长。

然后使用轮询策略,我们不断从 rollout_output_queueget 轨迹, 并依次分配给每个 DP 训练组。 这种方式能够近似实现各个 DP 训练组之间的工作量均衡, 从而确保训练过程中的更好利用率和效率。

示例#

class rlinf.scheduler.Channel

A FIFO queue-like channel for inter-worker communication.

Creation: Channel can be created both inside and outside of worker contexts. The recommended practice is to create channels outside of worker contexts using Channel.create(), and then pass them into workers as needed. You can also create channels inside worker contexts or connect to existing channels, using self.create_channel() or self.connect_channel().

Interface: Similar as the asyncio.Queue, the Channel provides interfaces like put, get, put_no_wait, and get_no_wait, as well as query interfaces like qsize, empty, and full. The semantics of these interfaces are identical to those of asyncio.Queue.

Features:

  1. Async operation: Channel supports both synchronous and asynchronous put and get operations, similar to Worker's send and recv APIs. Both operations accept arbitrary data item as long as it's serializable. The default behavior is synchronous, and async operations can be enabled by setting the async_op flag. This async can be used not only in asyncio context with await channel.get(async_op=True).async_wait(), but also in non-asyncio contexts by generating a communication handle that can be waited later, like async torch.distributed.send().

  2. Key-based routing: Channel allows specifying a key for each data item, which can be used to identify and route messages. For example, if you wish a specific data to be get and processed by a specific worker, you can assign a unique key to that data item when putting it into the channel. The target worker can then use this key to retrieve the specific data item.This is useful in multi-turn scenarios in agent and embodied RL, where a data is processed by a fixed set of workers.

  3. Weight and batch processing: Channel also supports assigning weights to individual data items, allowing for more fine-grained control over how messages are processed. A get_batch method can be used to retrieve a batch of messages which respects the assigned weights.

  4. Debugging: Channel allows you to print a Channel's internal data by directly print the Channel object.

Example:

>>> import sys
>>> import os
>>> import asyncio
>>> import torch
>>> from rlinf.scheduler import (
...     Worker,
...     Cluster,
...     PackedPlacementStrategy,
... )
>>>
>>> class Producer(Worker):
...     def __init__(self):
...         super().__init__()
...
...     def produce(self, channel: Channel):
...         # Synchronous put of common object
...         channel.put("Hello from Producer")
...
...         # Synchronous put of tensor
...         tensor = torch.ones(1, device=torch.cuda.current_device())
...         channel.put(tensor)
...
...         # Asynchronous put of common object
...         async_work = channel.put(
...             "Hello from Producer asynchronously", async_op=True
...         )
...         async_work.wait()
...
...         # Asynchronous put using asyncio
...         async_work = channel.put(tensor, async_op=True)
...
...         async def wait_async():
...             await async_work.async_wait()
...
...         asyncio.run(wait_async())
...
...         # Put object with weight
...         channel.put("Hello with weight", weight=1)
...         channel.put(tensor, weight=2)
>>>
>>> class Consumer(Worker):
...     def __init__(self):
...         super().__init__()
...
...     def consume(self, channel: Channel):
...         tensor = channel.get()
...
...         async_work = channel.get(async_op=True)
...         async_result = async_work.wait()
...
...         async_work = channel.get(async_op=True)
...
...         async def wait_async():
...             result = await async_work.async_wait()
...
...         asyncio.run(wait_async())
...
...         # Get batch of objects based on weight
...         batch = channel.get_batch(target_weight=3)
>>>
>>> cluster = Cluster(num_nodes=1)
>>> channel = Channel.create(name="channel")
>>> placement = PackedPlacementStrategy(
...     start_hardware_rank=0, end_hardware_rank=0
... )
>>> producer = Producer.create_group().launch(
...     cluster, name="test", placement_strategy=placement
... )
>>> consumer = Consumer.create_group().launch(
...     cluster, name="test2", placement_strategy=placement
... )
>>> r1 = producer.produce(channel)
>>> r2 = consumer.consume(channel)
>>> res = r1.wait()
>>> res = r2.wait()

总结#

Channel 组件为 Worker 通信提供了一个分布式生产者–消费者队列。 它在集体通信 send/recv 机制的基础上进行了封装,提供了直观的接口,支持优先级和批处理, 实现了解耦的、异步的数据流,非常适合在并行数据采集与批量消费的强化学习场景中使用。