Skip to content

Commit cb9cab7

Browse files
proposal(environments): SupportsPersistence Protocol for multi-turn sessions (#40)
* feat(local-repl): add multi-context support for persistent REPL * feat(core): add persistent flag for multi-turn REPL conversations Add persistent=True option to RLM that reuses the environment across completion() calls instead of creating/destroying for each call. This enables multi-turn conversations where variables and contexts persist. - Add persistent parameter to RLM constructor - Reuse environment when persistent=True, store as _persistent_env - Add close() method and context manager support (__enter__/__exit__) - Environment cleanup only on explicit close() when persistent * feat(prompts): inform model about multiple contexts in multi-turn sessions Add context_count parameter to build_user_prompt() so the model knows when multiple contexts are available (context_0, context_1, etc.) during persistent REPL sessions. * fix(core): validate environment supports persistent mode before use Add validation to prevent AttributeError when persistent=True is used with environments that don't implement the required methods (update_handler_address, add_context, get_context_count). - Add _validate_persistent_environment_support() called at init time - Add _env_supports_persistence() for runtime capability checking - Add defensive runtime check before calling persistence methods - Raise clear ValueError if unsupported environment is configured * feat(local-repl): add message history storage for multi-turn sessions Store conversation histories as versioned variables (history_0, history_1, etc.) in the REPL environment, making them accessible for subsequent queries. This enables models to reference prior conversation context in persistent sessions. * feat(environments): add persistent parameter to all REPL classes Add persistent parameter to BaseEnv, IsolatedEnv, and NonIsolatedEnv. DockerREPL, ModalREPL, and PrimeREPL raise NotImplementedError when persistent=True since they don't yet support it. LocalREPL passes the parameter through to the base class. Also makes DockerREPL cleanup more defensive with hasattr checks. * test(local-repl): replace multi-context tests with non-persistent simulation tests Replace TestLocalREPLMultiContext and TestLocalREPLHistory with TestLocalREPLSimulatingRLMNoPersistence. New tests verify that environments reset between RLM completions when persistent=False. * test(local-repl): add persistent mode unit tests Add TestLocalREPLMultiContext, TestLocalREPLHistory, and TestLocalREPLPersistentState test classes for LocalREPL's persistent mode features including multi-context versioning and message history storage. * test(rlm): add multi-turn integration tests Add comprehensive integration tests for persistent RLM sessions: - Environment reuse across completion calls - Context and history accumulation - Variable persistence between completions - Prompt awareness of contexts/histories - Resource cleanup on close - Validation of unsupported environments * refactor(tests): replace ScriptedMockLM with standard unittest.mock Remove custom ScriptedMockLM class (58 lines) and replace with create_mock_lm() helper (14 lines) using standard Mock patterns. * refactor(environments): add SupportsPersistence Protocol Replace scattered hasattr checks with a runtime-checkable Protocol that defines the persistence capability contract. This provides type checker enforcement and IDE autocomplete support. * docs(environments): add comprehensive SupportsPersistence Protocol docs Document expected behavior to implement persistence: - Versioning behavior (context_0, context_1, ...) - Aliasing behavior (context -> context_0) - Method contracts with detailed docstrings - References to tests and example implementation * add ruff lint on test * fix linting on tests --------- Co-authored-by: Alex Zhang <alex.lx.zhang@gmail.com>
1 parent a30805b commit cb9cab7

File tree

11 files changed

+1017
-48
lines changed

11 files changed

+1017
-48
lines changed

rlm/core/rlm.py

Lines changed: 101 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
RLMIteration,
1414
RLMMetadata,
1515
)
16-
from rlm.environments import BaseEnv, get_environment
16+
from rlm.environments import BaseEnv, SupportsPersistence, get_environment
1717
from rlm.logger import RLMLogger, VerbosePrinter
1818
from rlm.utils.parsing import (
1919
find_code_blocks,
@@ -51,6 +51,7 @@ def __init__(
5151
other_backend_kwargs: list[dict[str, Any]] | None = None,
5252
logger: RLMLogger | None = None,
5353
verbose: bool = False,
54+
persistent: bool = False,
5455
):
5556
"""
5657
Args:
@@ -66,6 +67,7 @@ def __init__(
6667
other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends).
6768
logger: The logger to use for the RLM.
6869
verbose: Whether to print verbose output in rich to console.
70+
persistent: If True, reuse the environment across completion() calls for multi-turn conversations.
6971
"""
7072
# Store config for spawning per-completion
7173
self.backend = backend
@@ -84,6 +86,14 @@ def __init__(
8486
self.logger = logger
8587
self.verbose = VerbosePrinter(enabled=verbose)
8688

89+
# Persistence support
90+
self.persistent = persistent
91+
self._persistent_env: SupportsPersistence | None = None
92+
93+
# Validate persistence support at initialization
94+
if self.persistent:
95+
self._validate_persistent_environment_support()
96+
8797
# Log metadata if logger is provided
8898
if self.logger or verbose:
8999
metadata = RLMMetadata(
@@ -108,7 +118,9 @@ def __init__(
108118
def _spawn_completion_context(self, prompt: str | dict[str, Any]):
109119
"""
110120
Spawn an LM handler and environment for a single completion call.
111-
Cleans up both when the context exits.
121+
122+
When persistent=True, the environment is reused across calls.
123+
When persistent=False (default), creates fresh environment each call.
112124
"""
113125
# Create client and wrap in handler
114126
client: BaseLM = get_client(self.backend, self.backend_kwargs)
@@ -122,20 +134,32 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]):
122134

123135
lm_handler.start()
124136

125-
# Pass handler address to environment so it can make llm_query() calls
126-
env_kwargs = self.environment_kwargs.copy()
127-
env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port)
128-
env_kwargs["context_payload"] = prompt
137+
# Environment: reuse if persistent, otherwise create fresh
138+
if self.persistent and self._persistent_env is not None:
139+
environment = self._persistent_env
140+
# Defensive check: ensure environment supports persistence methods
141+
if not self._env_supports_persistence(environment):
142+
raise RuntimeError(
143+
f"Persistent environment of type '{type(environment).__name__}' does not "
144+
f"implement required methods (update_handler_address, add_context, get_context_count). "
145+
f"This should have been caught at initialization."
146+
)
147+
environment.update_handler_address((lm_handler.host, lm_handler.port))
148+
environment.add_context(prompt)
149+
else:
150+
env_kwargs = self.environment_kwargs.copy()
151+
env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port)
152+
env_kwargs["context_payload"] = prompt
153+
environment: BaseEnv = get_environment(self.environment_type, env_kwargs)
129154

