Adding New Environment#

Add a new reinforcement learning environment to RLinf. RLinf supports various reinforcement learning environments, including robotic manipulation (e.g., ManiSkill3, LIBERO) and others.

The RLinf environment system consists of the following components:

  • Base Environment Classes: Concrete implementations inheriting from gym.Env.

  • Environment Wrappers: Add-on wrappers that provide extra functionality.

  • Task Variants: Implementations of specific tasks or scenarios.

1. Create Base Environment Class#

1.1 Inherit from gym.Env#

import gymnasium as gym
import numpy as np
import torch

class YourCustomEnv(gym.Env):
    def __init__(self, cfg, rank, num_envs, ret_device="cpu"):
        self.cfg = cfg
        self.rank = rank
        self.ret_device = ret_device
        self.seed = self.cfg.seed + rank

        # Initialize environment-related parameters
        self.num_envs = num_envs
        self.group_size = self.cfg.group_size
        self.num_group = self.num_envs // self.group_size

        # Initialize environment internals
        self._init_environment()
        self._init_reset_state_ids()

    def _init_environment(self):
        """Initialize the specific environment instance."""
        # Initialize based on environment type
        pass

    def _init_reset_state_ids(self):
        """Initialize reset state IDs and RNG."""
        self._generator = torch.Generator()
        self._generator.manual_seed(self.seed)
        # Set up reset-state logic
        pass

    def reset(self, options={}):
        """Reset the environment."""
        # Implement environment reset logic
        obs = self._get_observation()
        return obs, {}

    def step(self, actions):
        """Execute actions."""
        # Implement action execution logic
        obs = self._get_observation()
        reward = self._calculate_reward()
        terminated = self._check_termination()
        truncated = self._check_truncation()
        info = self._get_info()
        return obs, reward, terminated, truncated, info

    def _get_observation(self):
        """Retrieve observation."""
        # Implement observation retrieval logic
        pass

    def _calculate_reward(self):
        """Compute reward."""
        # Implement reward calculation logic
        pass

    def _check_termination(self):
        """Check termination conditions."""
        # Implement termination condition checks
        pass

    def _check_truncation(self):
        """Check truncation conditions."""
        # Implement truncation condition checks
        pass

    def _get_info(self):
        """Retrieve info dict."""
        # Implement information retrieval logic
        pass

1.2 Implement Required Properties#

@property
def total_num_group_envs(self):
    """Total number of environment groups."""
    # Implement based on your environment
    pass

@property
def num_envs(self):
    """Number of vectorized environments."""
    return self.num_envs

@property
def device(self):
    """Active device."""
    return self.env.unwrapped.device

2. Implement Environment Offload Support (Optional)#

If you need to support saving/restoring environment state, inherit from EnvOffloadMixin:

from rlinf.envs.env_offload_wrapper import EnvOffloadMixin
import io
import torch

class YourCustomEnv(gym.Env, EnvOffloadMixin):
    def get_state(self) -> bytes:
        """Serialize environment state to bytes."""
        state = {
            "env_state": self.env.get_state(),
            "rng_state": self._generator.get_state(),
            # Add other states as needed
        }
        buffer = io.BytesIO()
        torch.save(state, buffer)
        return buffer.getvalue()

    def load_state(self, state_buffer: bytes):
        """Restore environment state from bytes."""
        buffer = io.BytesIO(state_buffer)
        state = torch.load(buffer, map_location="cpu")
        self.env.set_state(state["env_state"])
        self._generator.set_state(state["rng_state"])
        # Restore other states as needed

3. Create Environment Wrapper#

If you implement offload functionality, create a corresponding wrapper:

# In env_offload_wrapper.py
class YourCustomEnv(BaseYourCustomEnv, EnvOffloadMixin):
    def get_state(self) -> bytes:
        # Implement state saving
        pass

    def load_state(self, state_buffer: bytes):
        # Implement state restoration
        pass

4. Add Action Processing Tools#

Add action processing utilities in action_utils.py:

def prepare_actions_for_your_env(
    raw_chunk_actions,
    num_action_chunks,
    action_dim,
    action_scale,
    policy,
):
    """Prepare actions for your environment."""
    # Implement action processing logic
    pass

def prepare_actions(
    env_type,
    raw_chunk_actions,
    num_action_chunks,
    action_dim,
    action_scale: float = 1.0,
    policy: str = "default",
):
    if env_type == "your_env":
        chunk_actions = prepare_actions_for_your_env(
            raw_chunk_actions=raw_chunk_actions,
            num_action_chunks=num_action_chunks,
            action_dim=action_dim,
            action_scale=action_scale,
            policy=policy,
        )
    # ... other environment types
    return chunk_actions

5. Create Task Variants (Optional)#

If you require specific task variants, place them under envs/YOUR_ENV/tasks/variants/:

# envs/YOUR_ENV/tasks/variants/your_task_variant.py
class YourTaskVariant:
    def __init__(self, config):
        self.config = config

    def setup_task(self):
        """Set up task assets and initial state."""
        pass

    def get_task_description(self):
        """Return a natural-language task description."""
        pass

    def check_success(self, obs, action):
        """Return True if the task is successful."""
        pass

6. Update Configuration Files#

Add your environment configuration:

your_env:
  env_type: "your_env"
  total_num_envs: 8
  group_size: 4
  seed: 42
  # Other environment-specific settings

7. Register Environment#

Expose the new environment in the package:

# In __init__.py or the relevant module
from .your_custom_env import YourCustomEnv

__all__ = ["YourCustomEnv"]

Testing and Validation#

import numpy as np

def test_your_env():
    """Basic smoke test for your environment."""
    cfg = get_test_config()
    env = YourCustomEnv(cfg, rank=0)

    # Reset
    obs, info = env.reset()
    assert obs is not None

    # Step
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    assert obs is not None
    assert isinstance(reward, (float, np.ndarray))