Channel 接口#
本节详细介绍 RLinf 中的 Channel, 它是一种用于异步通信的高层抽象,实现形式为生产者–消费者队列。
Channel#
- class rlinf.scheduler.Channel
A FIFO queue-like channel for inter-worker communication.
Creation: Channel can be created both inside and outside of worker contexts. The recommended practice is to create channels outside of worker contexts using
Channel.create(), and then pass them into workers as needed. You can also create channels inside worker contexts or connect to existing channels, usingself.create_channel()orself.connect_channel().Interface: Similar as the
asyncio.Queue, theChannelprovides interfaces likeput,get,put_no_wait, andget_no_wait, as well as query interfaces likeqsize,empty, andfull. The semantics of these interfaces are identical to those ofasyncio.Queue.Features:
Async operation: Channel supports both synchronous and asynchronous
putandgetoperations, similar to Worker'ssendandrecvAPIs. Both operations accept arbitrary data item as long as it's serializable. The default behavior is synchronous, and async operations can be enabled by setting theasync_opflag. This async can be used not only in asyncio context withawait channel.get(async_op=True).async_wait(), but also in non-asyncio contexts by generating a communication handle that can be waited later, like async torch.distributed.send().Key-based routing: Channel allows specifying a
keyfor each data item, which can be used to identify and route messages. For example, if you wish a specific data to be get and processed by a specific worker, you can assign a unique key to that data item when putting it into the channel. The target worker can then use this key to retrieve the specific data item.This is useful in multi-turn scenarios in agent and embodied RL, where a data is processed by a fixed set of workers.Weight and batch processing: Channel also supports assigning weights to individual data items, allowing for more fine-grained control over how messages are processed. A
get_batchmethod can be used to retrieve a batch of messages which respects the assigned weights.Debugging: Channel allows you to print a Channel's internal data by directly print the Channel object.
Example:
>>> import sys >>> import os >>> import asyncio >>> import torch >>> from rlinf.scheduler import ( ... Worker, ... Cluster, ... PackedPlacementStrategy, ... ) >>> >>> class Producer(Worker): ... def __init__(self): ... super().__init__() ... ... def produce(self, channel: Channel): ... # Synchronous put of common object ... channel.put("Hello from Producer") ... ... # Synchronous put of tensor ... tensor = torch.ones(1, device=torch.cuda.current_device()) ... channel.put(tensor) ... ... # Asynchronous put of common object ... async_work = channel.put( ... "Hello from Producer asynchronously", async_op=True ... ) ... async_work.wait() ... ... # Asynchronous put using asyncio ... async_work = channel.put(tensor, async_op=True) ... ... async def wait_async(): ... await async_work.async_wait() ... ... asyncio.run(wait_async()) ... ... # Put object with weight ... channel.put("Hello with weight", weight=1) ... channel.put(tensor, weight=2) >>> >>> class Consumer(Worker): ... def __init__(self): ... super().__init__() ... ... def consume(self, channel: Channel): ... tensor = channel.get() ... ... async_work = channel.get(async_op=True) ... async_result = async_work.wait() ... ... async_work = channel.get(async_op=True) ... ... async def wait_async(): ... result = await async_work.async_wait() ... ... asyncio.run(wait_async()) ... ... # Get batch of objects based on weight ... batch = channel.get_batch(target_weight=3) >>> >>> cluster = Cluster(num_nodes=1) >>> channel = Channel.create(name="channel") >>> placement = PackedPlacementStrategy( ... start_hardware_rank=0, end_hardware_rank=0 ... ) >>> producer = Producer.create_group().launch( ... cluster, name="test", placement_strategy=placement ... ) >>> consumer = Consumer.create_group().launch( ... cluster, name="test2", placement_strategy=placement ... ) >>> r1 = producer.produce(channel) >>> r2 = consumer.consume(channel) >>> res = r1.wait() >>> res = r2.wait()
- classmethod create(name, maxsize=0, distributed=False, node_rank=0, local=False, disable_distributed_log=True)
Create a new channel with the specified name, node ID, and accelerator ID.
- 参数:
name (
str) -- The name of the channel.maxsize (
int) -- The maximum size of the channel queue. Defaults to 0 (unbounded).distributed (
bool) -- Whether the channel should be distributed. A distributed channel creates distributed workers on each node, and routes communications to the channel worker on the same node as the current worker, benefiting from the locality of the data. The routing is based on the key of the put/get APIs. So if you expect the key to be randomly distributed, you should set this to False to avoid unnecessary routing overhead.node_rank (
int) -- The node rank of the current worker. Only valid when distributed is False.local (
bool) -- Create the channel for intra-process communication. A local channel cannot be connected by other workers, and its data cannot be shared among different processes.disable_distributed_log (
bool) -- Whether to disable distributed log for the channel.
- 返回:
A new instance of the Channel class.
- 返回类型:
Channel
- classmethod connect(name, current_worker)
Connect to an existing channel.
- 参数:
name (
str) -- The name of the channel to connect to.current_worker (
Worker) -- The current worker that is connecting to the channel.
- 返回:
An instance of the Channel class connected to the specified channel.
- 返回类型:
Channel
- property is_local
Check if the channel is a local channel.
- qsize(key='default_queue')
Get the size of the channel queue.
- 参数:
key (
Any) -- check the queue associated with the key.- 返回:
The number of items in the channel queue.
- 返回类型:
int
- empty(key='default_queue')
Check if the channel queue is empty.
- 参数:
key (
Any) -- The key to check the queue emptiness for.- 返回:
True if the channel queue is empty, False otherwise.
- 返回类型:
bool
- full(key='default_queue')
Check if the channel queue is full.
- 参数:
key (
Any) -- The key to check the queue fullness for.- 返回:
True if the channel queue is full, False otherwise.
- 返回类型:
bool
- put(item, weight=0, key='default_queue', async_op=False)
Put an item into the channel queue.
- 参数:
item (
Any) -- The item to put into the channel queue.weight (
int) -- The priority weight of the item. Defaults to 0.key (
Any) -- The key to get the item from. A unique identifier for a specific set of items.given (When a key is)
key. (the channel will put the item in the queue associated with that)
exist (If the queue associated with the key does not)
created. (it will be)
async_op (
bool) -- Whether to perform the operation asynchronously.
- 返回类型:
AsyncWork | None
- put_nowait(item, weight=0, key='default_queue')
Put an item into the channel queue without waiting. Raises asyncio.QueueFull if the queue is full.
- 参数:
item (
Any) -- The item to put into the channel queue.weight (
int) -- The priority weight of the item. Defaults to 0.key (
Any) -- The key to get the item from. A unique identifier for a specific set of items.given (When a key is)
key. (the channel will put the item in the queue associated with that)
exist (If the queue associated with the key does not)
created. (it will be)
- 抛出:
asyncio.QueueFull -- If the queue is full.
- get(key='default_queue', async_op=False)
Get an item from the channel queue.
- 参数:
key (
Any) -- The key to get the item from. A unique identifier for a specific set of items.given (When a key is)
key. (the channel will look for the item in the queue associated with that)
async_op (
bool) -- Whether to perform the operation asynchronously.
- 返回:
The item retrieved from the channel queue.
- 返回类型:
Any
- get_nowait(key='default_queue')
Get an item from the channel queue without waiting. Raises asyncio.QueueEmpty if the queue is empty.
- 参数:
key (
Any) -- The key to get the item from. A unique identifier for a specific set of items.given (When a key is)
key. (the channel will look for the item in the queue associated with that)
- 返回:
The item retrieved from the channel queue.
- 返回类型:
Any
- 抛出:
asyncio.QueueEmpty -- If the queue is empty.
- get_batch(target_weight=0, key='default_queue', async_op=False)
Get a batch of items from the channel queue based on the set batch weight.
It will try to get items until the total weight of the items is about to (i.e., the next item will) exceed the set batch weight.
- 参数:
target_weight (
int) -- The target weight for the batch.key (
Any) -- The key to get the item from. A unique identifier for a specific set of items.given (When a key is)
key. (the channel will look for the item in the queue associated with that)
async_op (
bool) -- Whether to perform the operation asynchronously.
- 返回:
A list of items retrieved from the channel queue.
- 返回类型:
List[Any]