130-
# Initialize the environment
131-
environment: BaseEnv = get_environment(self.environment_type, env_kwargs)
155+
if self.persistent:
156+
self._persistent_env = environment
132157

133158
try:
134159
yield lm_handler, environment
135160
finally:
136-
# Cleanup
137161
lm_handler.stop()
138-
if hasattr(environment, "cleanup"):
162+
if not self.persistent and hasattr(environment, "cleanup"):
139163
environment.cleanup()
140164

141165
def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
@@ -177,7 +201,19 @@ def completion(
177201

178202
for i in range(self.max_iterations):
179203
# Current prompt = message history + additional prompt suffix
180-
current_prompt = message_history + [build_user_prompt(root_prompt, i)]
204+
context_count = (
205+
environment.get_context_count()
206+
if isinstance(environment, SupportsPersistence)
207+
else 1
208+
)
209+
history_count = (
210+
environment.get_history_count()
211+
if isinstance(environment, SupportsPersistence)
212+
else 0
213+
)
214+
current_prompt = message_history + [
215+
build_user_prompt(root_prompt, i, context_count, history_count)
216+
]
181217

182218
iteration: RLMIteration = self._completion_turn(
183219
prompt=current_prompt,
@@ -201,6 +237,11 @@ def completion(
201237
usage = lm_handler.get_usage_summary()
202238
self.verbose.print_final_answer(final_answer)
203239
self.verbose.print_summary(i + 1, time_end - time_start, usage.to_dict())
240+
241+
# Store message history in persistent environment
242+
if self.persistent and isinstance(environment, SupportsPersistence):
243+
environment.add_history(message_history)
244+
204245
return RLMChatCompletion(
205246
root_model=self.backend_kwargs.get("model_name", "unknown")
206247
if self.backend_kwargs
@@ -223,6 +264,11 @@ def completion(
223264
usage = lm_handler.get_usage_summary()
224265
self.verbose.print_final_answer(final_answer)
225266
self.verbose.print_summary(self.max_iterations, time_end - time_start, usage.to_dict())
267+
268+
# Store message history in persistent environment
269+
if self.persistent and isinstance(environment, SupportsPersistence):
270+
environment.add_history(message_history)
271+
226272
return RLMChatCompletion(
227273
root_model=self.backend_kwargs.get("model_name", "unknown")
228274
if self.backend_kwargs
@@ -292,3 +338,47 @@ def _fallback_answer(self, message: str | dict[str, Any]) -> str:
292338
client: BaseLM = get_client(self.backend, self.backend_kwargs)
293339
response = client.completion(message)
294340
return response
341+
342+
def _validate_persistent_environment_support(self) -> None:
343+
"""
344+
Validate that the configured environment type supports persistent mode.
345+
346+
Persistent mode requires environments to implement:
347+
- update_handler_address(address): Update LM handler address between calls
348+
- add_context(payload, index): Add new context for multi-turn conversations
349+
- get_context_count(): Return the number of loaded contexts
350+
351+
Currently only 'local' (LocalREPL) supports these methods.
352+
353+
Raises:
354+
ValueError: If the environment type does not support persistent mode.
355+
"""
356+
# Known environments that support persistence
357+
persistent_supported_environments = {"local"}
358+
359+
if self.environment_type not in persistent_supported_environments:
360+
raise ValueError(
361+
f"persistent=True is not supported for environment type '{self.environment_type}'. "
362+
f"Persistent mode requires environments that implement update_handler_address(), "
363+
f"add_context(), and get_context_count(). "
364+
f"Supported environments: {sorted(persistent_supported_environments)}"
365+
)
366+
367+
@staticmethod
368+
def _env_supports_persistence(env: BaseEnv) -> bool:
369+
"""Check if an environment instance supports persistent mode methods."""
370+
return isinstance(env, SupportsPersistence)
371+
372+
def close(self) -> None:
373+
"""Clean up persistent environment. Call when done with multi-turn conversations."""
374+
if self._persistent_env is not None:
375+
if hasattr(self._persistent_env, "cleanup"):
376+
self._persistent_env.cleanup()
377+
self._persistent_env = None
378+
379+
def __enter__(self) -> "RLM":
380+
return self
381+
382+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
383+
self.close()
384+
return False

rlm/environments/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Any, Literal
22

3-
from rlm.environments.base_env import BaseEnv
3+
from rlm.environments.base_env import BaseEnv, SupportsPersistence
44
from rlm.environments.local_repl import LocalREPL
55

6+
__all__ = ["BaseEnv", "LocalREPL", "SupportsPersistence", "get_environment"]
7+
68

79
def get_environment(
810
environment: Literal["local", "modal", "docker"],

rlm/environments/base_env.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any, Protocol, runtime_checkable
23

34
from rlm.core.types import REPLResult
45

@@ -9,7 +10,8 @@ class BaseEnv(ABC):
910
where isolated environments are on a separate machine from the LM.
1011
"""
1112

12-
def __init__(self, **kwargs):
13+
def __init__(self, persistent: bool = False, **kwargs):
14+
self.persistent = persistent
1315
self.kwargs = kwargs
1416

1517
@abstractmethod
@@ -31,8 +33,8 @@ class IsolatedEnv(BaseEnv, ABC):
3133
guaranteeing complete isolation from the LM process.
3234
"""
3335

34-
def __init__(self, **kwargs):
35-
super().__init__(**kwargs)
36+
def __init__(self, persistent: bool = False, **kwargs):
37+
super().__init__(persistent=persistent, **kwargs)
3638

3739
@abstractmethod
3840
def setup(self):
@@ -54,8 +56,8 @@ class NonIsolatedEnv(BaseEnv, ABC):
5456
as a subprocess.
5557
"""
5658

57-
def __init__(self, **kwargs):
58-
super().__init__(**kwargs)
59+
def __init__(self, persistent: bool = False, **kwargs):
60+
super().__init__(persistent=persistent, **kwargs)
5961

6062
@abstractmethod
6163
def setup(self):
@@ -68,3 +70,112 @@ def load_context(self, context_payload: dict | list | str):
6870
@abstractmethod
6971
def execute_code(self, code: str) -> REPLResult:
7072
raise NotImplementedError
73+
74+
75+
@runtime_checkable
76+
class SupportsPersistence(Protocol):
77+
"""Protocol for environments that support persistent multi-turn sessions.
78+
79+
CHECKING SUPPORT:
80+
Use isinstance(env, SupportsPersistence) to check if an environment
81+
supports persistence capabilities.
82+
83+
IMPLEMENTING THIS PROTOCOL:
84+
To add persistence to your environment, implement these 5 methods.
85+
See tests/test_local_repl_persistent.py for expected behavior.
86+
87+
VERSIONING BEHAVIOR:
88+
Contexts and histories are versioned with numeric suffixes:
89+
- First context -> context_0, context_1, context_2, ...
90+
- First history -> history_0, history_1, history_2, ...
91+
92+
ALIASING BEHAVIOR:
93+
The unversioned names always point to index 0:
94+
- context -> context_0 (first context)
95+
- history -> history_0 (first history)
96+
97+
EXAMPLE IMPLEMENTATION:
98+
See rlm/environments/local_repl.py for a complete reference.
99+
100+
TESTS:
101+
- Unit tests: tests/test_local_repl_persistent.py
102+
- Integration tests: tests/test_multi_turn_integration.py
103+
104+
Run: uv run pytest tests/test_local_repl_persistent.py -v
105+
"""
106+
107+
def update_handler_address(self, address: tuple[str, int]) -> None:
108+
"""Update the LM handler address for nested LLM calls.
109+
110+
Called by RLM when the handler address changes between completions.
111+
Store the address so llm_query() calls from executed code can reach
112+
the LM handler.
113+
114+
Args:
115+
address: (host, port) tuple for the LM handler server.
116+
"""
117+
...
118+
119+
def add_context(
120+
self, context_payload: dict | list | str, context_index: int | None = None
121+
) -> int:
122+
"""Add a context payload, making it available as context_N in code.
123+
124+
Versioning:
125+
- context_index=None: auto-increment (0, 1, 2, ...)
126+
- context_index=N: use specific index N
127+
128+
Storage:
129+
Must store so executed code can access:
130+
- context_0, context_1, etc. (versioned)
131+
- context (alias to context_0)
132+
133+
Args:
134+
context_payload: The context data (string, dict, or list).
135+
context_index: Optional specific index, or None to auto-increment.
136+
137+
Returns:
138+
The index used (for auto-increment, returns the assigned index).
139+
"""
140+
...
141+
142+
def get_context_count(self) -> int:
143+
"""Return the number of contexts added so far.
144+
145+
Used by RLM to inform the model how many contexts are available.
146+
"""
147+
...
148+
149+
def add_history(
150+
self, message_history: list[dict[str, Any]], history_index: int | None = None
151+
) -> int:
152+
"""Add a message history, making it available as history_N in code.
153+
154+
Versioning:
155+
- history_index=None: auto-increment (0, 1, 2, ...)
156+
- history_index=N: use specific index N
157+
158+
Storage:
159+
Must store so executed code can access:
160+
- history_0, history_1, etc. (versioned)
161+
- history (alias to history_0)
162+
163+
IMPORTANT: Store a deep copy, not a reference. The caller may
164+
modify the list after calling this method.
165+
166+
Args:
167+
message_history: List of message dicts (role, content).
168+
history_index: Optional specific index, or None to auto-increment.
169+
170+
Returns:
171+
The index used.
172+
"""
173+
...
174+
175+
def get_history_count(self) -> int:
176+
"""Return the number of histories added so far.
177+
178+
Used by RLM to inform the model how many conversation histories
179+
are available.
180+
"""
181+
...

rlm/environments/docker_repl.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,14 @@ def __init__(
180180
lm_handler_address: tuple[str, int] | None = None,
181181
context_payload: dict | list | str | None = None,
182182
setup_code: str | None = None,
183+
persistent: bool = False,
183184
**kwargs,
184185
):
185-
super().__init__(**kwargs)
186+
if persistent:
187+
raise NotImplementedError(
188+
"Persistent REPLs are currently not supported for environment: DockerREPL"
189+
)
190+
super().__init__(persistent=persistent, **kwargs)
186191

187192
self.image = image
188193
self.lm_handler_address = lm_handler_address
@@ -292,13 +297,13 @@ def execute_code(self, code: str) -> REPLResult:
292297
)
293298

294299
def cleanup(self):
295-
if self.container_id:
300+
if hasattr(self, "container_id") and self.container_id:
296301
subprocess.run(["docker", "stop", self.container_id], capture_output=True)
297302
self.container_id = None
298-
if self.proxy_server:
303+
if hasattr(self, "proxy_server") and self.proxy_server:
299304
self.proxy_server.shutdown()
300305
self.proxy_server = None
301-
if os.path.exists(self.temp_dir):
306+
if hasattr(self, "temp_dir") and os.path.exists(self.temp_dir):
302307
import shutil
303308

304309
shutil.rmtree(self.temp_dir, ignore_errors=True)

0 commit comments

Comments
 (0)