Adding New Environment#

This document provides detailed instructions on how to add new environments to the RLinf framework. RLinf supports various reinforcement learning environments, including robotic manipulation (e.g., ManiSkill3, LIBERO) and others.

The RLinf environment system consists of the following components:

  • Base Environment Classes: Concrete implementations inheriting from gym.Env.

  • Environment Wrappers: Add-on wrappers that provide extra functionality.

  • Task Variants: Implementations of specific tasks or scenarios.

1. Create Base Environment Class#

1.1 Inherit from gym.Env#

import gymnasium as gym
import numpy as np
import torch

class YourCustomEnv(gym.Env):
    def __init__(self, cfg, rank, num_envs, ret_device="cpu"):
        self.cfg = cfg
        self.rank = rank
        self.ret_device = ret_device
        self.seed = self.cfg.seed + rank

        # Initialize environment-related parameters
        self.num_envs = num_envs
        self.group_size = self.cfg.group_size
        self.num_group = self.num_envs // self.group_size

        # Initialize environment internals
        self._init_environment()
        self._init_reset_state_ids()

    def _init_environment(self):
        """Initialize the specific environment instance."""
        # Initialize based on environment type
        pass

    def _init_reset_state_ids(self):
        """Initialize reset state IDs and RNG."""
        self._generator = torch.Generator()
        self._generator.manual_seed(self.seed)
        # Set up reset-state logic
        pass

    def reset(self, options={}):
        """Reset the environment."""
        # Implement environment reset logic
        obs = self._get_observation()
        return obs, {}

    def step(self, actions):
        """Execute actions."""
        # Implement action execution logic
        obs = self._get_observation()
        reward = self._calculate_reward()
        terminated = self._check_termination()
        truncated = self._check_truncation()
        info = self._get_info()
        return obs, reward, terminated, truncated, info

    def _get_observation(self):
        """Retrieve observation."""
        # Implement observation retrieval logic
        pass

    def _calculate_reward(self):
        """Compute reward."""
        # Implement reward calculation logic
        pass

    def _check_termination(self):
        """Check termination conditions."""
        # Implement termination condition checks
        pass

    def _check_truncation(self):
        """Check truncation conditions."""
        # Implement truncation condition checks
        pass

    def _get_info(self):
        """Retrieve info dict."""
        # Implement information retrieval logic
        pass

1.2 Implement Required Properties#

@property
def total_num_group_envs(self):
    """Total number of environment groups."""
    # Implement based on your environment
    pass

@property
def num_envs(self):
    """Number of vectorized environments."""
    return self.num_envs

@property
def device(self):
    """Active device."""
    return self.env.unwrapped.device

2. Implement Environment Offload Support (Optional)#

If you need to support saving/restoring environment state, inherit from EnvOffloadMixin:

from rlinf.envs.env_offload_wrapper import EnvOffloadMixin
import io
import torch

class YourCustomEnv(gym.Env, EnvOffloadMixin):
    def get_state(self) -> bytes:
        """Serialize environment state to bytes."""
        state = {
            "env_state": self.env.get_state(),
            "rng_state": self._generator.get_state(),
            # Add other states as needed
        }
        buffer = io.BytesIO()
        torch.save(state, buffer)
        return buffer.getvalue()

    def load_state(self, state_buffer: bytes):
        """Restore environment state from bytes."""
        buffer = io.BytesIO(state_buffer)
        state = torch.load(buffer, map_location="cpu")
        self.env.set_state(state["env_state"])
        self._generator.set_state(state["rng_state"])
        # Restore other states as needed

3. Create Environment Wrapper#

If you implement offload functionality, create a corresponding wrapper:

# In env_offload_wrapper.py
class YourCustomEnv(BaseYourCustomEnv, EnvOffloadMixin):
    def get_state(self) -> bytes:
        # Implement state saving
        pass

    def load_state(self, state_buffer: bytes):
        # Implement state restoration
        pass

4. Add Action Processing Tools#

Add action processing utilities in action_utils.py:

def prepare_actions_for_your_env(
    raw_chunk_actions,
    num_action_chunks,
    action_dim,
    action_scale,
    policy,
):
    """Prepare actions for your environment."""
    # Implement action processing logic
    pass

def prepare_actions(
    env_type,
    raw_chunk_actions,
    num_action_chunks,
    action_dim,
    action_scale: float = 1.0,
    policy: str = "default",
):
    if env_type == "your_env":
        chunk_actions = prepare_actions_for_your_env(
            raw_chunk_actions=raw_chunk_actions,
            num_action_chunks=num_action_chunks,
            action_dim=action_dim,
            action_scale=action_scale,
            policy=policy,
        )
    # ... other environment types
    return chunk_actions

5. Create Task Variants (Optional)#

If you require specific task variants, place them under envs/YOUR_ENV/tasks/variants/:

# envs/YOUR_ENV/tasks/variants/your_task_variant.py
class YourTaskVariant:
    def __init__(self, config):
        self.config = config

    def setup_task(self):
        """Set up task assets and initial state."""
        pass

    def get_task_description(self):
        """Return a natural-language task description."""
        pass

    def check_success(self, obs, action):
        """Return True if the task is successful."""
        pass

6. Update Configuration Files#

Add your environment configuration:

your_env:
  env_type: "your_env"
  total_num_envs: 8
  group_size: 4
  seed: 42
  # Other environment-specific settings

7. Register Environment#

Expose the new environment in the package:

# In __init__.py or the relevant module
from .your_custom_env import YourCustomEnv

__all__ = ["YourCustomEnv"]

Testing and Validation#

import numpy as np

def test_your_env():
    """Basic smoke test for your environment."""
    cfg = get_test_config()
    env = YourCustomEnv(cfg, rank=0)

    # Reset
    obs, info = env.reset()
    assert obs is not None

    # Step
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    assert obs is not None
    assert isinstance(reward, (float, np.ndarray))