Skip to content

Latest commit

 

History

History
200 lines (141 loc) · 17 KB

File metadata and controls

200 lines (141 loc) · 17 KB

Base Environment (BaseEnv)

The BaseEnv class (located in atroposlib/envs/base.py) provides a foundation for creating custom reinforcement learning environments that interact with Atropos. When creating your own environment, you will typically subclass BaseEnv and implement several key methods.

Design philosophy

Every environment in Atropos is a microservice that generates rollout data async from whatever trainer you attach to it. Environments (possibly many at once) send data to the Atropos API server which sequesters rollout data. Your trainer of choice grabs batches of data from the API and backpropagates.

image

Unlike other popular alternatives like Gymnasium which model environments as MDPs, we think about environments as dataloaders and do not make any assumptions about how a trajectory is produced. For multi-agent for example, this means our design is agnostic to AEC vs. POSG - both are supported out of the box. To achieve this generality, our environment abstraction deviates from other open source alternatives in several key ways.

  • Inference-scoring fusion: A popular design choice in open-source LLM RL trainers is to separate inference and scoring into independent abstractions. While this makes a lot of sense for single-turn environments like one-shot MCQA, we found that this led to awkwardness in multi-turn setups with process rewards. As such, we assume the existence of a single method collect_trajectories which is responsible for both inference and scoring. Users are still welcome to call separate inference and scoring methods from within collect_trajectories.

  • Groups as atomic units of data: A natural choice for a data atom in RL is a single trajectory. However, many popular RL methods for fine-tuning LLMs such as DPO and GRPO involve packing contrastive data into the same batch. As such the most fundamental dataloading method in our abstraction is not collect_trajectory (singular) but collect_trajectories (plural). We do not enforce any definition of what a "group" is other than a set of rollouts. Although a "group" is most commonly constructed by generated multiple rollouts starting from the same initial state (as in DPO and GRPO), a user could just as easily pack n similar-sounding problems with very different solutions into a group. For cases like PPO where advantages don't depend on group statistics a user can simply use group size 1.

  • Environments return tokens (not messages!): One of the most peculiar design choices we made was that at least for text-only environments, environments are responsible for tokenization. This gives us the flexibility to assign token-level rewards and to mix completions-based (e.g. autocomplete suggestion accept/reject) and chat-based (e.g. instruct-model code generation) environments together in the same training run. For cases like multimodal where an OpenAI-formatted message list needs to be passed to a transformers AutoProcessor, we support a list[dict]-valued messages key within our group abstraction ScoredDataGroup.

Working with Servers and ManagedServer

🎯 Recommended Approach: Use ManagedServer for automatic token and logprob tracking!

When implementing collect_trajectory or collect_trajectories, you need to interact with your inference server to generate completions and extract tokens/logprobs for training. The recommended way to do this is using ManagedServer, which automatically handles tokenization, masking, and logprob alignment.

ManagedServer Overview

ManagedServer wraps your APIServer and automatically tracks:

  • Tokens: Full unmasked token sequences
  • Masked Tokens: Training format with -100 for prompt positions, actual token IDs for completion
  • Logprobs: Training format with 1.0 for masked positions, actual logprob values for completion
  • Full Text: Complete text (prompt + completion)
  • Metadata: Finish reasons and other information

Why 1.0 for masked logprobs? It represents an "obviously bad" probability (e^1.0 ≈ 2.718 > 1.0, which is invalid), making it easy to identify and ignore during training.

Basic Usage Pattern

async def collect_trajectories(self, item):
    prompt = format_prompt(item)

    # Use managed server with context manager
    async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
        completion = await managed.completion(
            prompt=prompt,
            n=self.config.group_size,
            max_tokens=4096,
            temperature=1.0,
        )

        # Get tracked sequences with aligned tokens and logprobs
        state = managed.get_state()
        nodes = state["nodes"]

    # Extract pre-computed, guaranteed-aligned data
    for choice, node in zip(completion.choices, nodes):
        tokens = node.tokens                # ✅ Automatically computed
        masked_tokens = node.masked_tokens  # ✅ Automatically masked
        logprobs = node.logprobs            # ✅ Automatically aligned
        finish_reason = node.metadata["finish_reason"]

        # Score and return...

Chat Completion Pattern

For chat-based environments, use chat_completion():

async def collect_trajectories(self, item):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": item["question"]},
    ]

    async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
        chat_completion = await managed.chat_completion(
            messages=messages,
            n=self.config.group_size,
            max_tokens=4096,
        )

        state = managed.get_state()
        nodes = state["nodes"]

    # Process nodes...

Benefits Over Manual Handling

Without ManagedServer:

  • Manually tokenize prompts and completions
  • Manually compute prompt lengths
  • Manually apply masking logic
  • Manually extract and align logprobs
  • Prone to off-by-one errors

With ManagedServer:

  • Automatic tokenization
  • Automatic masking
  • Guaranteed alignment
  • Clean, simple code
  • Works with both completion() and chat_completion() APIs

Complete Documentation

For detailed examples, advanced patterns (multi-turn, RLAIF, backlog workflows), API reference, and migration guide, see:

📚 ManagedServer Complete Guide

Core Methods to Implement

These methods must be implemented in your subclass:

  • async def setup(self): This method is called once at the beginning of the environment's lifecycle (env_manager). Use it for any initial setup required for your specific environment, such as loading datasets, initializing models, or connecting to external resources.

  • async def get_next_item(self) -> Item: This method is responsible for generating or retrieving the next piece of data (prompt, state, etc.) that will be used to start a new trajectory collection. If no more items are available or should be generated, it can return None to signal the worker to pause.

  • async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]: The default implementation of this method runs collect_trajectory (see below) multiple times in parallel (controlled by group_size). You can override this if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input item (e.g. the n parameter in the OpenAI chat completions API) or some desired coupling of rollouts (e.g. via MCTS). It should return the collected group data and a list of backlog items.

  • async def collect_trajectory(self, item: Item) -> Tuple[Any | ScoredDataItem | None, List[Item]]: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the collect_trajectory method and use the default implementation of collect_trajectories which runs group_size instances of collect_trajectory in parallel. This method defines the logic for a single logical trajectory collection step based on the input item.

    • Return value: It returns a tuple containing:\
      1. The ScoredDataItem for this step (one trajectory). This data can be processed further in postprocess_histories, if you require additional filtering right before sending to the API.
      2. A list of new Item objects to be added to the backlog for future processing (e.g., follow-up prompts).\
    • Should I define collect_trajectory or override collect_trajectories? If you've got some way to generate your group more efficiently than a bunch of separate but parallel calls to collect_trajectory, or if your rollouts aren't independent as in MCTS, you should override collect_trajectories. If simplicity and iteration speed is more valuable than efficiency (e.g. at the start of a development cycle) and your rollouts are independent then collect_trajectory is for you.
  • async def evaluate(self, *args, **kwargs): This method is called periodically (controlled by steps_per_eval in the config) to perform evaluation runs. You define the evaluation logic here. The base class provides an example using self.eval_workers for parallel evaluation tasks, but you can implement any evaluation procedure suitable for your environment.

Optional Methods to Override

These methods have default implementations or are optional based on your needs:

  • async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]: The default implementation of this method runs collect_trajectory multiple times in parallel (controlled by group_size). You can override this instead of collect_trajectory if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input item. It should return the collected group data and a list of backlog items.

  • async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: This method is called after collect_trajectories and before the data is sent to the training server. It receives the collected data from the parallel runs (or your custom collect_trajectories implementation). Use this to perform final processing, scoring, or formatting you may require before sending to the server. You usually won't need this.

  • async def wandb_log(self, wandb_metrics: Optional[Dict] = None): Called periodically to log metrics to Weights & Biases. If you override this to add custom metrics, ensure you call super().wandb_log(wandb_metrics) at the end of your implementation. This ensures that the base class's performance metrics and rollout tables are also logged.

    async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
        if wandb_metrics is None:
            wandb_metrics = {}
        # Add your custom metrics
        wandb_metrics['my_custom_metric'] = calculate_my_metric()
        # ... add more metrics
    
        # Call the parent method to log base metrics
        await super().wandb_log(wandb_metrics)
  • save_checkpoint(self, step, data=None): The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provided data dictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize what data is saved or how it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic.

  • @classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]: This class method is used by the default get_cli_serve_config_cls implementation to get the initial environment configuration (BaseEnvConfig subclass) and server configurations (ServerBaseline or List[APIServerConfig]) when setting up the serve command. The default implementation returns cls.env_config_cls(), ServerBaseline(). You might override this if your environment requires different default configurations or specific server setups (like multiple APIServerConfig instances) when run via the CLI serve command.

  • async def cleanup(self): Called after each call to handle_env. You can implement this for any cleanup needed after processing a single item, though it's often not required.

