Skip to content

Commit b38ad84

Browse files
authored
Add simple compaction capabilities, preliminary (#110)
* add extra compaction capabilities, preliminary * update test parsing behavior, should return None
1 parent bcb9b25 commit b38ad84

File tree

8 files changed

+396
-11
lines changed

8 files changed

+396
-11
lines changed

examples/compaction_example.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
Compaction example: run RLM with compaction enabled and a low threshold
3+
to trigger summarization on gpt-5-nano.
4+
5+
Uses a low compaction_threshold_pct so compaction runs after several iterations.
6+
The task forces many separate REPL turns with substantial output so the root
7+
context grows and compaction definitely triggers. The REPL variable `history`
8+
holds trajectory segments and summaries.
9+
"""
10+
11+
import os
12+
13+
from dotenv import load_dotenv
14+
15+
from rlm import RLM
16+
from rlm.logger import RLMLogger
17+
18+
load_dotenv()
19+
20+
# Low threshold so compaction triggers after a few iterations (~2% of context).
21+
# Use 0.85 in production.
22+
COMPACTION_THRESHOLD_PCT = 0.03
23+
24+
logger = RLMLogger()
25+
rlm = RLM(
26+
backend="portkey",
27+
backend_kwargs={
28+
"model_name": "@openai/gpt-5-nano",
29+
"api_key": os.getenv("PORTKEY_API_KEY"),
30+
},
31+
environment="local",
32+
environment_kwargs={},
33+
max_depth=1,
34+
max_iterations=10,
35+
compaction=True,
36+
compaction_threshold_pct=COMPACTION_THRESHOLD_PCT,
37+
verbose=True,
38+
logger=logger,
39+
)
40+
41+
# Hard task that forces many iterations: find the 50th prime using at least 8
42+
# separate REPL blocks, one per turn. Each block must produce visible output.
43+
# This grows message history (long reasoning + code + execution results) so
44+
# compaction triggers.
45+
prompt = (
46+
"Find the 50th prime number. You MUST use at least 8 separate REPL blocks, "
47+
"each in its own response (one block per message). Do NOT combine steps.\n\n"
48+
"Required structure: "
49+
"Block 1: Define a function is_prime(n) and test it on a few numbers; print the results. "
50+
"Block 2: Write a loop that counts primes and print the first 10 primes. "
51+
"Block 3: Extend to count up to 20 primes and print them. "
52+
"Block 4: Count up to 30 primes and print the 25th. "
53+
"Block 5: Count up to 40 primes and print the 35th. "
54+
"Block 6: Count up to 50 primes and print the 45th. "
55+
"Block 7: Print the full list of the first 50 primes (so we see all 50). "
56+
"Block 8: Set answer = (the 50th prime) and call FINAL_VAR(answer).\n\n"
57+
"Each block must run alone. Show your reasoning briefly before each block. "
58+
"After each block, wait for the execution result before writing the next block."
59+
)
60+
61+
result = rlm.completion(prompt, root_prompt=prompt)
62+
63+
print("Response:", result.response)
64+
print("Execution time:", result.execution_time)
65+
print("Metadata:", result.metadata)

rlm/core/rlm.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
build_user_prompt,
2828
)
2929
from rlm.utils.rlm_utils import filter_sensitive_keys
30+
from rlm.utils.token_utils import count_tokens, get_context_limit
3031

3132

3233
class RLM:
@@ -54,6 +55,8 @@ def __init__(
5455
persistent: bool = False,
5556
custom_tools: dict[str, Any] | None = None,
5657
custom_sub_tools: dict[str, Any] | None = None,
58+
compaction: bool = False,
59+
compaction_threshold_pct: float = 0.85,
5760
):
5861
"""
5962
Args:
@@ -74,6 +77,10 @@ def __init__(
7477
values are callable functions. These are injected into the REPL globals.
7578
custom_sub_tools: Dict of custom tools for sub-agents (llm_query calls). If None, inherits
7679
from custom_tools. Pass an empty dict {} to disable tools for sub-agents.
80+
compaction: If True, keep full root model history in REPL variable `history` and compact
81+
when root context reaches compaction_threshold_pct of the model's context limit.
82+
compaction_threshold_pct: When compaction is on, trigger summarization when root
83+
message token count reaches this fraction of the model context limit (default 0.85).
7784
"""
7885
# Store config for spawning per-completion
7986
self.backend = backend
@@ -98,6 +105,9 @@ def __init__(
98105
# Sub-tools: if None, inherit from custom_tools; if {}, no tools for sub-agents
99106
self.custom_sub_tools = custom_sub_tools if custom_sub_tools is not None else custom_tools
100107

108+
self.compaction = compaction
109+
self.compaction_threshold_pct = compaction_threshold_pct
110+
101111
self.depth = depth
102112
self.max_depth = max_depth
103113
self.max_iterations = max_iterations
@@ -181,6 +191,8 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]):
181191
env_kwargs["custom_tools"] = self.custom_tools
182192
if self.custom_sub_tools is not None:
183193
env_kwargs["custom_sub_tools"] = self.custom_sub_tools
194+
if self.compaction and self.environment_type == "local":
195+
env_kwargs["compaction"] = True
184196
environment: BaseEnv = get_environment(self.environment_type, env_kwargs)
185197

186198
if self.persistent:
@@ -204,7 +216,11 @@ def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
204216
query_metadata=metadata,
205217
custom_tools=self.custom_tools,
206218
)
207-
219+
if self.compaction:
220+
message_history[0]["content"] += (
221+
"\n\nThe full conversation history (trajectory segments and any summaries) "
222+
"is available in the REPL variable `history` as a list."
223+
)
208224
return message_history
209225

210226
def completion(
@@ -236,6 +252,18 @@ def completion(
236252
message_history = self._setup_prompt(prompt)
237253

238254
for i in range(self.max_iterations):
255+
if self.compaction and hasattr(environment, "append_compaction_entry"):
256+
current_tokens, threshold_tokens, max_tokens = self._get_compaction_status(
257+
message_history
258+
)
259+
self.verbose.print_compaction_status(
260+
current_tokens, threshold_tokens, max_tokens
261+
)
262+
if current_tokens >= threshold_tokens:
263+
self.verbose.print_compaction()
264+
message_history = self._compact_history(
265+
lm_handler, environment, message_history
266+
)
239267
# Current prompt = message history + additional prompt suffix
240268
context_count = (
241269
environment.get_context_count()
@@ -257,8 +285,14 @@ def completion(
257285
environment=environment,
258286
)
259287

260-
# Check if RLM is done and has a final answer.
261-
final_answer = find_final_answer(iteration.response, environment=environment)
288+
# Check if RLM is done and has a final answer. Prefer FINAL_VAR result from REPL execution.
289+
final_answer = None
290+
for block in iteration.code_blocks:
291+
if getattr(block.result, "final_answer", None):
292+
final_answer = block.result.final_answer
293+
break
294+
if final_answer is None:
295+
final_answer = find_final_answer(iteration.response, environment=environment)
262296
iteration.final_answer = final_answer
263297

264298
# If logger is used, log the iteration.
@@ -294,6 +328,8 @@ def completion(
294328

295329
# Update message history with the new messages.
296330
message_history.extend(new_messages)
331+
if self.compaction and hasattr(environment, "append_compaction_entry"):
332+
environment.append_compaction_entry(new_messages)
297333

298334
# Default behavior: we run out of iterations, provide one final answer
299335
time_end = time.perf_counter()
@@ -317,6 +353,50 @@ def completion(
317353
metadata=self.logger.get_trajectory() if self.logger else None,
318354
)
319355

356+
def _get_compaction_status(self, message_history: list[dict[str, Any]]) -> tuple[int, int, int]:
357+
"""Return (current_tokens, threshold_tokens, max_tokens) for compaction."""
358+
model_name = (
359+
self.backend_kwargs.get("model_name", "unknown") if self.backend_kwargs else "unknown"
360+
)
361+
max_tokens = get_context_limit(model_name)
362+
current_tokens = count_tokens(message_history, model_name)
363+
threshold_tokens = int(self.compaction_threshold_pct * max_tokens)
364+
return current_tokens, threshold_tokens, max_tokens
365+
366+
def _should_compact(self, message_history: list[dict[str, Any]]) -> bool:
367+
"""True when root message history is at or over the compaction threshold."""
368+
current_tokens, threshold_tokens, _ = self._get_compaction_status(message_history)
369+
return current_tokens >= threshold_tokens
370+
371+
def _compact_history(
372+
self,
373+
lm_handler: LMHandler,
374+
environment: BaseEnv,
375+
message_history: list[dict[str, Any]],
376+
) -> list[dict[str, Any]]:
377+
"""
378+
Summarize current trajectory, append summary to REPL history, and return
379+
a short message_history with the summary as the new starting point.
380+
"""
381+
summary_prompt = message_history + [
382+
{
383+
"role": "user",
384+
"content": "Very concisely summarize what you have been doing so far in 1–3 short paragraphs. Be extremely brief. This summary will be used to continue the conversation.",
385+
}
386+
]
387+
summary = lm_handler.completion(summary_prompt)
388+
if hasattr(environment, "append_compaction_entry"):
389+
environment.append_compaction_entry({"type": "summary", "content": summary})
390+
# Keep system + initial assistant (metadata), then summary + continue
391+
new_history = message_history[:2] + [
392+
{"role": "assistant", "content": summary},
393+
{
394+
"role": "user",
395+
"content": "Continue from the above summary. The full history (including this summary) is in the REPL variable `history`. Your next action:",
396+
},
397+
]
398+
return new_history
399+
320400
def _completion_turn(
321401
self,
322402
prompt: str | dict[str, Any],

rlm/core/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class REPLResult:
131131
locals: dict
132132
execution_time: float
133133
llm_calls: list["RLMChatCompletion"]
134+
final_answer: str | None = None
134135

135136
def __init__(
136137
self,
@@ -139,12 +140,14 @@ def __init__(
139140
locals: dict,
140141
execution_time: float = None,
141142
rlm_calls: list["RLMChatCompletion"] = None,
143+
final_answer: str | None = None,
142144
):
143145
self.stdout = stdout
144146
self.stderr = stderr
145147
self.locals = locals
146148
self.execution_time = execution_time
147149
self.rlm_calls = rlm_calls or []
150+
self.final_answer = final_answer
148151

149152
def __str__(self):
150153
return f"REPLResult(stdout={self.stdout}, stderr={self.stderr}, locals={self.locals}, execution_time={self.execution_time}, rlm_calls={len(self.rlm_calls)})"
@@ -156,6 +159,7 @@ def to_dict(self):
156159
"locals": {k: _serialize_value(v) for k, v in self.locals.items()},
157160
"execution_time": self.execution_time,
158161
"rlm_calls": [call.to_dict() for call in self.rlm_calls],
162+
"final_answer": self.final_answer,
159163
}
160164

161165

rlm/environments/local_repl.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
depth: int = 1,
133133
custom_tools: dict[str, Any] | None = None,
134134
custom_sub_tools: dict[str, Any] | None = None,
135+
compaction: bool = False,
135136
**kwargs,
136137
):
137138
super().__init__(persistent=persistent, depth=depth, **kwargs)
@@ -142,6 +143,7 @@ def __init__(
142143
self._lock = threading.Lock()
143144
self._context_count: int = 0
144145
self._history_count: int = 0
146+
self.compaction = compaction
145147

146148
# Custom tools: functions available in the REPL
147149
self.custom_tools = custom_tools or {}
@@ -156,6 +158,10 @@ def __init__(
156158
# Setup globals, locals, and modules in environment.
157159
self.setup()
158160

161+
if compaction:
162+
self._compaction_history: list[Any] = []
163+
self.locals["history"] = self._compaction_history
164+
159165
# Load context if provided
160166
if context_payload is not None:
161167
self.load_context(context_payload)
@@ -175,6 +181,8 @@ def setup(self):
175181

176182
# Track LLM calls made during code execution
177183
self._pending_llm_calls: list[RLMChatCompletion] = []
184+
# When FINAL_VAR is called inside a REPL block, we store the value here for the main loop
185+
self._last_final_answer: str | None = None
178186

179187
# Add helper functions
180188
self.globals["FINAL_VAR"] = self._final_var
@@ -192,13 +200,19 @@ def setup(self):
192200
# For non-callable values (constants, data), add to locals
193201
self.locals[name] = value
194202

195-
def _final_var(self, variable_name: str) -> str:
196-
"""Return the value of a variable as a final answer."""
203+
def _final_var(self, variable_name: str | Any) -> str:
204+
"""Return the value of a variable as a final answer for the main model, or stringify a direct value."""
205+
if not isinstance(variable_name, str):
206+
answer = str(variable_name)
207+
self._last_final_answer = answer
208+
return answer
197209
variable_name = variable_name.strip().strip("\"'")
198210
if variable_name in self.locals:
199-
return str(self.locals[variable_name])
211+
answer = str(self.locals[variable_name])
212+
self._last_final_answer = answer
213+
return answer
200214

201-
# Provide helpful error message with available variables
215+
# Provide helpful error message with available variables (do not set _last_final_answer)
202216
available = [k for k in self.locals.keys() if not k.startswith("_")]
203217
if available:
204218
return (
@@ -358,6 +372,17 @@ def get_history_count(self) -> int:
358372
"""Return the number of conversation histories stored."""
359373
return self._history_count
360374

375+
def append_compaction_entry(self, entry: list[dict[str, Any]] | dict[str, Any]) -> None:
376+
"""
377+
Append a trajectory segment or a summary to the compaction history.
378+
379+
Entry is either a list of message dicts (trajectory segment) or
380+
a dict with "type": "summary" and "content": str.
381+
"""
382+
if not self.compaction:
383+
return
384+
self._compaction_history.append(copy.deepcopy(entry))
385+
361386
@contextmanager
362387
def _capture_output(self):
363388
"""Thread-safe context manager to capture stdout/stderr."""
@@ -393,8 +418,10 @@ def _restore_scaffold(self) -> None:
393418
self.globals["SHOW_VARS"] = self._show_vars
394419
elif name == "context" and "context_0" in self.locals:
395420
self.locals["context"] = self.locals["context_0"]
396-
elif name == "history" and "history_0" in self.locals:
421+
elif name == "history" and "history_0" in self.locals and not self.compaction:
397422
self.locals["history"] = self.locals["history_0"]
423+
elif name == "history" and self.compaction:
424+
self.locals["history"] = self._compaction_history
398425

399426
def execute_code(self, code: str) -> REPLResult:
400427
"""Execute code in the persistent namespace and return result."""
@@ -422,12 +449,16 @@ def execute_code(self, code: str) -> REPLResult:
422449
stdout = stdout_buf.getvalue()
423450
stderr = stderr_buf.getvalue() + f"\n{type(e).__name__}: {e}"
424451

452+
final_answer = self._last_final_answer
453+
self._last_final_answer = None
454+
425455
return REPLResult(
426456
stdout=stdout,
427457
stderr=stderr,
428458
locals=self.locals.copy(),
429459
execution_time=time.perf_counter() - start_time,
430460
rlm_calls=self._pending_llm_calls.copy(),
461+
final_answer=final_answer,
431462
)
432463

433464
def __enter__(self):

0 commit comments

Comments
 (0)