The BaseEnv class (located in atroposlib/envs/base.py) provides a foundation for creating custom reinforcement learning environments that interact with Atropos. When creating your own environment, you will typically subclass BaseEnv and implement several key methods.
Every environment in Atropos is a microservice that generates rollout data async from whatever trainer you attach to it. Environments (possibly many at once) send data to the Atropos API server which sequesters rollout data. Your trainer of choice grabs batches of data from the API and backpropagates.
Unlike other popular alternatives like Gymnasium which model environments as MDPs, we think about environments as dataloaders and do not make any assumptions about how a trajectory is produced. For multi-agent for example, this means our design is agnostic to AEC vs. POSG - both are supported out of the box. To achieve this generality, our environment abstraction deviates from other open source alternatives in several key ways.
-
Inference-scoring fusion: A popular design choice in open-source LLM RL trainers is to separate inference and scoring into independent abstractions. While this makes a lot of sense for single-turn environments like one-shot MCQA, we found that this led to awkwardness in multi-turn setups with process rewards. As such, we assume the existence of a single method
collect_trajectorieswhich is responsible for both inference and scoring. Users are still welcome to call separate inference and scoring methods from withincollect_trajectories. -
Groups as atomic units of data: A natural choice for a data atom in RL is a single trajectory. However, many popular RL methods for fine-tuning LLMs such as DPO and GRPO involve packing contrastive data into the same batch. As such the most fundamental dataloading method in our abstraction is not
collect_trajectory(singular) butcollect_trajectories(plural). We do not enforce any definition of what a "group" is other than a set of rollouts. Although a "group" is most commonly constructed by generated multiple rollouts starting from the same initial state (as in DPO and GRPO), a user could just as easily packnsimilar-sounding problems with very different solutions into a group. For cases like PPO where advantages don't depend on group statistics a user can simply use group size 1. -
Environments return tokens (not messages!): One of the most peculiar design choices we made was that at least for text-only environments, environments are responsible for tokenization. This gives us the flexibility to assign token-level rewards and to mix completions-based (e.g. autocomplete suggestion accept/reject) and chat-based (e.g. instruct-model code generation) environments together in the same training run. For cases like multimodal where a OpenAI-formatted message list needs to be passed to a transformers
AutoProcessor, we support alist[dict]-valuedmessageskey within our group abstraction ScoredDataGroup.
These methods must be implemented in your subclass:
-
async def setup(self): This method is called once at the beginning of the environment's lifecycle (env_manager). Use it for any initial setup required for your specific environment, such as loading datasets, initializing models, or connecting to external resources. -
async def get_next_item(self) -> Item: This method is responsible for generating or retrieving the next piece of data (prompt, state, etc.) that will be used to start a new trajectory collection. If no more items are available or should be generated, it can returnNoneto signal the worker to pause. -
async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]: The default implementation of this method runscollect_trajectory(see below) multiple times in parallel (controlled bygroup_size). You can override this if you have a more efficient way to generate the entire group of responses/trajectories at once based on the inputitem(e.g. thenparameter in the OpenAI chat completions API) or some desired coupling of rollouts (e.g. via MCTS). It should return the collected group data and a list of backlog items. -
async def collect_trajectory(self, item: Item) -> Tuple[Any | ScoredDataItem | None, List[Item]]: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define thecollect_trajectorymethod and use the default implementation ofcollect_trajectorieswhich runsgroup_sizeinstances ofcollect_trajectoryin parallel. This method defines the logic for a single logical trajectory collection step based on the inputitem.- Return value: It returns a tuple containing:\
- The ScoredDataItem for this step (one trajectory). This data can be processed further in
postprocess_histories, if you require additional filtering right before sending to the API. - A list of new
Itemobjects to be added to the backlog for future processing (e.g., follow-up prompts).\
- The ScoredDataItem for this step (one trajectory). This data can be processed further in
- Should I define
collect_trajectoryor overridecollect_trajectories? If you've got some way to generate your group more efficiently than a bunch of separate but parallel calls tocollect_trajectory, or if your rollouts aren't independent as in MCTS, you should overridecollect_trajectories. If simplicity and iteration speed is more valuable than efficiency (e.g. at the start of a development cycle) and your rollouts are independent thencollect_trajectoryis for you.
- Return value: It returns a tuple containing:\
-
async def evaluate(self, *args, **kwargs): This method is called periodically (controlled bysteps_per_evalin the config) to perform evaluation runs. You define the evaluation logic here. The base class provides an example usingself.eval_workersfor parallel evaluation tasks, but you can implement any evaluation procedure suitable for your environment.
These methods have default implementations or are optional based on your needs:
-
async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]: The default implementation of this method runscollect_trajectorymultiple times in parallel (controlled bygroup_size). You can override this instead ofcollect_trajectoryif you have a more efficient way to generate the entire group of responses/trajectories at once based on the inputitem. It should return the collected group data and a list of backlog items. -
async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: This method is called aftercollect_trajectoriesand before the data is sent to the training server. It receives the collected data from the parallel runs (or your customcollect_trajectoriesimplementation). Use this to perform final processing, scoring, or formatting you may require before sending to the server. You usually won't need this. -
async def wandb_log(self, wandb_metrics: Optional[Dict] = None): Called periodically to log metrics to Weights & Biases. If you override this to add custom metrics, ensure you callsuper().wandb_log(wandb_metrics)at the end of your implementation. This ensures that the base class's performance metrics and rollout tables are also logged.async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = {} # Add your custom metrics wandb_metrics['my_custom_metric'] = calculate_my_metric() # ... add more metrics # Call the parent method to log base metrics await super().wandb_log(wandb_metrics)
-
save_checkpoint(self, step, data=None): The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provideddatadictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize what data is saved or how it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic. -
@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]: This class method is used by the defaultget_cli_serve_config_clsimplementation to get the initial environment configuration (BaseEnvConfigsubclass) and server configurations (ServerBaselineorList[APIServerConfig]) when setting up theservecommand. The default implementation returnscls.env_config_cls(), ServerBaseline(). You might override this if your environment requires different default configurations or specific server setups (like multipleAPIServerConfiginstances) when run via the CLIservecommand. -
async def cleanup(self): Called after each call tohandle_env. You can implement this for any cleanup needed after processing a single item, though it's often not required.
These class-level variables in BaseEnv can be overridden in your subclass to customize its behavior:
-
name: Optional[str]:- Default:
None - Purpose: You can set a string name for your environment. This name is used by default for
wandb_namein theBaseEnvConfigif not otherwise specified, influencing how runs are grouped or named in Weights & Biases. It can also be useful for general identification or logging purposes.
- Default:
-
env_config_cls: Type[BaseEnvConfig]:- Default:
BaseEnvConfig - Purpose: This variable holds the Pydantic model class that will be used for your environment's configuration. If your environment requires custom configuration fields beyond what
BaseEnvConfigoffers, you should create a new class that inherits fromBaseEnvConfig(or a subclass thereof) and assign it toenv_config_cls. This allows the CLI and other parts of the system to correctly parse and manage your environment's specific settings.
from pydantic import Field from atroposlib.envs import BaseEnv, BaseEnvConfig class MyEnvConfig(BaseEnvConfig): my_custom_param: str = Field(default="default_value", description="A custom parameter for MyEnv") class MyEnv(BaseEnv): env_config_cls = MyEnvConfig name = "MyCustomEnvironment" # ... other implementations
- Default:
-
server_cls: Type[APIServer]:- Default:
APIServer - Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing additional API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need.
- Default:
BaseEnv provides several helpful features:
- Parallel Trajectory Collection (
collect_trajectories): The base implementation runs yourcollect_trajectorymethod multiple times in parallel (based ongroup_size) and gathers the results. You can overridecollect_trajectoriesdirectly for custom group generation logic (see Optional Methods). - Server Interaction: Handles registration with the rollout server, fetching configuration (like
batch_size), sending scored data (handle_send_to_apiwith retries), and status updates. - WandB Integration: Sets up WandB logging (if enabled) based on server information and provides the
wandb_loghook for custom metrics (remember to callsuper().wandb_log()). It uses helper methodsadd_rollouts_for_wandb(to temporarily store rollout data) andcreate_rollout_table(to format the data into awandb.Table). You can override either of these helpers for custom logging behavior (e.g., changing what data is stored or how the final table is structured). - Checkpointing:
- The environment automatically triggers checkpoint saves based on the
checkpoint_intervalreceived from the server, calling thesave_checkpointmethod (see Optional Methods). load_checkpoint(self): Loads data from the checkpoint file corresponding to the environment'scurr_step. It attempts to restore attributes of the environment object based on the keys in the loaded JSON data. This is called automatically ifcurr_step > 0during registration.
- The environment automatically triggers checkpoint saves based on the
- Worker Management: Manages asynchronous worker tasks for collecting trajectories (
add_train_workers,handle_env). - Performance Monitoring: Tracks and logs various performance statistics (task durations, worker counts, etc.).
- CLI Integration: Provides a
cli()class method usingpydantic-clito easily create command-line interfaces for your environment (e.g.,python your_env_module.py serve --port 8001 ...). Seeget_cli_serve_config_clsandget_cli_process_config_cls.
By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the Atropos framework.