Overridable Class Variables

These class-level variables in BaseEnv can be overridden in your subclass to customize its behavior:

  • name: Optional[str]:

    • Default: None
    • Purpose: You can set a string name for your environment. This name is used by default for wandb_name in the BaseEnvConfig if not otherwise specified, influencing how runs are grouped or named in Weights & Biases. It can also be useful for general identification or logging purposes.
  • env_config_cls: Type[BaseEnvConfig]:

    • Default: BaseEnvConfig
    • Purpose: This variable holds the Pydantic model class that will be used for your environment's configuration. If your environment requires custom configuration fields beyond what BaseEnvConfig offers, you should create a new class that inherits from BaseEnvConfig (or a subclass thereof) and assign it to env_config_cls. This allows the CLI and other parts of the system to correctly parse and manage your environment's specific settings.
    from pydantic import Field
    from atroposlib.envs import BaseEnv, BaseEnvConfig
    
    class MyEnvConfig(BaseEnvConfig):
        my_custom_param: str = Field(default="default_value", description="A custom parameter for MyEnv")
    
    class MyEnv(BaseEnv):
        env_config_cls = MyEnvConfig
        name = "MyCustomEnvironment"
        # ... other implementations
  • server_cls: Type[APIServer]:

    • Default: APIServer
    • Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing additional API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need.
    • Note: In most cases, you should use the server_type field in your APIServerConfig instead of overriding this. Set server_type to "openai" (default), "vllm", "sglang", or "trl" to automatically use the appropriate server class with enhanced features like native API access and full token/logprob tracking.

Provided Functionality

BaseEnv provides several helpful features:

  • Parallel Trajectory Collection (collect_trajectories): The base implementation runs your collect_trajectory method multiple times in parallel (based on group_size) and gathers the results. You can override collect_trajectories directly for custom group generation logic (see Optional Methods).
  • Server Interaction: Handles registration with the rollout server, fetching configuration (like batch_size), sending scored data (handle_send_to_api with retries), and status updates.
  • WandB Integration: Sets up WandB logging (if enabled) based on server information and provides the wandb_log hook for custom metrics (remember to call super().wandb_log()). It uses helper methods add_rollouts_for_wandb (to temporarily store rollout data) and create_rollout_table (to format the data into a wandb.Table). You can override either of these helpers for custom logging behavior (e.g., changing what data is stored or how the final table is structured).
  • Checkpointing:
    • The environment automatically triggers checkpoint saves based on the checkpoint_interval received from the server, calling the save_checkpoint method (see Optional Methods).
    • load_checkpoint(self): Loads data from the checkpoint file corresponding to the environment's curr_step. It attempts to restore attributes of the environment object based on the keys in the loaded JSON data. This is called automatically if curr_step > 0 during registration.
  • Worker Management: Manages asynchronous worker tasks for collecting trajectories (add_train_workers, handle_env).
  • Performance Monitoring: Tracks and logs various performance statistics (task durations, worker counts, etc.).
  • CLI Integration: Provides a cli() class method using pydantic-cli to easily create command-line interfaces for your environment (e.g., python your_env_module.py serve --port 8001 ...). See get_cli_serve_config_cls and get_cli_process_config_cls.

By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the Atropos framework.