-
Notifications
You must be signed in to change notification settings - Fork 14
Add EnvironmentCheckpoint and CheckpointableEnvironment protocol #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,8 @@ | |||||||||||||||||||||||||
| dm_env Environment protocol and utilities for ARES. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import atexit | ||||||||||||||||||||||||||
| from collections.abc import Awaitable, Callable | ||||||||||||||||||||||||||
| import functools | ||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||
|
|
@@ -149,6 +151,150 @@ async def __aexit__( | |||||||||||||||||||||||||
| await self.close() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class EnvironmentCheckpoint[EnvType]: | ||||||||||||||||||||||||||
| """A restorable snapshot of environment state at a step boundary. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Generic over the environment type so that restore() returns the concrete type. | ||||||||||||||||||||||||||
| The checkpoint is self-contained: it holds closures for restore and release | ||||||||||||||||||||||||||
| that capture whatever data and factories are needed. The environment that | ||||||||||||||||||||||||||
| creates it decides how those closures work. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| restore() creates a brand new environment instance in the checkpointed state. | ||||||||||||||||||||||||||
| The original environment is not affected. Multiple restores from the same | ||||||||||||||||||||||||||
| checkpoint create independent environments. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Usage:: | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| checkpoint = await env.checkpoint() | ||||||||||||||||||||||||||
| env_restored, ts = await checkpoint.restore() | ||||||||||||||||||||||||||
| # env_restored is a new, independent environment at the checkpointed state | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Resource management: | ||||||||||||||||||||||||||
| Each checkpoint may hold expensive resources (e.g., a Docker image from | ||||||||||||||||||||||||||
| ``docker commit``). Call release() when no longer needed, or rely on the | ||||||||||||||||||||||||||
| global checkpoint janitor for atexit cleanup. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||
| restore_fn: Callable[[], Awaitable[EnvType]], | ||||||||||||||||||||||||||
| release_fn: Callable[[], Awaitable[None]] | None = None, | ||||||||||||||||||||||||||
| release_fn_sync: Callable[[], None] | None = None, | ||||||||||||||||||||||||||
| step_count: int, | ||||||||||||||||||||||||||
| timestep: TimeStep, | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| self._restore_fn = restore_fn | ||||||||||||||||||||||||||
| self._release_fn = release_fn | ||||||||||||||||||||||||||
| self._release_fn_sync = release_fn_sync | ||||||||||||||||||||||||||
| self.step_count = step_count | ||||||||||||||||||||||||||
| self.timestep = timestep | ||||||||||||||||||||||||||
| self._released = False | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| _CHECKPOINT_JANITOR.register(self) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def restore(self) -> tuple[EnvType, TimeStep]: | ||||||||||||||||||||||||||
| """Create a new environment at this checkpoint's state. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Constructs a brand new environment instance initialised to the | ||||||||||||||||||||||||||
| checkpointed state. The original environment is not touched. | ||||||||||||||||||||||||||
| The caller owns the returned environment and must close() it. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Can be called multiple times -- each call creates an independent | ||||||||||||||||||||||||||
| environment. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| A tuple of (environment, timestep) where the environment is a new | ||||||||||||||||||||||||||
| instance and the timestep is the observation at the checkpointed state. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| if self._released: | ||||||||||||||||||||||||||
| raise RuntimeError("Cannot restore from a released checkpoint.") | ||||||||||||||||||||||||||
| env = await self._restore_fn() | ||||||||||||||||||||||||||
| return env, self.timestep | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def release(self) -> None: | ||||||||||||||||||||||||||
| """Release resources held by this checkpoint. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| For environments backed by containers, this deletes the snapshot image. | ||||||||||||||||||||||||||
| For simple environments, this is a no-op. Safe to call multiple times. | ||||||||||||||||||||||||||
| After release(), restore() will raise RuntimeError. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| if self._released: | ||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||
| self._released = True | ||||||||||||||||||||||||||
| if self._release_fn is not None: | ||||||||||||||||||||||||||
| await self._release_fn() | ||||||||||||||||||||||||||
|
Comment on lines
+215
to
+226
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. release() flips _released to True before awaiting _release_fn so failures leave the checkpoint marked and prevent retries, should we postpone marking _released and unregistering the janitor until after _release_fn succeeds? self._released = True; await self._release_fn() => await self._release_fn(); self._released = True Finding type: Want Baz to fix this for you? Activate Fixer Other fix methodsPrompt for AI Agents: |
||||||||||||||||||||||||||
| _CHECKPOINT_JANITOR.unregister(self) | ||||||||||||||||||||||||||
|
Comment on lines
+222
to
+227
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not mark checkpoints released before async release succeeds. On Line 224, Proposed fix async def release(self) -> None:
"""Release resources held by this checkpoint.
@@
if self._released:
return
- self._released = True
- if self._release_fn is not None:
- await self._release_fn()
- _CHECKPOINT_JANITOR.unregister(self)
+ if self._release_fn is not None:
+ await self._release_fn()
+ self._released = True
+ _CHECKPOINT_JANITOR.unregister(self)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def is_released(self) -> bool: | ||||||||||||||||||||||||||
| """Whether this checkpoint's resources have been released.""" | ||||||||||||||||||||||||||
| return self._released | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class CheckpointableEnvironment[ActionType, ObservationType, RewardType: Scalar, DiscountType: Scalar]( | ||||||||||||||||||||||||||
| Environment[ActionType, ObservationType, RewardType, DiscountType], | ||||||||||||||||||||||||||
| Protocol, | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| """An Environment that supports checkpointing for exploration algorithms. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Extends the base Environment protocol with a single method: checkpoint(). | ||||||||||||||||||||||||||
| Everything else (restore, release, cleanup) lives on the returned | ||||||||||||||||||||||||||
| EnvironmentCheckpoint object. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| The implementation constructs the checkpoint with closures that know how | ||||||||||||||||||||||||||
| to create a brand new environment at the checkpointed state. No internal | ||||||||||||||||||||||||||
| protocol methods are imposed -- the implementation is free. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def checkpoint(self) -> EnvironmentCheckpoint[Self]: | ||||||||||||||||||||||||||
| """Capture the current environment state. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Must be called at a step boundary (after reset() or step() returned | ||||||||||||||||||||||||||
| a non-LAST timestep, before the next step() call). | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| An EnvironmentCheckpoint whose restore() creates a new instance | ||||||||||||||||||||||||||
| of this concrete environment type. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| ... | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class _CheckpointJanitor: | ||||||||||||||||||||||||||
| """Emergency cleanup for checkpoints that weren't explicitly released. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Mirrors the _Janitor pattern in code_env.py for container cleanup. | ||||||||||||||||||||||||||
| Registered checkpoints are cleaned up synchronously at interpreter exit. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||
| self._checkpoints: dict[int, EnvironmentCheckpoint] = {} | ||||||||||||||||||||||||||
| atexit.register(self._cleanup) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def register(self, checkpoint: EnvironmentCheckpoint) -> None: | ||||||||||||||||||||||||||
| """Register a checkpoint for emergency cleanup.""" | ||||||||||||||||||||||||||
| self._checkpoints[id(checkpoint)] = checkpoint | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def unregister(self, checkpoint: EnvironmentCheckpoint) -> None: | ||||||||||||||||||||||||||
| """Unregister a checkpoint from emergency cleanup.""" | ||||||||||||||||||||||||||
| self._checkpoints.pop(id(checkpoint), None) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _cleanup(self) -> None: | ||||||||||||||||||||||||||
| """Clean up all registered checkpoints at exit.""" | ||||||||||||||||||||||||||
| if self._checkpoints: | ||||||||||||||||||||||||||
| _LOGGER.info("Cleaning up %d unreleased checkpoints at exit.", len(self._checkpoints)) | ||||||||||||||||||||||||||
| for cp in list(self._checkpoints.values()): | ||||||||||||||||||||||||||
| if cp._release_fn_sync is not None: | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| cp._release_fn_sync() | ||||||||||||||||||||||||||
|
Comment on lines
+284
to
+289
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cleanup logs lack checkpoint IDs despite CLAUDE.md requiring object IDs for async lifecycle events — should we include id(cp) in those log lines? Finding type: Want Baz to fix this for you? Activate Fixer Other fix methodsPrompt for AI Agents: |
||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||
| _LOGGER.warning("Failed to clean up checkpoint.", exc_info=True) | ||||||||||||||||||||||||||
| self._checkpoints.clear() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| _CHECKPOINT_JANITOR = _CheckpointJanitor() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def create_container( | ||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||
| container_factory: containers.ContainerFactory, | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The atexit janitor only calls checkpoint._release_fn_sync and never awaits the async release_fn, so async-only cleanup can be skipped on exit — should we run the async release_fn in a temporary event loop or provide a sync fallback?
Finding type:
Logical Bugs| Severity: 🔴 HighWant Baz to fix this for you? Activate Fixer
Other fix methods
Prompt for AI Agents: