forked from alexzhang13/rlm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_env.py
More file actions
181 lines (135 loc) · 5.61 KB
/
base_env.py
File metadata and controls
181 lines (135 loc) · 5.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from abc import ABC, abstractmethod
from typing import Any, Protocol, runtime_checkable
from rlm.core.types import REPLResult
class BaseEnv(ABC):
"""
Base REPL-like environment that the RLM uses to interact with. The primary types are isolated and non-isolated,
where isolated environments are on a separate machine from the LM.
"""
def __init__(self, persistent: bool = False, **kwargs):
self.persistent = persistent
self.kwargs = kwargs
@abstractmethod
def setup(self):
raise NotImplementedError
@abstractmethod
def load_context(self, context_payload: dict | list | str):
raise NotImplementedError
@abstractmethod
def execute_code(self, code: str) -> REPLResult:
raise NotImplementedError
class IsolatedEnv(BaseEnv, ABC):
"""
These environments (e.g. Prime Envs, Modal Envs) sit on a completely separate machine from the LM,
guaranteeing complete isolation from the LM process.
"""
def __init__(self, persistent: bool = False, **kwargs):
super().__init__(persistent=persistent, **kwargs)
@abstractmethod
def setup(self):
raise NotImplementedError
@abstractmethod
def load_context(self, context_payload: dict | list | str):
raise NotImplementedError
@abstractmethod
def execute_code(self, code: str) -> REPLResult:
raise NotImplementedError
class NonIsolatedEnv(BaseEnv, ABC):
"""
These environments run on the same machine as the LM, and provide different levels of isolation
depending on the choice of environment. The simplest, default is a local Python REPL that runs
as a subprocess.
"""
def __init__(self, persistent: bool = False, **kwargs):
super().__init__(persistent=persistent, **kwargs)
@abstractmethod
def setup(self):
raise NotImplementedError
@abstractmethod
def load_context(self, context_payload: dict | list | str):
raise NotImplementedError
@abstractmethod
def execute_code(self, code: str) -> REPLResult:
raise NotImplementedError
@runtime_checkable
class SupportsPersistence(Protocol):
"""Protocol for environments that support persistent multi-turn sessions.
CHECKING SUPPORT:
Use isinstance(env, SupportsPersistence) to check if an environment
supports persistence capabilities.
IMPLEMENTING THIS PROTOCOL:
To add persistence to your environment, implement these 5 methods.
See tests/test_local_repl_persistent.py for expected behavior.
VERSIONING BEHAVIOR:
Contexts and histories are versioned with numeric suffixes:
- First context -> context_0, context_1, context_2, ...
- First history -> history_0, history_1, history_2, ...
ALIASING BEHAVIOR:
The unversioned names always point to index 0:
- context -> context_0 (first context)
- history -> history_0 (first history)
EXAMPLE IMPLEMENTATION:
See rlm/environments/local_repl.py for a complete reference.
TESTS:
- Unit tests: tests/test_local_repl_persistent.py
- Integration tests: tests/test_multi_turn_integration.py
Run: uv run pytest tests/test_local_repl_persistent.py -v
"""
def update_handler_address(self, address: tuple[str, int]) -> None:
"""Update the LM handler address for nested LLM calls.
Called by RLM when the handler address changes between completions.
Store the address so llm_query() calls from executed code can reach
the LM handler.
Args:
address: (host, port) tuple for the LM handler server.
"""
...
def add_context(
self, context_payload: dict | list | str, context_index: int | None = None
) -> int:
"""Add a context payload, making it available as context_N in code.
Versioning:
- context_index=None: auto-increment (0, 1, 2, ...)
- context_index=N: use specific index N
Storage:
Must store so executed code can access:
- context_0, context_1, etc. (versioned)
- context (alias to context_0)
Args:
context_payload: The context data (string, dict, or list).
context_index: Optional specific index, or None to auto-increment.
Returns:
The index used (for auto-increment, returns the assigned index).
"""
...
def get_context_count(self) -> int:
"""Return the number of contexts added so far.
Used by RLM to inform the model how many contexts are available.
"""
...
def add_history(
self, message_history: list[dict[str, Any]], history_index: int | None = None
) -> int:
"""Add a message history, making it available as history_N in code.
Versioning:
- history_index=None: auto-increment (0, 1, 2, ...)
- history_index=N: use specific index N
Storage:
Must store so executed code can access:
- history_0, history_1, etc. (versioned)
- history (alias to history_0)
IMPORTANT: Store a deep copy, not a reference. The caller may
modify the list after calling this method.
Args:
message_history: List of message dicts (role, content).
history_index: Optional specific index, or None to auto-increment.
Returns:
The index used.
"""
...
def get_history_count(self) -> int:
"""Return the number of histories added so far.
Used by RLM to inform the model how many conversation histories
are available.
"""
...