Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
# Import presets first to register defaults
# This must come before we expose make and info to ensure presets are available
from ares import presets # noqa: F401
from ares.environments.base import CheckpointableEnvironment
from ares.environments.base import Environment
from ares.environments.base import EnvironmentCheckpoint
from ares.environments.base import TimeStep
from ares.registry import EnvironmentInfo

Expand All @@ -58,7 +60,9 @@

# Define public API
__all__ = [
"CheckpointableEnvironment",
"Environment",
"EnvironmentCheckpoint",
"EnvironmentInfo",
"TimeStep",
"info",
Expand Down
146 changes: 146 additions & 0 deletions src/ares/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
dm_env Environment protocol and utilities for ARES.
"""

import atexit
from collections.abc import Awaitable, Callable
import functools
import logging
import os
Expand Down Expand Up @@ -149,6 +151,150 @@ async def __aexit__(
await self.close()


class EnvironmentCheckpoint[EnvType]:
"""A restorable snapshot of environment state at a step boundary.

Generic over the environment type so that restore() returns the concrete type.
The checkpoint is self-contained: it holds closures for restore and release
that capture whatever data and factories are needed. The environment that
creates it decides how those closures work.

restore() creates a brand new environment instance in the checkpointed state.
The original environment is not affected. Multiple restores from the same
checkpoint create independent environments.

Usage::

checkpoint = await env.checkpoint()
env_restored, ts = await checkpoint.restore()
# env_restored is a new, independent environment at the checkpointed state

Resource management:
Each checkpoint may hold expensive resources (e.g., a Docker image from
``docker commit``). Call release() when no longer needed, or rely on the
global checkpoint janitor for atexit cleanup.
"""

def __init__(
self,
*,
restore_fn: Callable[[], Awaitable[EnvType]],
release_fn: Callable[[], Awaitable[None]] | None = None,
release_fn_sync: Callable[[], None] | None = None,
step_count: int,
timestep: TimeStep,
):
self._restore_fn = restore_fn
self._release_fn = release_fn
self._release_fn_sync = release_fn_sync
self.step_count = step_count
self.timestep = timestep
self._released = False
Comment on lines +181 to +192

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The atexit janitor only calls checkpoint._release_fn_sync and never awaits the async release_fn, so async-only cleanup can be skipped on exit — should we run the async release_fn in a temporary event loop or provide a sync fallback?

Finding type: Logical Bugs | Severity: 🔴 High


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents:

In src/ares/environments/base.py around lines 181 to 226, the
_CheckpointJanitor._cleanup currently only calls checkpoint._release_fn_sync and
therefore skips async _release_fn, leaving resources unreleased at process exit. Modify
_cleanup so that for each checkpoint: if _release_fn_sync is present, call it as before;
otherwise if _release_fn is present, run it synchronously by creating a temporary
asyncio event loop and calling loop.run_until_complete(_release_fn()), catching and
logging exceptions. If creating/running a loop fails because an event loop is already
running, fall back to running the coroutine in a separate Thread with its own event loop
and join it, also catching and logging exceptions. Keep the existing warning logging and
ensure the _checkpoints dict is cleared afterward.


_CHECKPOINT_JANITOR.register(self)

async def restore(self) -> tuple[EnvType, TimeStep]:
"""Create a new environment at this checkpoint's state.

Constructs a brand new environment instance initialised to the
checkpointed state. The original environment is not touched.
The caller owns the returned environment and must close() it.

Can be called multiple times -- each call creates an independent
environment.

Returns:
A tuple of (environment, timestep) where the environment is a new
instance and the timestep is the observation at the checkpointed state.
"""
if self._released:
raise RuntimeError("Cannot restore from a released checkpoint.")
env = await self._restore_fn()
return env, self.timestep

async def release(self) -> None:
"""Release resources held by this checkpoint.

For environments backed by containers, this deletes the snapshot image.
For simple environments, this is a no-op. Safe to call multiple times.
After release(), restore() will raise RuntimeError.
"""
if self._released:
return
self._released = True
if self._release_fn is not None:
await self._release_fn()
Comment on lines +215 to +226

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

release() flips _released to True before awaiting _release_fn so failures leave the checkpoint marked and prevent retries, should we postpone marking _released and unregistering the janitor until after _release_fn succeeds?

self._released = True; await self._release_fn() => await self._release_fn(); self._released = True

Finding type: Logical Bugs | Severity: 🔴 High


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents:

In src/ares/environments/base.py around lines 215-228, the release() method currently
sets self._released = True before awaiting self._release_fn, which causes the checkpoint
to be marked released and unregistered incorrectly if the async release raises. Change
release() so that it (1) does not set self._released until after await
self._release_fn() completes successfully, (2) only calls
_CHECKPOINT_JANITOR.unregister(self) after a successful release, and (3) keeps the
checkpoint registered and _released False if the release raises (propagate the exception
so callers can retry). Also preserve the no-op path: if _release_fn is None, still set
_released = True and unregister. Ensure no broad exception swallowing; let failures
surface so cleanup can be retried.

_CHECKPOINT_JANITOR.unregister(self)
Comment on lines +222 to +227

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not mark checkpoints released before async release succeeds.

On Line 224, _released is set before awaiting release_fn (Line 226). If release_fn raises, the checkpoint becomes permanently “released” even though cleanup failed, so release() cannot be retried and restore() is blocked.

Proposed fix
     async def release(self) -> None:
         """Release resources held by this checkpoint.
@@
         if self._released:
             return
-        self._released = True
-        if self._release_fn is not None:
-            await self._release_fn()
-        _CHECKPOINT_JANITOR.unregister(self)
+        if self._release_fn is not None:
+            await self._release_fn()
+        self._released = True
+        _CHECKPOINT_JANITOR.unregister(self)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self._released:
return
self._released = True
if self._release_fn is not None:
await self._release_fn()
_CHECKPOINT_JANITOR.unregister(self)
if self._released:
return
if self._release_fn is not None:
await self._release_fn()
self._released = True
_CHECKPOINT_JANITOR.unregister(self)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ares/environments/base.py` around lines 222 - 227, The method currently
sets self._released before awaiting the potentially-failing async cleanup, which
prevents retries if release fails; change the order so you await
self._release_fn() first (if _release_fn is not None), only set self._released =
True and call _CHECKPOINT_JANITOR.unregister(self) after the await completes
successfully, and let exceptions propagate (or re-raise) so the checkpoint
remains not released on failure; update the release() method accordingly to move
the _released assignment and unregister call to after the awaited _release_fn()
completes.


@property
def is_released(self) -> bool:
"""Whether this checkpoint's resources have been released."""
return self._released


class CheckpointableEnvironment[ActionType, ObservationType, RewardType: Scalar, DiscountType: Scalar](
Environment[ActionType, ObservationType, RewardType, DiscountType],
Protocol,
):
"""An Environment that supports checkpointing for exploration algorithms.

Extends the base Environment protocol with a single method: checkpoint().
Everything else (restore, release, cleanup) lives on the returned
EnvironmentCheckpoint object.

The implementation constructs the checkpoint with closures that know how
to create a brand new environment at the checkpointed state. No internal
protocol methods are imposed -- the implementation is free.
"""

async def checkpoint(self) -> EnvironmentCheckpoint[Self]:
"""Capture the current environment state.

Must be called at a step boundary (after reset() or step() returned
a non-LAST timestep, before the next step() call).

Returns:
An EnvironmentCheckpoint whose restore() creates a new instance
of this concrete environment type.
"""
...


class _CheckpointJanitor:
"""Emergency cleanup for checkpoints that weren't explicitly released.

Mirrors the _Janitor pattern in code_env.py for container cleanup.
Registered checkpoints are cleaned up synchronously at interpreter exit.
"""

def __init__(self):
self._checkpoints: dict[int, EnvironmentCheckpoint] = {}
atexit.register(self._cleanup)

def register(self, checkpoint: EnvironmentCheckpoint) -> None:
"""Register a checkpoint for emergency cleanup."""
self._checkpoints[id(checkpoint)] = checkpoint

def unregister(self, checkpoint: EnvironmentCheckpoint) -> None:
"""Unregister a checkpoint from emergency cleanup."""
self._checkpoints.pop(id(checkpoint), None)

def _cleanup(self) -> None:
"""Clean up all registered checkpoints at exit."""
if self._checkpoints:
_LOGGER.info("Cleaning up %d unreleased checkpoints at exit.", len(self._checkpoints))
for cp in list(self._checkpoints.values()):
if cp._release_fn_sync is not None:
try:
cp._release_fn_sync()
Comment on lines +284 to +289

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup logs lack checkpoint IDs despite CLAUDE.md requiring object IDs for async lifecycle events — should we include id(cp) in those log lines?

Finding type: AI Coding Guidelines | Severity: 🟢 Low


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents:

In src/ares/environments/base.py around lines 284 to 289, the
_CheckpointJanitor._cleanup method currently logs only the count and a generic warning
but does not include checkpoint identifiers. Update the info log at line ~285 to include
the IDs of the registered checkpoints (e.g., build a list of ids from self._checkpoints
values and log it alongside the count), and change the warning at line ~291 to include
the specific checkpoint id (id(cp)) when a release fails while keeping exc_info=True.
Keep formatting consistent with existing _LOGGER calls and avoid exposing sensitive data
beyond Python object ids.

except Exception:
_LOGGER.warning("Failed to clean up checkpoint.", exc_info=True)
self._checkpoints.clear()


_CHECKPOINT_JANITOR = _CheckpointJanitor()


async def create_container(
*,
container_factory: containers.ContainerFactory,
Expand Down
Loading
Loading