diff --git a/AGENTS.md b/AGENTS.md index 289db1b1..cc86cdf0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -156,7 +156,7 @@ Environment implementations live in `rlm/environments/`. Choose the appropriate - Inherit from `NonIsolatedEnv` or `IsolatedEnv` in `rlm/environments/base_env.py` - Implement all abstract methods: `setup`, `load_context`, `execute_code` - Return `REPLResult` from `execute_code` -- Handle `lm_handler_address` for sub-LM calls via `llm_query()` +- Handle `lm_handler_address` for LM calls via `llm_query()` and `rlm_query()` - Implement `cleanup()` for resource management - Register environment in `rlm/environments/__init__.py` @@ -164,14 +164,17 @@ Environment implementations live in `rlm/environments/`. Choose the appropriate - `setup()`: Initialize globals, locals, and helper functions - `load_context()`: Make context available as `context` variable - `execute_code()`: Execute code, capture stdout/stderr, return `REPLResult` -- Always provide `llm_query` and `llm_query_batched` functions in environment globals +- Always provide `llm_query`, `llm_query_batched`, `rlm_query`, and `rlm_query_batched` functions in environment globals ### State Management Environments must provide these globals to executed code: - `context`: The loaded context payload -- `llm_query(prompt, model=None)`: For sub-LM calls -- `llm_query_batched(prompts, model=None)`: For batched sub-LM calls +- `llm_query(prompt, model=None)`: Plain single LM completion (no REPL, no iteration) +- `llm_query_batched(prompts, model=None)`: Batched plain LM completions +- `rlm_query(prompt, model=None)`: Recursive child RLM call (own REPL + iteration). Falls back to `llm_query` at max depth. +- `rlm_query_batched(prompts, model=None)`: Batched recursive child RLM calls - `FINAL_VAR(variable_name)`: For returning final answers +- `SHOW_VARS()`: For listing available variables ### Example Structure ```python @@ -204,7 +207,8 @@ class MyEnvironment(NonIsolatedEnv): - Guidelines here are followed - Environment works with basic RLM completion calls - `cleanup()` properly releases all resources -- Sub-LM calls work via `llm_query()` +- Sub-LM calls work via `llm_query()` and `rlm_query()` +- Reserved names (`llm_query`, `rlm_query`, `context`, `history`, `FINAL_VAR`, `SHOW_VARS`) are restored after each execution ## Architecture: Environment ↔ LM Handler Communication @@ -223,7 +227,7 @@ Understanding how environments communicate with the LM Handler is essential for │ ▼ │ │ │ ┌─────────────┐ Socket (TCP) │ │ │ │ LocalREPL │────────────────────────────────────┘ │ -│ │ (exec code) │ llm_query() → send_lm_request() │ +│ │ (exec code) │ llm_query() / rlm_query() → LM calls │ │ └─────────────┘ │ └─────────────────────────────────────────────────────────────────────┘ ``` @@ -242,8 +246,8 @@ def socket_send(sock: socket.socket, data: dict) -> None: ``` **Request Flow**: -1. Environment's `llm_query(prompt)` is called during code execution -2. Creates `LMRequest` dataclass and calls `send_lm_request(address, request)` +1. Environment's `llm_query(prompt)` or `rlm_query(prompt)` is called during code execution +2. For `llm_query`: creates `LMRequest` and calls `send_lm_request(address, request)`. For `rlm_query`: invokes `subcall_fn` to spawn a child RLM (or falls back to `llm_query` at max depth). 3. Opens TCP connection to `LMHandler` at `(host, port)` 4. Sends length-prefixed JSON request 5. `LMHandler` processes via `LMRequestHandler.handle()` diff --git a/README.md b/README.md index d4fdb5fa..3f8541ab 100644 --- a/README.md +++ b/README.md @@ -77,11 +77,11 @@ make quickstart ## REPL Environments -We support two types of REPL environments -- isolated, and non-isolated. Non-isolated environments (default) run code execution on the same machine as the RLM (e.g. through `exec`), which is pretty reasonable for some local low-risk tasks, like simple benchmarking, but can be problematic if the prompts or tool calls can interact with malicious users. Fully isolated environments used Cloud-based sandboxes (e.g. Prime Sandboxes, [Modal Sandboxes](https://modal.com/docs/guide/sandboxes)) to run code generated by the RLM, ensuring completely isolation from the host process. Environments can be added, but we natively support the following: `local` (default), `modal`, `prime`. +We support two types of REPL environments -- isolated, and non-isolated. Non-isolated environments (default) run code execution on the same machine as the RLM (e.g. through `exec`), which is pretty reasonable for some local low-risk tasks, like simple benchmarking, but can be problematic if the prompts or tool calls can interact with malicious users. Fully isolated environments use cloud-based sandboxes (e.g. Prime Sandboxes, [Modal Sandboxes](https://modal.com/docs/guide/sandboxes)) to run code generated by the RLM, ensuring complete isolation from the host process. Environments can be added, but we natively support the following: `local` (default), `docker`, `modal`, `prime`, `daytona`, `e2b`. ```python rlm = RLM( - environment="...", # "local", "docker", "modal", "prime" + environment="...", # "local", "docker", "modal", "prime", "daytona", "e2b" environment_kwargs={...}, ) ``` @@ -124,19 +124,19 @@ We currently support most major clients (OpenAI, Anthropic), as well as the rout If you use this code or repository in your research, please cite: ```bibtex -@misc{zhang2025recursivelanguagemodels, - title={Recursive Language Models}, +@misc{zhang2026recursivelanguagemodels, + title={Recursive Language Models}, author={Alex L. Zhang and Tim Kraska and Omar Khattab}, - year={2025}, + year={2026}, eprint={2512.24601}, archivePrefix={arXiv}, primaryClass={cs.AI}, - url={https://arxiv.org/abs/2512.24601}, + url={https://arxiv.org/abs/2512.24601}, } ``` ## Optional: Trajectory metadata and logging -`RLMChatCompletion` has an optional `metadata` field (default empty) that can hold the full trajectory (run config + all iterations and sub-calls) so you can reconstruct the run. Pass an `RLMLogger` to capture it: +`RLMChatCompletion` has an optional `metadata` field (default `None`) that holds the full trajectory (run config + all iterations and sub-calls) so you can reconstruct the run. Pass an `RLMLogger` to capture it: - **In-memory only** (trajectory on `completion.metadata`): `logger=RLMLogger()` (no `log_dir`). - **Also save to disk** (JSONL for the visualizer): `logger=RLMLogger(log_dir="./logs")`. diff --git a/docs/api/rlm.md b/docs/api/rlm.md index 14461d81..383f2cb2 100644 --- a/docs/api/rlm.md +++ b/docs/api/rlm.md @@ -28,7 +28,7 @@ from rlm import RLM rlm = RLM( backend="openai", - backend_kwargs={"model_name": "gpt-5"}, + backend_kwargs={"model_name": "gpt-5-nano"}, ) ``` @@ -45,11 +45,24 @@ RLM( depth: int = 0, max_depth: int = 1, max_iterations: int = 30, + max_budget: float | None = None, + max_timeout: float | None = None, + max_tokens: int | None = None, + max_errors: int | None = None, custom_system_prompt: str | None = None, other_backends: list[str] | None = None, other_backend_kwargs: list[dict] | None = None, logger: RLMLogger | None = None, verbose: bool = False, + persistent: bool = False, + custom_tools: dict[str, Any] | None = None, + custom_sub_tools: dict[str, Any] | None = None, + compaction: bool = False, + compaction_threshold_pct: float = 0.85, + on_subcall_start: Callable | None = None, + on_subcall_complete: Callable | None = None, + on_iteration_start: Callable | None = None, + on_iteration_complete: Callable | None = None, ) ``` @@ -58,7 +71,7 @@ RLM( #### `backend` {: .no_toc } -**Type:** `Literal["openai", "portkey", "openrouter", "vllm", "litellm", "anthropic"]` +**Type:** `Literal["openai", "portkey", "openrouter", "vllm", "litellm", "anthropic"]` **Default:** `"openai"` The LM provider backend to use for the root model. @@ -79,7 +92,7 @@ rlm = RLM(backend="vllm", ...) #### `backend_kwargs` {: .no_toc } -**Type:** `dict[str, Any] | None` +**Type:** `dict[str, Any] | None` **Default:** `None` Configuration passed to the LM client. Required fields vary by backend: @@ -106,23 +119,26 @@ backend_kwargs = { #### `environment` {: .no_toc } -**Type:** `Literal["local", "modal", "docker"]` +**Type:** `Literal["local", "docker", "modal", "prime", "daytona", "e2b"]` **Default:** `"local"` The execution environment for running generated code. | Environment | Description | |:------------|:------------| -| `local` | Same-process execution with sandboxed builtins | +| `local` | Same-process execution with sandboxed builtins (default) | | `docker` | Containerized execution in Docker | | `modal` | Cloud sandbox via Modal | +| `prime` | Cloud sandbox via Prime Intellect | +| `daytona` | Cloud sandbox via Daytona | +| `e2b` | Cloud sandbox via E2B | --- #### `environment_kwargs` {: .no_toc } -**Type:** `dict[str, Any] | None` +**Type:** `dict[str, Any] | None` **Default:** `None` Configuration for the execution environment: @@ -155,19 +171,24 @@ environment_kwargs = { #### `max_depth` {: .no_toc } -**Type:** `int` +**Type:** `int` **Default:** `1` -Maximum recursion depth for nested RLM calls. Currently only depth 1 is fully supported. +Maximum recursion depth for nested RLM calls. When `max_depth > 1`, the REPL provides `rlm_query()` and `rlm_query_batched()` functions that spawn child RLMs with their own REPL environments. -When `depth >= max_depth`, the RLM falls back to a regular LM completion. +When `depth >= max_depth`, `rlm_query()` falls back to a plain `llm_query()` call (no REPL, no iteration). + +```python +# Enable one level of recursive sub-calls +rlm = RLM(..., max_depth=2) +``` --- #### `max_iterations` {: .no_toc } -**Type:** `int` +**Type:** `int` **Default:** `30` Maximum number of REPL iterations before forcing a final answer. @@ -179,34 +200,69 @@ Each iteration consists of: ```python # For complex tasks, allow more iterations -rlm = RLM( - ..., - max_iterations=50, -) +rlm = RLM(..., max_iterations=50) ``` --- +#### `max_budget` +{: .no_toc } + +**Type:** `float | None` +**Default:** `None` + +Maximum total USD cost for a completion. If exceeded, raises `BudgetExceededError`. Requires a backend that reports cost. + +--- + +#### `max_timeout` +{: .no_toc } + +**Type:** `float | None` +**Default:** `None` + +Maximum wall-clock seconds for a completion. If exceeded, raises `TimeoutExceededError`. The partial answer (if any) is available on the exception. + +--- + +#### `max_tokens` +{: .no_toc } + +**Type:** `int | None` +**Default:** `None` + +Maximum total tokens (input + output) for a completion. If exceeded, raises `TokenLimitExceededError`. + +--- + +#### `max_errors` +{: .no_toc } + +**Type:** `int | None` +**Default:** `None` + +Maximum consecutive REPL errors before aborting. The error counter resets on a successful execution. If exceeded, raises `ErrorThresholdExceededError`. + +--- + #### `custom_system_prompt` {: .no_toc } -**Type:** `str | None` +**Type:** `str | None` **Default:** `None` Override the default RLM system prompt. The default prompt instructs the LM on: - How to use the `context` variable -- How to call `llm_query()` and `llm_query_batched()` -- How to signal completion with `FINAL()` +- How to call `llm_query()` / `llm_query_batched()` for plain LM calls +- How to call `rlm_query()` / `rlm_query_batched()` for recursive sub-calls +- How to signal completion with `FINAL()` or `FINAL_VAR()` ```python custom_prompt = """You are a data analysis expert. Use the REPL to analyze the context variable. When done, output FINAL(your answer).""" -rlm = RLM( - ..., - custom_system_prompt=custom_prompt, -) +rlm = RLM(..., custom_system_prompt=custom_prompt) ``` --- @@ -214,26 +270,24 @@ rlm = RLM( #### `other_backends` / `other_backend_kwargs` {: .no_toc } -**Type:** `list[str] | None` / `list[dict] | None` +**Type:** `list[str] | None` / `list[dict] | None` **Default:** `None` -Register additional LM backends available for sub-calls via `llm_query()`. +Register additional LM backends. The first `other_backend` is used as the default for depth-routed sub-calls (e.g. `llm_query()` calls from code at depth > 0 are routed to the other backend). Additional backends are registered by model name and can be selected explicitly. ```python rlm = RLM( backend="openai", backend_kwargs={"model_name": "gpt-4o"}, - other_backends=["anthropic", "openai"], + other_backends=["anthropic"], other_backend_kwargs=[ {"model_name": "claude-sonnet-4-20250514"}, - {"model_name": "gpt-4o-mini"}, ], ) # Inside REPL, code can call: -# llm_query(prompt) # Uses default (gpt-4o) -# llm_query(prompt, model="claude-sonnet-4-20250514") # Uses Claude -# llm_query(prompt, model="gpt-4o-mini") # Uses GPT-4o-mini +# llm_query(prompt) # Routed to other_backend (Claude) at depth > 0 +# llm_query(prompt, model="gpt-4o") # Explicit model override ``` --- @@ -241,15 +295,20 @@ rlm = RLM( #### `logger` {: .no_toc } -**Type:** `RLMLogger | None` +**Type:** `RLMLogger | None` **Default:** `None` -Logger for saving iteration trajectories to disk. +Logger for capturing trajectory metadata. When provided, the returned `RLMChatCompletion.metadata` field contains the full trajectory (iterations, code blocks, sub-calls). ```python from rlm.logger import RLMLogger +# In-memory only (trajectory on result.metadata) +logger = RLMLogger() + +# Also save to disk (JSONL for the visualizer) logger = RLMLogger(log_dir="./logs") + rlm = RLM(..., logger=logger) ``` @@ -258,7 +317,7 @@ rlm = RLM(..., logger=logger) #### `verbose` {: .no_toc } -**Type:** `bool` +**Type:** `bool` **Default:** `False` Enable rich console output showing: @@ -269,6 +328,98 @@ Enable rich console output showing: --- +#### `persistent` +{: .no_toc } + +**Type:** `bool` +**Default:** `False` + +When enabled, reuses the same environment across multiple `completion()` calls. This enables multi-turn conversations where each call adds a new context and the model retains all previous variables and state. + +Contexts are versioned (`context_0`, `context_1`, ...) with `context` always aliasing `context_0`. Conversation histories from previous calls are available as `history_0`, `history_1`, etc. + +Supports the context manager protocol for automatic cleanup: + +```python +with RLM(..., persistent=True) as rlm: + result1 = rlm.completion("First context") + result2 = rlm.completion("Second context") # Can access context_0 and context_1 +``` + +--- + +#### `custom_tools` +{: .no_toc } + +**Type:** `dict[str, Any] | None` +**Default:** `None` + +Custom functions and data available in the REPL environment. Callable values are added to globals (callable by the model), non-callable values are added to locals (accessible as variables). + +Two formats are supported: + +```python +custom_tools = { + # Plain value + "fetch_data": my_fetch_function, + "API_KEY": "sk-...", + + # With description (shown in system prompt) + "calculator": { + "tool": calc_function, + "description": "Performs arithmetic calculations", + }, +} +``` + +Reserved names (`llm_query`, `rlm_query`, `context`, `history`, `FINAL_VAR`, `SHOW_VARS`, and their batched variants) cannot be used as tool names. + +--- + +#### `custom_sub_tools` +{: .no_toc } + +**Type:** `dict[str, Any] | None` +**Default:** `None` + +Separate set of custom tools for child RLMs spawned via `rlm_query()`. If `None`, children inherit the parent's `custom_tools`. Pass an empty dict `{}` to disable custom tools for children. + +--- + +#### `compaction` +{: .no_toc } + +**Type:** `bool` +**Default:** `False` + +When enabled, automatically summarizes the conversation history when token usage exceeds `compaction_threshold_pct` of the model's context window. The full history (including summaries) is available in the REPL as the `history` variable. + +--- + +#### `compaction_threshold_pct` +{: .no_toc } + +**Type:** `float` +**Default:** `0.85` + +Fraction of the model's context window that triggers compaction. Only used when `compaction=True`. + +--- + +#### Event Callbacks +{: .no_toc } + +Optional callbacks for monitoring execution progress: + +| Callback | Signature | Triggered when | +|:---------|:----------|:---------------| +| `on_iteration_start` | `(depth: int, iteration_num: int)` | An iteration begins | +| `on_iteration_complete` | `(depth: int, iteration_num: int, duration: float)` | An iteration completes | +| `on_subcall_start` | `(depth: int, model: str, prompt_preview: str)` | A child RLM is spawned | +| `on_subcall_complete` | `(depth: int, model: str, duration: float, error: str \| None)` | A child RLM finishes | + +--- + ## Methods ### `completion()` @@ -307,10 +458,10 @@ result = rlm.completion(["doc1", "doc2", "doc3"]) **`root_prompt`** {: .no_toc } -Optional short prompt shown to the root LM. Useful for Q&A tasks where the question should be visible throughout. +Optional short prompt shown to the root LM on every iteration. Useful for Q&A tasks where the question should be visible throughout. ```python -# The context is the document, but the LM sees the question +# The context is the document, but the LM sees the question each iteration result = rlm.completion( prompt=long_document, root_prompt="What is the main theme of this document?" @@ -324,11 +475,12 @@ result = rlm.completion( ```python @dataclass class RLMChatCompletion: - root_model: str # Model name used - prompt: str | dict # Original input - response: str # Final answer + root_model: str # Model name used + prompt: str | dict # Original input + response: str # Final answer usage_summary: UsageSummary # Token usage - execution_time: float # Total seconds + execution_time: float # Total seconds + metadata: dict | None # Full trajectory when logger is provided ``` #### Example @@ -340,10 +492,15 @@ result = rlm.completion( print(result.response) # "158" print(result.execution_time) # 12.34 +print(result.metadata) # Trajectory dict (if logger provided), else None print(result.usage_summary.to_dict()) # {'model_usage_summaries': {'gpt-4o': {'total_calls': 5, ...}}} ``` +### `close()` + +Clean up persistent environment resources. Called automatically when using the context manager protocol (`with RLM(...) as rlm:`). + --- ## Response Types @@ -360,6 +517,7 @@ result.prompt # Original input result.response # Final answer string result.execution_time # Total time in seconds result.usage_summary # UsageSummary object +result.metadata # Full trajectory dict (if logger provided) ``` ### `UsageSummary` @@ -382,16 +540,29 @@ usage.to_dict() --- +## REPL Functions + +The following functions are available to model-generated code inside the REPL: + +| Function | Description | +|:---------|:------------| +| `llm_query(prompt, model=None)` | Single plain LM completion. Fast, no REPL or iteration. | +| `llm_query_batched(prompts, model=None)` | Multiple plain LM completions concurrently. | +| `rlm_query(prompt, model=None)` | Spawn a child RLM with its own REPL for deeper thinking. Falls back to `llm_query` at max depth. | +| `rlm_query_batched(prompts, model=None)` | Spawn multiple child RLMs. Falls back to `llm_query_batched` at max depth. | +| `FINAL_VAR(variable_name)` | Return a REPL variable as the final answer. | +| `SHOW_VARS()` | List all user-created variables in the REPL. | +| `print(...)` | Print output visible to the model in the next iteration. | + +--- + ## Error Handling RLM follows a "fail fast" philosophy: ```python # Missing required argument -rlm = RLM( - backend="vllm", - backend_kwargs={"model_name": "llama"}, -) +rlm = RLM(backend="vllm", backend_kwargs={"model_name": "llama"}) # Raises: AssertionError: base_url is required for vLLM # Unknown backend @@ -399,7 +570,30 @@ rlm = RLM(backend="unknown") # Raises: ValueError: Unknown backend: unknown ``` -If the RLM exhausts `max_iterations` without finding a `FINAL()` answer, it prompts the LM one more time to provide a final answer based on the conversation history. +If the RLM exhausts `max_iterations` without finding a `FINAL()` / `FINAL_VAR()` answer, it prompts the LM one more time to provide a final answer based on the conversation history. + +RLM raises explicit exceptions when limits are exceeded: + +| Exception | Raised when | Key attributes | +|:----------|:------------|:---------------| +| `BudgetExceededError` | `max_budget` exceeded | `spent`, `budget` | +| `TimeoutExceededError` | `max_timeout` exceeded | `elapsed`, `timeout`, `partial_answer` | +| `TokenLimitExceededError` | `max_tokens` exceeded | `tokens_used`, `token_limit`, `partial_answer` | +| `ErrorThresholdExceededError` | `max_errors` consecutive errors | `error_count`, `threshold`, `last_error`, `partial_answer` | +| `CancellationError` | `KeyboardInterrupt` during completion | `partial_answer` | + +All exceptions are importable from the top-level package: + +```python +from rlm import RLM, TimeoutExceededError, CancellationError + +try: + result = rlm.completion(prompt) +except TimeoutExceededError as e: + print(f"Timed out after {e.elapsed:.1f}s, partial: {e.partial_answer}") +except CancellationError as e: + print(f"Cancelled, partial: {e.partial_answer}") +``` --- @@ -407,7 +601,7 @@ If the RLM exhausts `max_iterations` without finding a `FINAL()` answer, it prom Each `completion()` call: 1. Spawns its own `LMHandler` socket server -2. Creates a fresh environment instance +2. Creates a fresh environment instance (unless persistent) 3. Cleans up both when done This makes `completion()` calls independent, but the `RLM` instance itself should not be shared across threads without external synchronization. @@ -421,7 +615,7 @@ import os from rlm import RLM from rlm.logger import RLMLogger -logger = RLMLogger(log_dir="./logs", file_name="analysis") +logger = RLMLogger(log_dir="./logs") rlm = RLM( # Primary model @@ -430,24 +624,35 @@ rlm = RLM( "api_key": os.getenv("ANTHROPIC_API_KEY"), "model_name": "claude-sonnet-4-20250514", }, - + # Execution environment - environment="docker", - environment_kwargs={ - "image": "python:3.11-slim", - }, - - # Additional models for sub-calls + environment="local", + + # Additional model for sub-calls (routed at depth > 0) other_backends=["openai"], other_backend_kwargs=[{ "api_key": os.getenv("OPENAI_API_KEY"), "model_name": "gpt-4o-mini", }], - - # Behavior + + # Recursion: allow one level of child RLMs via rlm_query() + max_depth=2, max_iterations=40, - max_depth=1, - + + # Limits + max_timeout=120.0, + max_budget=1.0, + max_errors=5, + + # Custom tools available in the REPL + custom_tools={ + "fetch_data": {"tool": my_fetch_fn, "description": "Fetch data from API"}, + }, + + # Compaction for long conversations + compaction=True, + compaction_threshold_pct=0.85, + # Debugging logger=logger, verbose=True, @@ -455,7 +660,9 @@ rlm = RLM( result = rlm.completion( prompt=massive_document, - root_prompt="Summarize the key findings" + root_prompt="Summarize the key findings", ) -``` +print(result.response) +print(result.metadata) # Full trajectory (iterations, sub-calls, etc.) +``` diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 00000000..7ed305ac --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,359 @@ +--- +layout: default +title: Architecture +nav_order: 3 +--- + +# Architecture +{: .no_toc } + +How the RLM runtime, LM handler, code execution, and recursive sub-calls fit together. +{: .fs-6 .fw-300 } + +## Table of Contents +{: .no_toc .text-delta } + +1. TOC +{:toc} + +--- + +## Overview + +An RLM completion involves three cooperating pieces: + +1. **RLM** (`rlm/core/rlm.py`) — the main loop that drives iteration. +2. **LMHandler** (`rlm/core/lm_handler.py`) — a per-completion TCP server that routes LM API calls. +3. **LocalREPL** (`rlm/environments/local_repl.py`) — the Python execution environment where model-generated code runs. + +``` +┌────────────────────────────────────────────────────────────────┐ +│ RLM.completion(prompt) │ +│ │ +│ 1. Spawn LMHandler (TCP server on localhost, auto port) │ +│ 2. Create LocalREPL (in-process exec() namespace) │ +│ 3. Iterate: │ +│ a. Send message history → LM backend → get response │ +│ b. Extract ```repl``` code blocks from response │ +│ c. Execute code in LocalREPL │ +│ d. Append stdout/stderr to message history │ +│ e. Repeat until FINAL_VAR / FINAL or limits exceeded │ +│ 4. Tear down handler and environment │ +└────────────────────────────────────────────────────────────────┘ +``` + +--- + +## LM Handler + +### What it is + +The LMHandler is a **multi-threaded TCP socket server** that sits between +the execution environment and the actual LM API backends. Every call to +`RLM.completion()` spins up a fresh handler (unless the environment is +persistent and already has one). + +### Why a socket server? + +The handler exists so that code running inside the execution environment +can make LM calls back to the host process without directly importing or +calling the LM client. This is essential for isolated environments (Docker, +Modal) that run in separate processes or machines — they communicate with +the handler over TCP. The local environment uses the same protocol for +consistency, even though it runs in-process. + +### Lifecycle + +```python +# Inside RLM._spawn_completion_context(): +client = get_client(backend, backend_kwargs) # 1. Create LM client +lm_handler = LMHandler(client, other_backend_client=…) # 2. Wrap in handler +lm_handler.start() # 3. Start TCP server (daemon thread) +# … run completion loop … +lm_handler.stop() # 4. Shut down server +``` + +- The server binds to `127.0.0.1` with port `0` (OS auto-assigns an available port). +- It runs in a **daemon thread** so it doesn't block process exit. +- Each incoming connection is handled by a new thread (`ThreadingTCPServer`). + +### Wire protocol + +All messages use a simple framing: **4-byte big-endian length prefix + UTF-8 JSON payload**. + +``` +┌──────────┬─────────────────────────┐ +│ 4 bytes │ N bytes │ +│ len (BE) │ JSON payload (UTF-8) │ +└──────────┴─────────────────────────┘ +``` + +Implemented in `socket_send()` / `socket_recv()` in `rlm/core/comms_utils.py`. + +### Client routing + +The handler can hold multiple LM clients and routes requests based on the +`model` and `depth` fields in the request: + +```python +def get_client(self, model=None, depth=0): + if model and model in self.clients: + return self.clients[model] # Explicit model override + if depth == 1 and self.other_backend_client: + return self.other_backend_client # Depth-based routing + return self.default_client # Fallback +``` + +This lets you use a different (e.g. cheaper/faster) model for sub-LM calls +by specifying `other_backends` / `other_backend_kwargs` in the RLM constructor. + +--- + +## Code Execution Environment (LocalREPL) + +### In-process `exec()` — not a subprocess + +LocalREPL executes model-generated code **in the same Python interpreter +process** as the RLM, using Python's built-in `exec()`. There is no +subprocess, no fork, and no IPC for code execution. + +```python +# Simplified from LocalREPL.execute_code(): +combined = {**self.globals, **self.locals} +exec(code, combined, combined) +``` + +### What this means in practice + +- **Fast**: No process spawn overhead. Code execution is as fast as native Python. +- **Persistent namespace**: Variables created in one code block are visible in the next. The `self.locals` dict accumulates state across iterations. +- **Shared memory**: Helper functions like `llm_query()` and `rlm_query()` are plain Python closures in `self.globals`. When model code calls `llm_query("...")`, it's a direct function call within the same process. +- **Limited sandbox**: Dangerous builtins (`eval`, `exec`, `compile`, `input`) are removed from the namespace. This is a soft sandbox — it prevents accidental misuse but is not a security boundary. + +### Namespace layout + +``` +globals (shared across all executions): +├── __builtins__ → _SAFE_BUILTINS (eval/exec/input removed) +├── llm_query() → plain LM call via handler +├── llm_query_batched() → batched plain LM calls +├── rlm_query() → recursive RLM sub-call (or fallback to llm_query) +├── rlm_query_batched() → batched recursive sub-calls +├── FINAL_VAR() → mark a variable as the final answer +├── SHOW_VARS() → list user-created variables +└── → user-provided callable tools + +locals (accumulates user variables): +├── context → alias for context_0 +├── context_0 → first context payload +├── context_1, … → additional contexts (persistent mode) +├── history → conversation history (persistent/compaction mode) +└── → anything created by model code +``` + +### Scaffold restoration + +After each `exec()`, LocalREPL restores all reserved names to prevent model +code from corrupting the environment. If the model writes +`llm_query = "oops"` or `context = None`, the next execution will still +have the real functions and data. See `_restore_scaffold()`. + +--- + +## How `llm_query()` and `rlm_query()` Work + +These are the two functions available to model-generated code for making LM calls. +They have very different behaviors: + +### `llm_query(prompt, model=None)` — Plain LM call + +Always makes a single, direct LM completion. No REPL, no iteration — just +prompt in, text out. Fast and lightweight. + +``` +Model code: answer = llm_query("Summarize this text: ...") + │ + ▼ +LocalREPL._llm_query() + │ Creates LMRequest(prompt=..., depth=self.depth) + │ Opens TCP socket to handler + ▼ +LMHandler (TCP server) + │ get_client(model, depth) → selects backend + │ client.completion(prompt) → calls LM API + ▼ +Response flows back over socket + │ + ▼ +Returns response string to model code +``` + +### `rlm_query(prompt, model=None)` — Recursive RLM sub-call + +Spawns a **child RLM** that gets its own REPL and can reason iteratively +over the prompt — just like the parent. Use this when the subtask needs +multi-step reasoning, code execution, or its own iterative problem-solving. + +Falls back to `llm_query` when recursion is not available (i.e. the current +depth has reached `max_depth`). + +``` +Model code: answer = rlm_query("Solve this complex problem: ...") + │ + ▼ +LocalREPL._rlm_query() + │ if self.subcall_fn is not None: ← set when max_depth > 1 + │ calls self.subcall_fn(prompt, model) + │ else: + │ falls back to _llm_query() + │ + ▼ (when subcall_fn exists) +RLM._subcall(prompt, model) + │ next_depth = self.depth + 1 + │ if next_depth >= max_depth: + │ → plain client.completion() (leaf call, no REPL) + │ else: + │ → create child RLM(depth=next_depth, ...) + │ → child.completion(prompt) ← full RLM with its own handler + REPL + │ + ▼ +Returns RLMChatCompletion to parent +``` + +### `llm_query_batched` / `rlm_query_batched` + +Same semantics as above, but for multiple prompts. `llm_query_batched` sends +all prompts as a single batched request to the handler, which processes them +concurrently with `asyncio.gather`. `rlm_query_batched` calls `subcall_fn` +sequentially for each prompt (each child RLM is a blocking call). + +--- + +## Recursive Sub-Calls (Depth > 1) + +### How depth works + +``` +max_depth=3 + +RLM (depth=0) + └─ rlm_query() → child RLM (depth=1) + └─ rlm_query() → child RLM (depth=2) + └─ rlm_query() → plain LM call (depth=3 >= max_depth, no REPL) +``` + +- `depth=0` is the root RLM that the user calls. +- Each child increments depth by 1. +- When `next_depth >= max_depth`, `_subcall()` does a plain `client.completion()` instead of creating a child RLM. This is the leaf case — no REPL, no iteration. +- `llm_query()` always does a plain LM call regardless of depth. Only `rlm_query()` triggers recursion. + +### Each child gets its own handler and environment + +When a child RLM is created via `_subcall()`, its `completion()` method calls +`_spawn_completion_context()` which creates: + +1. **A new `LMHandler`** listening on a **different auto-assigned port**. +2. **A new `LocalREPL`** with its own isolated namespace. + +``` +Parent RLM (depth=0) +├── LMHandler #1 on port 52301 +├── LocalREPL #1 (depth=1) +│ ├── globals: {llm_query, rlm_query, ...} +│ ├── locals: {context: "parent prompt", ...} +│ └── subcall_fn = RLM._subcall ← enables rlm_query() +│ +└── When model code calls rlm_query("subtask"): + │ + └── Child RLM (depth=1) + ├── LMHandler #2 on port 52302 ← NEW handler, NEW port + ├── LocalREPL #2 (depth=2) ← NEW namespace + │ ├── locals: {context: "subtask", ...} + │ └── subcall_fn = child._subcall (or None if at max_depth-1) + │ + └── Runs its own iteration loop, returns RLMChatCompletion +``` + +The child's handler and environment are torn down when the child's `completion()` finishes. + +### Resource limits propagate + +The parent passes **remaining** budget/timeout/tokens to the child, not the +original totals. This prevents a child from consuming all of the parent's resources: + +```python +# In _subcall(): +remaining_timeout = self.max_timeout - elapsed # not self.max_timeout +remaining_budget = self.max_budget - spent # not self.max_budget +child = RLM(..., max_timeout=remaining_timeout, max_budget=remaining_budget) +``` + +### Metadata flows back + +Each child RLM can have its own `RLMLogger`. When the child completes, its +full trajectory metadata (iterations, code blocks, sub-calls) is captured in +the returned `RLMChatCompletion.metadata` dict. The parent's logger records +this as part of the REPL result's `rlm_calls` list, creating a nested +metadata tree. + +--- + +## Putting It All Together + +Here's the complete request flow for a depth-2 RLM call: + +``` +User: rlm.completion("Analyze this data") + │ + ▼ +RLM (depth=0) + ├─ _spawn_completion_context() + │ ├─ LMHandler #1 starts on port 52301 + │ └─ LocalREPL #1 created with context="Analyze this data" + │ + ├─ Iteration 1: LM generates code + │ │ ```repl + │ │ answer = rlm_query("What patterns exist in: " + context[:5000]) + │ │ ``` + │ │ + │ └─ LocalREPL.execute_code() runs the code via exec() + │ │ + │ ├─ rlm_query() → _rlm_query() → subcall_fn() + │ │ │ + │ │ └─ RLM._subcall("What patterns exist in: ...") + │ │ │ + │ │ ├─ depth=1 < max_depth=2, so create child RLM + │ │ │ + │ │ └─ Child RLM (depth=1) + │ │ ├─ LMHandler #2 on port 52302 + │ │ ├─ LocalREPL #2 with context="What patterns..." + │ │ │ + │ │ ├─ Child iteration 1: LM generates code + │ │ │ │ result = llm_query("Extract key metrics: " + context) + │ │ │ │ + │ │ │ └─ llm_query() → TCP to Handler #2 → LM API → response + │ │ │ + │ │ ├─ Child iteration 2: LM calls FINAL_VAR(result) + │ │ │ + │ │ └─ Returns RLMChatCompletion to parent + │ │ + │ └─ answer = child_completion.response + │ + ├─ Iteration 2: LM uses answer, calls FINAL_VAR(final) + │ + ├─ LMHandler #1 stops + └─ Returns RLMChatCompletion to user +``` + +### Key takeaways + +| Aspect | Detail | +|:-------|:-------| +| Code execution | In-process `exec()` in the same Python interpreter. Not a subprocess. | +| LM calls from code | Go through a local TCP socket server (LMHandler), even for in-process execution. | +| Handler per completion | Each `completion()` call gets its own handler on an auto-assigned port. | +| Child RLMs | Created by `_subcall()`, each with its own handler + LocalREPL. Fully independent. | +| `llm_query` vs `rlm_query` | `llm_query` = always plain LM call. `rlm_query` = recursive child RLM (or fallback). | +| Depth limit | At `max_depth`, `rlm_query` falls back to `llm_query`. No further recursion. | +| Resource isolation | Children get remaining budget/timeout, not the full amount. | +| Namespace isolation | Each LocalREPL has its own `globals`/`locals`. No shared state between parent and child. | diff --git a/docs/getting-started.md b/docs/getting-started.md index 814e4ebb..a288eb1a 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -131,14 +131,22 @@ This will display: | `backend_kwargs` | `dict` | `None` | Backend-specific configuration | | `environment` | `str` | `"local"` | Execution environment type | | `environment_kwargs` | `dict` | `None` | Environment configuration | -| `max_depth` | `int` | `1` | Maximum recursion depth | +| `max_depth` | `int` | `1` | Maximum recursion depth for `rlm_query()` | | `max_iterations` | `int` | `30` | Max REPL iterations per call | +| `max_budget` | `float` | `None` | Max total USD cost (if provider reports cost) | +| `max_timeout` | `float` | `None` | Max wall-clock seconds per completion | +| `max_tokens` | `int` | `None` | Max total tokens (input + output) per completion | +| `max_errors` | `int` | `None` | Max consecutive REPL errors before abort | | `custom_system_prompt` | `str` | `None` | Override default system prompt | | `other_backends` | `list` | `None` | Additional backends for sub-calls | | `other_backend_kwargs` | `list` | `None` | Configs for additional backends | -| `logger` | `RLMLogger` | `None` | Logger for trajectory tracking | +| `logger` | `RLMLogger` | `None` | Logger for trajectory tracking and `metadata` capture | | `verbose` | `bool` | `False` | Enable console output | +| `persistent` | `bool` | `False` | Reuse environment across `completion()` calls | | `custom_tools` | `dict` | `None` | Custom functions/data available in REPL | +| `custom_sub_tools` | `dict` | `None` | Custom tools for child RLMs (defaults to `custom_tools`) | +| `compaction` | `bool` | `False` | Auto-summarize history when context fills up | +| `compaction_threshold_pct` | `float` | `0.85` | Context usage fraction that triggers compaction | ### The `completion()` Method @@ -159,12 +167,35 @@ result = rlm.completion( - `execution_time`: Total time in seconds - `root_model`: Model name used - `prompt`: Original input +- `metadata`: Full trajectory dict (if `logger` was provided, else `None`) + +### Depth>1 Recursion + +Depth>1 recursive subcalls are supported. The REPL provides two LM call functions: + +- **`llm_query(prompt)`** — Always makes a plain, single LM completion. Fast and lightweight. Use for simple extraction, summarization, or Q&A. +- **`rlm_query(prompt)`** — Spawns a child RLM with its own REPL and iterative reasoning. Use for subtasks that need multi-step thinking or code execution. Falls back to `llm_query` when `depth >= max_depth`. + +Both have batched variants (`llm_query_batched`, `rlm_query_batched`) for processing multiple prompts concurrently. + +See [Architecture](architecture.md) for details on how handlers, environments, and recursive sub-calls work. + +### Limits and Exceptions + +RLM raises explicit exceptions when limits are exceeded: +- `BudgetExceededError` +- `TimeoutExceededError` +- `TokenLimitExceededError` +- `ErrorThresholdExceededError` +- `CancellationError` (on `KeyboardInterrupt`) + +All exceptions are importable from the top-level package (`from rlm import TimeoutExceededError, ...`). --- ## Choosing an Environment -RLM supports three execution environments: +RLM supports several execution environments: ### Local (Default) @@ -423,6 +454,6 @@ Upload `.jsonl` log files to visualize: ## Next Steps - [API Reference](api/rlm.md) - Complete RLM class documentation +- [Architecture](architecture.md) - How the handler, REPL, and recursive sub-calls work - [Environments](environments/) - Deep dive into each environment - [Backends](backends.md) - Detailed backend configuration - diff --git a/examples/compaction_example.py b/examples/compaction_example.py index 69e3cecf..d1ae38d0 100644 --- a/examples/compaction_example.py +++ b/examples/compaction_example.py @@ -1,11 +1,17 @@ """ Compaction example: run RLM with compaction enabled and a low threshold -to trigger summarization on gpt-5-nano. +to trigger summarization. -Uses a low compaction_threshold_pct so compaction runs after several iterations. -The task forces many separate REPL turns with substantial output so the root -context grows and compaction definitely triggers. The REPL variable `history` -holds trajectory segments and summaries. +The task uses random data so earlier results are impossible to recompute — +the model *must* recover them from `history` after compaction fires. + +Phase 1: Generate random datasets and compute statistics (fills context). +Phase 2: Compaction fires, summarizing the conversation. +Phase 3: Combine the earlier statistics into a final answer. Since the data + was random, the model cannot just re-derive the numbers. + +Usage: + PORTKEY_API_KEY=... python examples/compaction_example.py """ import os @@ -17,9 +23,9 @@ load_dotenv() -# Low threshold so compaction triggers after a few iterations (~2% of context). +# Low threshold so compaction triggers after a few iterations. # Use 0.85 in production. -COMPACTION_THRESHOLD_PCT = 0.03 +COMPACTION_THRESHOLD_PCT = 0.02 logger = RLMLogger() rlm = RLM( @@ -31,35 +37,32 @@ environment="local", environment_kwargs={}, max_depth=1, - max_iterations=10, + max_iterations=12, compaction=True, compaction_threshold_pct=COMPACTION_THRESHOLD_PCT, verbose=True, logger=logger, ) -# Hard task that forces many iterations: find the 50th prime using at least 8 -# separate REPL blocks, one per turn. Each block must produce visible output. -# This grows message history (long reasoning + code + execution results) so -# compaction triggers. prompt = ( - "Find the 50th prime number. You MUST use at least 8 separate REPL blocks, " - "each in its own response (one block per message). Do NOT combine steps.\n\n" - "Required structure: " - "Block 1: Define a function is_prime(n) and test it on a few numbers; print the results. " - "Block 2: Write a loop that counts primes and print the first 10 primes. " - "Block 3: Extend to count up to 20 primes and print them. " - "Block 4: Count up to 30 primes and print the 25th. " - "Block 5: Count up to 40 primes and print the 35th. " - "Block 6: Count up to 50 primes and print the 45th. " - "Block 7: Print the full list of the first 50 primes (so we see all 50). " - "Block 8: Set answer = (the 50th prime) and call FINAL_VAR(answer).\n\n" - "Each block must run alone. Show your reasoning briefly before each block. " - "After each block, wait for the execution result before writing the next block." + "Complete the following steps, each in its own REPL block (one block per message). " + "Do NOT combine steps.\n\n" + "Step 1: Use `import random; random.seed(42)` then generate a list of 50 random " + "integers between 1 and 1000. Print ALL of them. Store as `data_a`.\n\n" + "Step 2: Generate another 50 random integers (continuing the same RNG stream) " + "between 1 and 1000. Print ALL of them. Store as `data_b`.\n\n" + "Step 3: Compute and print the mean, median, min, max, and standard deviation " + "of `data_a`. Store the mean as `mean_a`.\n\n" + "Step 4: Compute and print the mean, median, min, max, and standard deviation " + "of `data_b`. Store the mean as `mean_b`.\n\n" + "Step 5: You previously computed `mean_a` and `mean_b`. " + "Compute final_answer = round(mean_a + mean_b, 2) and call FINAL_VAR(final_answer)." ) result = rlm.completion(prompt, root_prompt=prompt) -print("Response:", result.response) -print("Execution time:", result.execution_time) -print("Metadata:", result.metadata) +print("\n" + "=" * 60) +print("RESULT") +print("=" * 60) +print(f"Response: {result.response}") +print(f"Execution time: {result.execution_time:.2f}s") diff --git a/examples/compaction_history_retrieval_example.py b/examples/compaction_history_retrieval_example.py new file mode 100644 index 00000000..51b959bb --- /dev/null +++ b/examples/compaction_history_retrieval_example.py @@ -0,0 +1,65 @@ +""" +Compaction + history retrieval example. + +Like compaction_example.py but with more intermediate results to recover. +Uses random data across 4 groups so none of the statistics are re-derivable +after compaction — the model must look in `history` to find them. + +Usage: + PORTKEY_API_KEY=... python examples/compaction_history_retrieval_example.py +""" + +import os + +from dotenv import load_dotenv + +from rlm import RLM +from rlm.logger import RLMLogger + +load_dotenv() + +# Very low threshold so compaction fires after the first few steps. +COMPACTION_THRESHOLD_PCT = 0.02 + +logger = RLMLogger() +rlm = RLM( + backend="portkey", + backend_kwargs={ + "model_name": "@openai/gpt-5-nano", + "api_key": os.getenv("PORTKEY_API_KEY"), + }, + environment="local", + environment_kwargs={}, + max_depth=1, + max_iterations=15, + compaction=True, + compaction_threshold_pct=COMPACTION_THRESHOLD_PCT, + verbose=True, + logger=logger, +) + +prompt = ( + "Complete the following steps, each in its own REPL block (one block per message). " + "Do NOT combine steps.\n\n" + "Step 1: `import random; random.seed(99)`. Generate 40 random floats in [0, 100). " + "Print all of them. Store as `group_1`.\n\n" + "Step 2: Generate another 40 random floats (same RNG stream). Print all. " + "Store as `group_2`.\n\n" + "Step 3: Generate another 40 random floats. Print all. Store as `group_3`.\n\n" + "Step 4: Generate another 40 random floats. Print all. Store as `group_4`.\n\n" + "Step 5: Compute the mean of each group. Print all four means. " + "Store as `m1`, `m2`, `m3`, `m4`.\n\n" + "Step 6: Compute the standard deviation of each group. Print all four. " + "Store as `s1`, `s2`, `s3`, `s4`.\n\n" + "Step 7: You previously computed four means and four standard deviations. " + "Compute final_answer = round((m1 + m2 + m3 + m4) / (s1 + s2 + s3 + s4), 4). " + "Print the full equation with values and call FINAL_VAR(final_answer)." +) + +result = rlm.completion(prompt, root_prompt=prompt) + +print("\n" + "=" * 60) +print("RESULT") +print("=" * 60) +print(f"Response: {result.response}") +print(f"Execution time: {result.execution_time:.2f}s") diff --git a/examples/depth_metadata_example.py b/examples/depth_metadata_example.py new file mode 100644 index 00000000..f283a3e0 --- /dev/null +++ b/examples/depth_metadata_example.py @@ -0,0 +1,178 @@ +""" +Example: depth>1 RLM with metadata inspection. + +Demonstrates that when a parent RLM spawns child RLMs via rlm_query(), +each child captures its own trajectory metadata. The metadata flows back +through RLMChatCompletion.metadata on the rlm_calls recorded in each +code block's REPLResult. + +Usage: + PORTKEY_API_KEY=... python examples/depth_metadata_example.py + +Prints a structured tree of the metadata at all depth levels. +""" + +import json +import os +import sys + +from dotenv import load_dotenv + +from rlm import RLM +from rlm.logger import RLMLogger + +load_dotenv() + + +def print_separator(char="─", width=80): + print(char * width) + + +def print_metadata_tree(result, depth=0): + """Recursively print metadata from an RLMChatCompletion and its sub-calls.""" + indent = " " * depth + prefix = f"{'└─ ' if depth > 0 else ''}" + + print( + f"{indent}{prefix}[Depth {depth}] model={result.root_model} " + f"time={result.execution_time:.2f}s response_len={len(result.response)}" + ) + + # Print usage + usage = result.usage_summary + if usage and usage.model_usage_summaries: + for _model, summary in usage.model_usage_summaries.items(): + print( + f"{indent} tokens: in={summary.total_input_tokens} out={summary.total_output_tokens} " + f"calls={summary.total_calls}" + + (f" cost=${summary.total_cost:.6f}" if summary.total_cost else "") + ) + + # Print trajectory metadata + if result.metadata: + traj = result.metadata + n_iters = len(traj.get("iterations", [])) + print(f"{indent} metadata: {n_iters} iteration(s) captured") + + # Dig into iterations to find sub-calls with their own metadata + for i, iteration in enumerate(traj.get("iterations", [])): + for cb in iteration.get("code_blocks", []): + repl_result = cb.get("result", {}) + for j, sub_call in enumerate(repl_result.get("rlm_calls", [])): + sub_response = sub_call.get("response", "")[:80] + print( + f"{indent} iter {i + 1} sub-call {j + 1}: " + f"model={sub_call.get('root_model', '?')} " + f"response={sub_response!r}..." + ) + if sub_call.get("metadata"): + sub_n = len(sub_call["metadata"].get("iterations", [])) + print(f"{indent} ^ has nested metadata: {sub_n} iteration(s)") + else: + print(f"{indent} metadata: (none)") + + print() + + +def main(): + api_key = os.environ.get("PORTKEY_API_KEY") + if not api_key: + print("Error: PORTKEY_API_KEY not set. Set it and re-run.") + sys.exit(1) + + model = "@openai/gpt-5-nano" + + print_separator("=") + print(" Depth>1 RLM Metadata Example") + print(f" Model: {model} | max_depth=2 | max_iterations=3") + print_separator("=") + print() + + logger = RLMLogger() + + rlm = RLM( + backend="portkey", + backend_kwargs={ + "model_name": model, + "api_key": api_key, + }, + environment="local", + max_depth=2, + max_iterations=3, + logger=logger, + verbose=True, + ) + + # This prompt forces the model to use rlm_query(), triggering a depth>1 subcall + prompt = ( + "Use rlm_query() to ask the sub-model: 'What are the first 5 prime numbers? " + "Reply with just the numbers separated by commas.' " + "Store the response in a variable called 'primes', then return it with FINAL_VAR('primes')." + ) + + print("Prompt:", prompt) + print() + print_separator() + + result = rlm.completion(prompt) + + print_separator("=") + print(" RESULT") + print_separator("=") + print(f"Response: {result.response}") + print(f"Execution time: {result.execution_time:.2f}s") + print() + + # ── Metadata tree ── + print_separator("=") + print(" METADATA TREE") + print_separator("=") + print_metadata_tree(result, depth=0) + + # ── Raw metadata JSON ── + print_separator("=") + print(" RAW METADATA (JSON)") + print_separator("=") + if result.metadata: + # Print compactly but readable + print(json.dumps(result.metadata, indent=2, default=str)[:3000]) + if len(json.dumps(result.metadata, default=str)) > 3000: + print("... (truncated)") + else: + print("(no metadata captured)") + + print() + print_separator("=") + print(" SUB-CALL METADATA DETAIL") + print_separator("=") + + # Walk through iterations and find sub-calls with metadata + if result.metadata: + found_subcalls = False + for i, iteration in enumerate(result.metadata.get("iterations", [])): + for cb in iteration.get("code_blocks", []): + repl_result = cb.get("result", {}) + for j, sub_call in enumerate(repl_result.get("rlm_calls", [])): + found_subcalls = True + print(f"\nIteration {i + 1}, Sub-call {j + 1}:") + print(f" Model: {sub_call.get('root_model', '?')}") + print(f" Response: {sub_call.get('response', '')[:200]}") + print(f" Execution time: {sub_call.get('execution_time', 0):.2f}s") + if sub_call.get("metadata"): + meta = sub_call["metadata"] + print(f" Trajectory: {len(meta.get('iterations', []))} iterations") + print( + f" Run metadata: {json.dumps(meta.get('run_metadata', {}), indent=4, default=str)}" + ) + else: + print(" Trajectory: (none — leaf LM call, no REPL)") + print() + + if not found_subcalls: + print("No sub-calls found in iterations (model may not have used llm_query).") + else: + print("No metadata to inspect.") + + +if __name__ == "__main__": + main() diff --git a/examples/rlm_query_batched_example.py b/examples/rlm_query_batched_example.py new file mode 100644 index 00000000..7c76f925 --- /dev/null +++ b/examples/rlm_query_batched_example.py @@ -0,0 +1,170 @@ +""" +Example: rlm_query_batched() with depth > 1. + +Demonstrates that rlm_query_batched() spawns multiple child RLMs, each +with its own REPL and iterative reasoning. The parent collects all +responses, and the metadata tree shows sub-calls from every child. + +Usage: + PORTKEY_API_KEY=... python examples/rlm_query_batched_example.py + +Prints the parent response, timing, and a metadata tree showing each +child RLM sub-call and its own trajectory. +""" + +import json +import os +import sys + +from dotenv import load_dotenv + +from rlm import RLM +from rlm.logger import RLMLogger + +load_dotenv() + + +def print_separator(char="─", width=80): + print(char * width) + + +def print_metadata_tree(result, depth=0): + """Recursively print metadata from an RLMChatCompletion and its sub-calls.""" + indent = " " * depth + prefix = f"{'└─ ' if depth > 0 else ''}" + + print( + f"{indent}{prefix}[Depth {depth}] model={result.root_model} " + f"time={result.execution_time:.2f}s response_len={len(result.response)}" + ) + + usage = result.usage_summary + if usage and usage.model_usage_summaries: + for _model, summary in usage.model_usage_summaries.items(): + print( + f"{indent} tokens: in={summary.total_input_tokens} " + f"out={summary.total_output_tokens} " + f"calls={summary.total_calls}" + + (f" cost=${summary.total_cost:.6f}" if summary.total_cost else "") + ) + + if result.metadata: + traj = result.metadata + n_iters = len(traj.get("iterations", [])) + print(f"{indent} metadata: {n_iters} iteration(s) captured") + + for i, iteration in enumerate(traj.get("iterations", [])): + for cb in iteration.get("code_blocks", []): + repl_result = cb.get("result", {}) + for j, sub_call in enumerate(repl_result.get("rlm_calls", [])): + sub_response = sub_call.get("response", "")[:80] + print( + f"{indent} iter {i + 1} sub-call {j + 1}: " + f"model={sub_call.get('root_model', '?')} " + f"response={sub_response!r}..." + ) + if sub_call.get("metadata"): + sub_n = len(sub_call["metadata"].get("iterations", [])) + print(f"{indent} ^ has nested metadata: {sub_n} iteration(s)") + else: + print(f"{indent} metadata: (none)") + print() + + +def main(): + api_key = os.environ.get("PORTKEY_API_KEY") + if not api_key: + print("Error: PORTKEY_API_KEY not set. Set it and re-run.") + sys.exit(1) + + model = "@openai/gpt-5-nano" + + print_separator("=") + print(" rlm_query_batched() Example (depth > 1)") + print(f" Model: {model} | max_depth=2 | max_iterations=3") + print_separator("=") + print() + + logger = RLMLogger() + + rlm = RLM( + backend="portkey", + backend_kwargs={ + "model_name": model, + "api_key": api_key, + }, + environment="local", + max_depth=2, + max_iterations=3, + logger=logger, + verbose=True, + ) + + # Prompt that forces the model to use rlm_query_batched() + prompt = ( + "Use rlm_query_batched() to ask THREE different questions in parallel:\n" + " 1. 'What are the first 5 prime numbers? Reply with just the numbers.'\n" + " 2. 'What are the first 5 even numbers? Reply with just the numbers.'\n" + " 3. 'What are the first 5 square numbers? Reply with just the numbers.'\n" + "Store the list of responses in a variable called 'answers', " + "then return it with FINAL_VAR(answers)." + ) + + print("Prompt:", prompt) + print() + print_separator() + + result = rlm.completion(prompt) + + print_separator("=") + print(" RESULT") + print_separator("=") + print(f"Response: {result.response}") + print(f"Execution time: {result.execution_time:.2f}s") + print() + + # ── Metadata tree ── + print_separator("=") + print(" METADATA TREE") + print_separator("=") + print_metadata_tree(result, depth=0) + + # ── Sub-call detail ── + print_separator("=") + print(" SUB-CALL DETAIL") + print_separator("=") + + if result.metadata: + found_subcalls = 0 + for i, iteration in enumerate(result.metadata.get("iterations", [])): + for cb in iteration.get("code_blocks", []): + repl_result = cb.get("result", {}) + for j, sub_call in enumerate(repl_result.get("rlm_calls", [])): + found_subcalls += 1 + print(f"\nIteration {i + 1}, Sub-call {j + 1}:") + print(f" Model: {sub_call.get('root_model', '?')}") + print(f" Response: {sub_call.get('response', '')[:200]}") + print(f" Execution time: {sub_call.get('execution_time', 0):.2f}s") + if sub_call.get("metadata"): + meta = sub_call["metadata"] + print(f" Trajectory: {len(meta.get('iterations', []))} iterations") + print( + f" Run metadata: " + f"{json.dumps(meta.get('run_metadata', {}), indent=4, default=str)}" + ) + else: + print(" Trajectory: (none — leaf LM call, no REPL)") + print() + + print(f"Total sub-calls found: {found_subcalls}") + if found_subcalls < 3: + print( + "NOTE: Expected at least 3 sub-calls from rlm_query_batched(). " + "The model may not have followed instructions exactly." + ) + else: + print("No metadata to inspect.") + + +if __name__ == "__main__": + main() diff --git a/rlm/__init__.py b/rlm/__init__.py index 37a4c71e..b29d3cb0 100644 --- a/rlm/__init__.py +++ b/rlm/__init__.py @@ -1,3 +1,17 @@ from rlm.core.rlm import RLM +from rlm.utils.exceptions import ( + BudgetExceededError, + CancellationError, + ErrorThresholdExceededError, + TimeoutExceededError, + TokenLimitExceededError, +) -__all__ = ["RLM"] +__all__ = [ + "RLM", + "BudgetExceededError", + "TimeoutExceededError", + "TokenLimitExceededError", + "ErrorThresholdExceededError", + "CancellationError", +] diff --git a/rlm/clients/openai.py b/rlm/clients/openai.py index b235d4b7..1bfc549f 100644 --- a/rlm/clients/openai.py +++ b/rlm/clients/openai.py @@ -45,12 +45,14 @@ def __init__( api_key=api_key, base_url=base_url, timeout=self.timeout ) self.model_name = model_name + self.base_url = base_url # Track for cost extraction # Per-model usage tracking self.model_call_counts: dict[str, int] = defaultdict(int) self.model_input_tokens: dict[str, int] = defaultdict(int) self.model_output_tokens: dict[str, int] = defaultdict(int) self.model_total_tokens: dict[str, int] = defaultdict(int) + self.model_costs: dict[str, float] = defaultdict(float) # Cost in USD def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str: if isinstance(prompt, str): @@ -113,13 +115,37 @@ def _track_cost(self, response: openai.ChatCompletion, model: str): self.last_prompt_tokens = usage.prompt_tokens self.last_completion_tokens = usage.completion_tokens + # Extract cost from OpenRouter responses (cost is in USD) + # OpenRouter returns cost in usage.model_extra for pydantic models + self.last_cost: float | None = None + cost = None + + # Try direct attribute first + if hasattr(usage, "cost") and usage.cost: + cost = usage.cost + # Then try model_extra (OpenRouter uses this) + elif hasattr(usage, "model_extra") and usage.model_extra: + extra = usage.model_extra + # Primary cost field (may be 0 for BYOK) + if extra.get("cost"): + cost = extra["cost"] + # Fallback to upstream cost details + elif extra.get("cost_details", {}).get("upstream_inference_cost"): + cost = extra["cost_details"]["upstream_inference_cost"] + + if cost is not None and cost > 0: + self.last_cost = float(cost) + self.model_costs[model] += self.last_cost + def get_usage_summary(self) -> UsageSummary: model_summaries = {} for model in self.model_call_counts: + cost = self.model_costs.get(model) model_summaries[model] = ModelUsageSummary( total_calls=self.model_call_counts[model], total_input_tokens=self.model_input_tokens[model], total_output_tokens=self.model_output_tokens[model], + total_cost=cost if cost else None, ) return UsageSummary(model_usage_summaries=model_summaries) @@ -128,4 +154,5 @@ def get_last_usage(self) -> ModelUsageSummary: total_calls=1, total_input_tokens=self.last_prompt_tokens, total_output_tokens=self.last_completion_tokens, + total_cost=getattr(self, "last_cost", None), ) diff --git a/rlm/core/rlm.py b/rlm/core/rlm.py index f236f45b..ce247ecd 100644 --- a/rlm/core/rlm.py +++ b/rlm/core/rlm.py @@ -1,4 +1,5 @@ import time +from collections.abc import Callable from contextlib import contextmanager from typing import Any @@ -12,9 +13,17 @@ RLMChatCompletion, RLMIteration, RLMMetadata, + UsageSummary, ) from rlm.environments import BaseEnv, SupportsPersistence, get_environment from rlm.logger import RLMLogger, VerbosePrinter +from rlm.utils.exceptions import ( + BudgetExceededError, + CancellationError, + ErrorThresholdExceededError, + TimeoutExceededError, + TokenLimitExceededError, +) from rlm.utils.parsing import ( find_code_blocks, find_final_answer, @@ -47,6 +56,10 @@ def __init__( depth: int = 0, max_depth: int = 1, max_iterations: int = 30, + max_budget: float | None = None, + max_timeout: float | None = None, + max_tokens: int | None = None, + max_errors: int | None = None, custom_system_prompt: str | None = None, other_backends: list[ClientBackend] | None = None, other_backend_kwargs: list[dict[str, Any]] | None = None, @@ -57,6 +70,10 @@ def __init__( custom_sub_tools: dict[str, Any] | None = None, compaction: bool = False, compaction_threshold_pct: float = 0.85, + on_subcall_start: Callable[[int, str, str], None] | None = None, + on_subcall_complete: Callable[[int, str, float, str | None], None] | None = None, + on_iteration_start: Callable[[int, int], None] | None = None, + on_iteration_complete: Callable[[int, int, float], None] | None = None, ): """ Args: @@ -65,8 +82,12 @@ def __init__( environment: The environment to use for the RLM. environment_kwargs: The kwargs to pass to the environment. depth: The current depth of the RLM (0-indexed). - max_depth: The maximum depth of the RLM. Currently, only depth 1 is supported. + max_depth: The maximum depth of recursion. When depth >= max_depth, falls back to plain LM completion. max_iterations: The maximum number of iterations of the RLM. + max_budget: Maximum budget in USD. Execution stops if exceeded. Requires cost-tracking backend (e.g., OpenRouter). + max_timeout: Maximum execution time in seconds. Execution stops if exceeded, returning best answer if available. + max_tokens: Maximum total tokens (input + output). Execution stops if exceeded, returning best answer if available. + max_errors: Maximum consecutive errors before stopping. Execution stops if exceeded, returning best answer if available. custom_system_prompt: The custom system prompt to use for the RLM. other_backends: A list of other client backends that the environments can use to make sub-calls. other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends). @@ -81,6 +102,10 @@ def __init__( when root context reaches compaction_threshold_pct of the model's context limit. compaction_threshold_pct: When compaction is on, trigger summarization when root message token count reaches this fraction of the model context limit (default 0.85). + on_subcall_start: Callback fired when a child RLM starts. Args: (depth, model, prompt_preview). + on_subcall_complete: Callback fired when a child RLM completes. Args: (depth, model, duration, error_or_none). + on_iteration_start: Callback fired when an iteration starts. Args: (depth, iteration_num). + on_iteration_complete: Callback fired when an iteration completes. Args: (depth, iteration_num, duration). """ # Store config for spawning per-completion self.backend = backend @@ -111,10 +136,27 @@ def __init__( self.depth = depth self.max_depth = max_depth self.max_iterations = max_iterations + self.max_budget = max_budget + self.max_timeout = max_timeout + self.max_tokens = max_tokens + self.max_errors = max_errors self.system_prompt = custom_system_prompt if custom_system_prompt else RLM_SYSTEM_PROMPT self.logger = logger self.verbose = VerbosePrinter(enabled=verbose) + # Event callbacks for live tree display + self.on_subcall_start = on_subcall_start + self.on_subcall_complete = on_subcall_complete + self.on_iteration_start = on_iteration_start + self.on_iteration_complete = on_iteration_complete + + # Tracking (cumulative across all calls including children) + self._cumulative_cost: float = 0.0 + self._consecutive_errors: int = 0 + self._last_error: str | None = None + self._best_partial_answer: str | None = None + self._completion_start_time: float | None = None # Set when completion() starts + # Persistence support self.persistent = persistent self._persistent_env: SupportsPersistence | None = None @@ -186,6 +228,9 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]): env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port) env_kwargs["context_payload"] = prompt env_kwargs["depth"] = self.depth + 1 # Environment depth is RLM depth + 1 + # For local environment with max_depth > 1, pass subcall callback for recursive RLM calls + if self.environment_type == "local" and self.max_depth > 1: + env_kwargs["subcall_fn"] = self._subcall # Pass custom tools to the environment if self.custom_tools is not None: env_kwargs["custom_tools"] = self.custom_tools @@ -240,7 +285,12 @@ def completion( A final answer as a string. """ time_start = time.perf_counter() + self._completion_start_time = time_start + # Reset tracking state for this completion + self._consecutive_errors = 0 + self._last_error = None + self._best_partial_answer = None # If we're at max depth, the RLM is an LM, so we fallback to the regular LM. if self.depth >= self.max_depth: return self._fallback_answer(prompt) @@ -251,85 +301,110 @@ def completion( with self._spawn_completion_context(prompt) as (lm_handler, environment): message_history = self._setup_prompt(prompt) - for i in range(self.max_iterations): - if self.compaction and hasattr(environment, "append_compaction_entry"): - current_tokens, threshold_tokens, max_tokens = self._get_compaction_status( - message_history + compaction_count = 0 + try: + for i in range(self.max_iterations): + # Check timeout before each iteration + self._check_timeout(i, time_start) + + # Compaction: check if context needs summarization + if self.compaction and hasattr(environment, "append_compaction_entry"): + current_tokens, threshold_tokens, max_tokens = self._get_compaction_status( + message_history + ) + self.verbose.print_compaction_status( + current_tokens, threshold_tokens, max_tokens + ) + if current_tokens >= threshold_tokens: + compaction_count += 1 + self.verbose.print_compaction() + message_history = self._compact_history( + lm_handler, environment, message_history, compaction_count + ) + + # Current prompt = message history + additional prompt suffix + context_count = ( + environment.get_context_count() + if isinstance(environment, SupportsPersistence) + else 1 ) - self.verbose.print_compaction_status( - current_tokens, threshold_tokens, max_tokens + history_count = ( + environment.get_history_count() + if isinstance(environment, SupportsPersistence) + else 0 + ) + current_prompt = message_history + [ + build_user_prompt(root_prompt, i, context_count, history_count) + ] + + iteration: RLMIteration = self._completion_turn( + prompt=current_prompt, + lm_handler=lm_handler, + environment=environment, ) - if current_tokens >= threshold_tokens: - self.verbose.print_compaction() - message_history = self._compact_history( - lm_handler, environment, message_history - ) - # Current prompt = message history + additional prompt suffix - context_count = ( - environment.get_context_count() - if isinstance(environment, SupportsPersistence) - else 1 - ) - history_count = ( - environment.get_history_count() - if isinstance(environment, SupportsPersistence) - else 0 - ) - current_prompt = message_history + [ - build_user_prompt(root_prompt, i, context_count, history_count) - ] - iteration: RLMIteration = self._completion_turn( - prompt=current_prompt, - lm_handler=lm_handler, - environment=environment, - ) + # Check error/budget/token limits after each iteration + self._check_iteration_limits(iteration, i, lm_handler) + + # Check if RLM is done and has a final answer. + # Prefer FINAL_VAR result from REPL execution. + final_answer = None + for block in iteration.code_blocks: + if getattr(block.result, "final_answer", None): + final_answer = block.result.final_answer + break + if final_answer is None: + final_answer = find_final_answer( + iteration.response, environment=environment + ) + iteration.final_answer = final_answer + + # Store as best partial answer (most recent response with content) + if iteration.response and iteration.response.strip(): + self._best_partial_answer = iteration.response + + # If logger is used, log the iteration. + if self.logger: + self.logger.log(iteration) + + # Verbose output for this iteration + self.verbose.print_iteration(iteration, i + 1) + + if final_answer is not None: + time_end = time.perf_counter() + usage = lm_handler.get_usage_summary() + self.verbose.print_final_answer(final_answer) + self.verbose.print_summary(i + 1, time_end - time_start, usage.to_dict()) + + # Store message history in persistent environment + if self.persistent and isinstance(environment, SupportsPersistence): + environment.add_history(message_history) + + return RLMChatCompletion( + root_model=self.backend_kwargs.get("model_name", "unknown") + if self.backend_kwargs + else "unknown", + prompt=prompt, + response=final_answer, + usage_summary=usage, + execution_time=time_end - time_start, + metadata=self.logger.get_trajectory() if self.logger else None, + ) - # Check if RLM is done and has a final answer. Prefer FINAL_VAR result from REPL execution. - final_answer = None - for block in iteration.code_blocks: - if getattr(block.result, "final_answer", None): - final_answer = block.result.final_answer - break - if final_answer is None: - final_answer = find_final_answer(iteration.response, environment=environment) - iteration.final_answer = final_answer - - # If logger is used, log the iteration. - if self.logger: - self.logger.log(iteration) - - # Verbose output for this iteration - self.verbose.print_iteration(iteration, i + 1) - - if final_answer is not None: - time_end = time.perf_counter() - usage = lm_handler.get_usage_summary() - self.verbose.print_final_answer(final_answer) - self.verbose.print_summary(i + 1, time_end - time_start, usage.to_dict()) - - # Store message history in persistent environment - if self.persistent and isinstance(environment, SupportsPersistence): - environment.add_history(message_history) - - return RLMChatCompletion( - root_model=self.backend_kwargs.get("model_name", "unknown") - if self.backend_kwargs - else "unknown", - prompt=prompt, - response=final_answer, - usage_summary=usage, - execution_time=time_end - time_start, - metadata=self.logger.get_trajectory() if self.logger else None, - ) + # Format the iteration for the next prompt. + new_messages = format_iteration(iteration) - # Format the iteration for the next prompt. - new_messages = format_iteration(iteration) + # Update message history with the new messages. + message_history.extend(new_messages) + if self.compaction and hasattr(environment, "append_compaction_entry"): + environment.append_compaction_entry(new_messages) - # Update message history with the new messages. - message_history.extend(new_messages) - if self.compaction and hasattr(environment, "append_compaction_entry"): - environment.append_compaction_entry(new_messages) + except KeyboardInterrupt: + self.verbose.print_limit_exceeded("cancelled", "User interrupted execution") + raise CancellationError( + partial_answer=self._best_partial_answer, + message="Execution cancelled by user (Ctrl+C)", + ) from None # Default behavior: we run out of iterations, provide one final answer time_end = time.perf_counter() @@ -353,6 +428,101 @@ def completion( metadata=self.logger.get_trajectory() if self.logger else None, ) + def _check_timeout(self, iteration: int, time_start: float) -> None: + """Raise TimeoutExceededError if the timeout has been exceeded.""" + if self.max_timeout is None: + return + elapsed = time.perf_counter() - time_start + if elapsed > self.max_timeout: + self.verbose.print_limit_exceeded( + "timeout", + f"{elapsed:.1f}s of {self.max_timeout:.1f}s", + ) + raise TimeoutExceededError( + elapsed=elapsed, + timeout=self.max_timeout, + partial_answer=self._best_partial_answer, + message=( + f"Timeout exceeded after iteration {iteration}: " + f"{elapsed:.1f}s of {self.max_timeout:.1f}s limit" + ), + ) + + def _check_iteration_limits( + self, iteration: RLMIteration, iteration_num: int, lm_handler: LMHandler + ) -> None: + """Check error tracking, budget, and token limits after an iteration. + + Raises ErrorThresholdExceededError, BudgetExceededError, or TokenLimitExceededError + if the respective limits are exceeded. + """ + # Track errors from code execution (check stderr for errors) + iteration_had_error = False + for code_block in iteration.code_blocks: + if code_block.result and code_block.result.stderr: + iteration_had_error = True + self._last_error = code_block.result.stderr + break + + if iteration_had_error: + self._consecutive_errors += 1 + else: + self._consecutive_errors = 0 # Reset on success + + # Check error threshold + if self.max_errors is not None and self._consecutive_errors >= self.max_errors: + self.verbose.print_limit_exceeded( + "errors", + f"{self._consecutive_errors} consecutive errors (limit: {self.max_errors})", + ) + raise ErrorThresholdExceededError( + error_count=self._consecutive_errors, + threshold=self.max_errors, + last_error=self._last_error, + partial_answer=self._best_partial_answer, + message=( + "Error threshold exceeded: " + f"{self._consecutive_errors} consecutive errors " + f"(limit: {self.max_errors})" + ), + ) + + # Check budget + if self.max_budget is not None: + current_usage = lm_handler.get_usage_summary() + current_cost = current_usage.total_cost or 0.0 + self._cumulative_cost = current_cost + if self._cumulative_cost > self.max_budget: + self.verbose.print_budget_exceeded(self._cumulative_cost, self.max_budget) + raise BudgetExceededError( + spent=self._cumulative_cost, + budget=self.max_budget, + message=( + f"Budget exceeded after iteration {iteration_num + 1}: " + f"spent ${self._cumulative_cost:.6f} " + f"of ${self.max_budget:.6f} budget" + ), + ) + + # Check token limit + if self.max_tokens is not None: + current_usage = lm_handler.get_usage_summary() + total_tokens = current_usage.total_input_tokens + current_usage.total_output_tokens + if total_tokens > self.max_tokens: + self.verbose.print_limit_exceeded( + "tokens", + f"{total_tokens:,} of {self.max_tokens:,} tokens", + ) + raise TokenLimitExceededError( + tokens_used=total_tokens, + token_limit=self.max_tokens, + partial_answer=self._best_partial_answer, + message=( + f"Token limit exceeded after iteration {iteration_num + 1}: " + f"{total_tokens:,} of {self.max_tokens:,} tokens" + ), + ) + def _get_compaction_status(self, message_history: list[dict[str, Any]]) -> tuple[int, int, int]: """Return (current_tokens, threshold_tokens, max_tokens) for compaction.""" model_name = ( @@ -373,6 +543,7 @@ def _compact_history( lm_handler: LMHandler, environment: BaseEnv, message_history: list[dict[str, Any]], + compaction_count: int = 1, ) -> list[dict[str, Any]]: """ Summarize current trajectory, append summary to REPL history, and return @@ -381,7 +552,15 @@ def _compact_history( summary_prompt = message_history + [ { "role": "user", - "content": "Very concisely summarize what you have been doing so far in 1–3 short paragraphs. Be extremely brief. This summary will be used to continue the conversation.", + "content": ( + "Summarize your progress so far. Include:\n" + "1. Which steps/sub-tasks you have completed and which remain.\n" + "2. Any concrete intermediate results (numbers, values, variable names) " + "you computed — preserve these exactly.\n" + "3. What your next action should be.\n" + "Be concise (1–3 paragraphs) but preserve all key results and your " + "current position in the task." + ), } ] summary = lm_handler.completion(summary_prompt) @@ -392,7 +571,13 @@ def _compact_history( {"role": "assistant", "content": summary}, { "role": "user", - "content": "Continue from the above summary. The full history (including this summary) is in the REPL variable `history`. Your next action:", + "content": ( + f"Your conversation has been compacted {compaction_count} time(s). " + "Continue from the above summary. Do NOT repeat work you have already " + "completed. Use SHOW_VARS() to check which REPL variables exist, " + "and check `history` for full context. " + "Your next action:" + ), }, ] return new_history @@ -457,6 +642,170 @@ def _fallback_answer(self, message: str | dict[str, Any]) -> str: response = client.completion(message) return response + def _subcall(self, prompt: str, model: str | None = None) -> RLMChatCompletion: + """ + Handle a subcall from the environment, potentially spawning a child RLM. + + This method is passed as a callback to LocalREPL to enable recursive RLM calls. + When depth allows, it spawns a child RLM with its own REPL. At max depth, + it falls back to a plain LM completion. + + Args: + prompt: The prompt to process. + model: Optional model name. If specified, the child RLM will use this model + instead of inheriting the parent's default backend. + + Returns: + The full RLMChatCompletion from either a child RLM or plain LM completion. + On error, returns a completion with the error message as the response. + """ + next_depth = self.depth + 1 + + # Determine which backend/kwargs to use (model override or parent's default) + if model is not None: + child_backend_kwargs = (self.backend_kwargs or {}).copy() + child_backend_kwargs["model_name"] = model + else: + child_backend_kwargs = self.backend_kwargs + resolved_model = model or (child_backend_kwargs or {}).get("model_name", "unknown") + + # If we'd hit/exceed the cap, do a normal LM completion (no REPL) + if next_depth >= self.max_depth: + # Use other_backend if available, otherwise use main backend + if self.other_backends and self.other_backend_kwargs: + client = get_client(self.other_backends[0], self.other_backend_kwargs[0]) + else: + client = get_client(self.backend, child_backend_kwargs or {}) + root_model = model or client.model_name + start_time = time.perf_counter() + try: + response = client.completion(prompt) + end_time = time.perf_counter() + model_usage = client.get_last_usage() + usage_summary = UsageSummary(model_usage_summaries={root_model: model_usage}) + return RLMChatCompletion( + root_model=root_model, + prompt=prompt, + response=response, + usage_summary=usage_summary, + execution_time=end_time - start_time, + ) + except Exception as e: + end_time = time.perf_counter() + return RLMChatCompletion( + root_model=root_model, + prompt=prompt, + response=f"Error: LM query failed at max depth - {e}", + usage_summary=UsageSummary(model_usage_summaries={}), + execution_time=end_time - start_time, + ) + + # Calculate remaining budget for child (if budget tracking enabled) + remaining_budget = None + if self.max_budget is not None: + remaining_budget = self.max_budget - self._cumulative_cost + if remaining_budget <= 0: + return RLMChatCompletion( + root_model=resolved_model, + prompt=prompt, + response=( + "Error: Budget exhausted " + f"(spent ${self._cumulative_cost:.6f} of ${self.max_budget:.6f})" + ), + usage_summary=UsageSummary(model_usage_summaries={}), + execution_time=0.0, + ) + + # Calculate remaining timeout for child (if timeout tracking enabled) + remaining_timeout = None + if self.max_timeout is not None and self._completion_start_time is not None: + elapsed = time.perf_counter() - self._completion_start_time + remaining_timeout = self.max_timeout - elapsed + if remaining_timeout <= 0: + return RLMChatCompletion( + root_model=resolved_model, + prompt=prompt, + response=f"Error: Timeout exhausted ({elapsed:.1f}s of {self.max_timeout:.1f}s)", + usage_summary=UsageSummary(model_usage_summaries={}), + execution_time=0.0, + ) + + # Resolve the model name for callbacks + prompt_preview = prompt[:80] if len(prompt) > 80 else prompt + + # Fire subcall start callback + if self.on_subcall_start: + try: + self.on_subcall_start(next_depth, str(resolved_model), prompt_preview) + except Exception: + pass # Don't let callback errors break execution + + subcall_start = time.perf_counter() + error_msg: str | None = None + + # Spawn a child RLM with its own LocalREPL + child = RLM( + backend=self.backend, + backend_kwargs=child_backend_kwargs, + environment=self.environment_type, + environment_kwargs=self.environment_kwargs, + depth=next_depth, + max_depth=self.max_depth, + max_iterations=self.max_iterations, + max_budget=remaining_budget, + max_timeout=remaining_timeout, + max_tokens=self.max_tokens, + max_errors=self.max_errors, + custom_system_prompt=self.system_prompt, + other_backends=self.other_backends, + other_backend_kwargs=self.other_backend_kwargs, + # Give child its own logger so its trajectory is captured in metadata + logger=RLMLogger() if self.logger else None, + verbose=False, + # Propagate custom tools to children (sub_tools become the child's tools) + custom_tools=self.custom_sub_tools, + custom_sub_tools=self.custom_sub_tools, + # Propagate callbacks to children for nested tracking + on_subcall_start=self.on_subcall_start, + on_subcall_complete=self.on_subcall_complete, + ) + try: + result = child.completion(prompt, root_prompt=None) + # Track child's cost in parent's cumulative cost + if result.usage_summary and result.usage_summary.total_cost: + self._cumulative_cost += result.usage_summary.total_cost + return result + except BudgetExceededError as e: + # Propagate child's spending to parent + self._cumulative_cost += e.spent + error_msg = f"Budget exceeded - {e}" + return RLMChatCompletion( + root_model=resolved_model, + prompt=prompt, + response=f"Error: Child RLM budget exceeded - {e}", + usage_summary=UsageSummary(model_usage_summaries={}), + execution_time=time.perf_counter() - subcall_start, + ) + except Exception as e: + error_msg = str(e) + return RLMChatCompletion( + root_model=resolved_model, + prompt=prompt, + response=f"Error: Child RLM completion failed - {e}", + usage_summary=UsageSummary(model_usage_summaries={}), + execution_time=time.perf_counter() - subcall_start, + ) + finally: + # Ensure child resources are cleaned up + child.close() + # Fire subcall complete callback + if self.on_subcall_complete: + try: + duration = time.perf_counter() - subcall_start + self.on_subcall_complete(next_depth, str(resolved_model), duration, error_msg) + except Exception: + pass # Don't let callback errors break execution + def _validate_persistent_environment_support(self) -> None: """ Validate that the configured environment type supports persistent mode. diff --git a/rlm/core/types.py b/rlm/core/types.py index 239e092c..bd9e0d0a 100644 --- a/rlm/core/types.py +++ b/rlm/core/types.py @@ -45,13 +45,17 @@ class ModelUsageSummary: total_calls: int total_input_tokens: int total_output_tokens: int + total_cost: float | None = None # Cost in USD, if available from provider def to_dict(self): - return { + result = { "total_calls": self.total_calls, "total_input_tokens": self.total_input_tokens, "total_output_tokens": self.total_output_tokens, } + if self.total_cost is not None: + result["total_cost"] = self.total_cost + return result @classmethod def from_dict(cls, data: dict) -> "ModelUsageSummary": @@ -59,6 +63,7 @@ def from_dict(cls, data: dict) -> "ModelUsageSummary": total_calls=data.get("total_calls"), total_input_tokens=data.get("total_input_tokens"), total_output_tokens=data.get("total_output_tokens"), + total_cost=data.get("total_cost"), ) @@ -66,13 +71,36 @@ def from_dict(cls, data: dict) -> "ModelUsageSummary": class UsageSummary: model_usage_summaries: dict[str, ModelUsageSummary] + @property + def total_cost(self) -> float | None: + """Aggregate cost across all models. Returns None if no cost data available.""" + costs = [ + summary.total_cost + for summary in self.model_usage_summaries.values() + if summary.total_cost is not None + ] + return sum(costs) if costs else None + + @property + def total_input_tokens(self) -> int: + """Aggregate input tokens across all models.""" + return sum(summary.total_input_tokens for summary in self.model_usage_summaries.values()) + + @property + def total_output_tokens(self) -> int: + """Aggregate output tokens across all models.""" + return sum(summary.total_output_tokens for summary in self.model_usage_summaries.values()) + def to_dict(self): - return { + result = { "model_usage_summaries": { model: usage_summary.to_dict() for model, usage_summary in self.model_usage_summaries.items() }, } + if self.total_cost is not None: + result["total_cost"] = self.total_cost + return result @classmethod def from_dict(cls, data: dict) -> "UsageSummary": diff --git a/rlm/environments/base_env.py b/rlm/environments/base_env.py index 9b9e4ba5..afd6387f 100644 --- a/rlm/environments/base_env.py +++ b/rlm/environments/base_env.py @@ -14,6 +14,8 @@ { "llm_query", "llm_query_batched", + "rlm_query", + "rlm_query_batched", "FINAL_VAR", "SHOW_VARS", "context", @@ -173,6 +175,7 @@ class SupportsCustomTools(Protocol): RESERVED NAMES: The following names cannot be used as custom tool names: - llm_query, llm_query_batched: Single LM completion functions (no tool access) + - rlm_query, rlm_query_batched: Recursive RLM calls for deeper thinking subtasks - FINAL_VAR, SHOW_VARS: Built-in helper functions - context, history: The input context and conversation history variables diff --git a/rlm/environments/local_repl.py b/rlm/environments/local_repl.py index d2ba13fd..afdde3b3 100644 --- a/rlm/environments/local_repl.py +++ b/rlm/environments/local_repl.py @@ -8,6 +8,7 @@ import threading import time import uuid +from collections.abc import Callable from contextlib import contextmanager from typing import Any @@ -130,6 +131,7 @@ def __init__( setup_code: str | None = None, persistent: bool = False, depth: int = 1, + subcall_fn: Callable[[str, str | None], RLMChatCompletion] | None = None, custom_tools: dict[str, Any] | None = None, custom_sub_tools: dict[str, Any] | None = None, compaction: bool = False, @@ -138,6 +140,7 @@ def __init__( super().__init__(persistent=persistent, depth=depth, **kwargs) self.lm_handler_address = lm_handler_address + self.subcall_fn = subcall_fn # Callback for recursive RLM calls (depth > 1 support) self.original_cwd = os.getcwd() self.temp_dir = tempfile.mkdtemp(prefix=f"repl_env_{uuid.uuid4()}_") self._lock = threading.Lock() @@ -189,6 +192,8 @@ def setup(self): self.globals["SHOW_VARS"] = self._show_vars self.globals["llm_query"] = self._llm_query self.globals["llm_query_batched"] = self._llm_query_batched + self.globals["rlm_query"] = self._rlm_query + self.globals["rlm_query_batched"] = self._rlm_query_batched # Add custom tools to globals # Tools can be either plain values or (value, description) tuples @@ -234,7 +239,9 @@ def _show_vars(self) -> str: return f"Available variables: {available}" def _llm_query(self, prompt: str, model: str | None = None) -> str: - """Query the LM via socket connection to the handler. + """Query the LM with a single plain completion (no REPL, no recursion). + + This always makes a direct LM call via the handler, regardless of depth. Args: prompt: The prompt to send to the LM. @@ -250,17 +257,15 @@ def _llm_query(self, prompt: str, model: str | None = None) -> str: if not response.success: return f"Error: {response.error}" - # Track this LLM call - self._pending_llm_calls.append( - response.chat_completion, - ) - + self._pending_llm_calls.append(response.chat_completion) return response.chat_completion.response except Exception as e: return f"Error: LM query failed - {e}" def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> list[str]: - """Query the LM with multiple prompts concurrently. + """Query the LM with multiple prompts concurrently (no REPL, no recursion). + + This always makes direct LM calls via the handler, regardless of depth. Args: prompts: List of prompts to send to the LM. @@ -271,7 +276,6 @@ def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> li """ if not self.lm_handler_address: return ["Error: No LM handler configured"] * len(prompts) - try: responses = send_lm_request_batched( self.lm_handler_address, prompts, model=model, depth=self.depth @@ -282,7 +286,6 @@ def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> li if not response.success: results.append(f"Error: {response.error}") else: - # Track this LLM call in list of all calls -- we may want to do this hierarchically self._pending_llm_calls.append(response.chat_completion) results.append(response.chat_completion.response) @@ -290,6 +293,55 @@ def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> li except Exception as e: return [f"Error: LM query failed - {e}"] * len(prompts) + def _rlm_query(self, prompt: str, model: str | None = None) -> str: + """Spawn a recursive RLM sub-call for deeper thinking on a subtask. + + When a subcall callback is available (max_depth > 1), this spawns a child + RLM with its own REPL that can reason over the prompt iteratively. + Falls back to a plain llm_query if no recursive capability is configured. + + Args: + prompt: The prompt to send to the child RLM. + model: Optional model name override for the child. + """ + if self.subcall_fn is not None: + try: + completion = self.subcall_fn(prompt, model) + self._pending_llm_calls.append(completion) + return completion.response + except Exception as e: + return f"Error: RLM query failed - {e}" + + # Fall back to plain LM call if no recursive capability + return self._llm_query(prompt, model) + + def _rlm_query_batched(self, prompts: list[str], model: str | None = None) -> list[str]: + """Spawn recursive RLM sub-calls for multiple prompts. + + Each prompt gets its own child RLM for deeper thinking. + Falls back to llm_query_batched if no recursive capability is configured. + + Args: + prompts: List of prompts for child RLMs. + model: Optional model name override for the children. + + Returns: + List of responses in the same order as input prompts. + """ + if self.subcall_fn is not None: + results = [] + for prompt in prompts: + try: + completion = self.subcall_fn(prompt, model) + self._pending_llm_calls.append(completion) + results.append(completion.response) + except Exception as e: + results.append(f"Error: RLM query failed - {e}") + return results + + # Fall back to plain batched LM call if no recursive capability + return self._llm_query_batched(prompts, model) + def load_context(self, context_payload: dict | list | str): """Load context into the environment as context_0 (and 'context' alias).""" self.add_context(context_payload, 0) @@ -412,6 +464,10 @@ def _restore_scaffold(self) -> None: self.globals["llm_query"] = self._llm_query elif name == "llm_query_batched": self.globals["llm_query_batched"] = self._llm_query_batched + elif name == "rlm_query": + self.globals["rlm_query"] = self._rlm_query + elif name == "rlm_query_batched": + self.globals["rlm_query_batched"] = self._rlm_query_batched elif name == "FINAL_VAR": self.globals["FINAL_VAR"] = self._final_var elif name == "SHOW_VARS": diff --git a/rlm/logger/verbose.py b/rlm/logger/verbose.py index 9f88b0a3..7a9d7571 100644 --- a/rlm/logger/verbose.py +++ b/rlm/logger/verbose.py @@ -268,31 +268,69 @@ def print_subcall( prompt_preview: str, response_preview: str, execution_time: float | None = None, + metadata: dict | None = None, ) -> None: - """Print a sub-call to another model.""" + """Print a sub-call to another model. + + Args: + model: The model name used for the sub-call. + prompt_preview: Preview of the prompt sent. + response_preview: Preview of the response received. + execution_time: Time taken for the sub-call. + metadata: If present, this was a recursive RLM call (rlm_query). + Contains "iterations" and "run_metadata" keys. + """ if not self.enabled: return + is_rlm_call = metadata is not None + # Header header = Text() - header.append(" ↳ ", style=STYLE_SECONDARY) - header.append("Sub-call: ", style=STYLE_SECONDARY) + if is_rlm_call: + header.append(" ↳ ", style=STYLE_SECONDARY) + header.append("RLM Sub-call: ", style=STYLE_SECONDARY) + else: + header.append(" ↳ ", style=STYLE_MUTED) + header.append("LLM Sub-call: ", style=STYLE_MUTED) header.append(_to_str(model), style=STYLE_ACCENT) if execution_time: header.append(f" ({execution_time:.2f}s)", style=STYLE_MUTED) # Content content = Text() + + # Show child RLM summary when metadata is available + if is_rlm_call: + iterations = metadata.get("iterations", []) + iteration_count = len(iterations) + content.append(f"Iterations: {iteration_count}", style=STYLE_WARNING) + run_meta = metadata.get("run_metadata", {}) + depth = run_meta.get("depth") + if depth is not None: + content.append(f" | Depth: {depth}", style=STYLE_MUTED) + content.append("\n") + + # Truncate previews for readability + max_preview = 200 + prompt_str = _to_str(prompt_preview) + response_str = _to_str(response_preview) + if len(prompt_str) > max_preview: + prompt_str = prompt_str[:max_preview] + "…" + if len(response_str) > max_preview: + response_str = response_str[:max_preview] + "…" + content.append("Prompt: ", style=STYLE_MUTED) - content.append(_to_str(prompt_preview), style=STYLE_TEXT) + content.append(prompt_str, style=STYLE_TEXT) content.append("\nResponse: ", style=STYLE_MUTED) - content.append(_to_str(response_preview), style=STYLE_TEXT) + content.append(response_str, style=STYLE_TEXT) + border = COLORS["secondary"] if is_rlm_call else COLORS["muted"] panel = Panel( content, title=header, title_align="left", - border_style=COLORS["secondary"], + border_style=border, padding=(0, 1), ) self.console.print(panel) @@ -322,8 +360,69 @@ def print_iteration(self, iteration: RLMIteration, iteration_num: int) -> None: prompt_preview=_to_str(call.prompt) if call.prompt else "", response_preview=_to_str(call.response) if call.response else "", execution_time=call.execution_time, + metadata=call.metadata, ) + def print_budget_exceeded(self, spent: float, budget: float) -> None: + """Print a budget exceeded warning.""" + if not self.enabled: + return + + # Title + title = Text() + title.append("⚠ ", style=STYLE_ERROR) + title.append("Budget Exceeded", style=Style(color=COLORS["error"], bold=True)) + + # Content + content = Text() + content.append(f"Spent: ${spent:.6f}\n", style=STYLE_ERROR) + content.append(f"Budget: ${budget:.6f}", style=STYLE_MUTED) + + panel = Panel( + content, + title=title, + title_align="left", + border_style=COLORS["error"], + padding=(0, 2), + ) + + self.console.print() + self.console.print(panel) + self.console.print() + + def print_limit_exceeded(self, limit_type: str, details: str) -> None: + """Print a limit exceeded warning (timeout, tokens, errors, cancellation).""" + if not self.enabled: + return + + # Map limit type to display name + limit_names = { + "timeout": "Timeout Exceeded", + "tokens": "Token Limit Exceeded", + "errors": "Error Threshold Exceeded", + "cancelled": "Execution Cancelled", + } + display_name = limit_names.get(limit_type, f"{limit_type.title()} Limit Exceeded") + + # Title + title = Text() + title.append("⚠ ", style=STYLE_ERROR) + title.append(display_name, style=Style(color=COLORS["error"], bold=True)) + + # Content + content = Text(details, style=STYLE_ERROR) + + panel = Panel( + content, + title=title, + title_align="left", + border_style=COLORS["error"], + padding=(0, 2), + ) + + self.console.print() + self.console.print(panel) + def print_compaction_status( self, current_tokens: int, @@ -424,9 +523,12 @@ def print_summary( m.get("total_output_tokens", 0) for m in usage_summary.get("model_usage_summaries", {}).values() ) + total_cost = usage_summary.get("total_cost") if total_input or total_output: summary_table.add_row("Input Tokens", f"{total_input:,}") summary_table.add_row("Output Tokens", f"{total_output:,}") + if total_cost is not None: + summary_table.add_row("Total Cost", f"${total_cost:.6f}") # Wrap in rule self.console.print() diff --git a/rlm/utils/exceptions.py b/rlm/utils/exceptions.py new file mode 100644 index 00000000..809944c6 --- /dev/null +++ b/rlm/utils/exceptions.py @@ -0,0 +1,73 @@ +"""Custom exceptions for RLM execution limits and cancellation.""" + + +class BudgetExceededError(Exception): + """Raised when the RLM execution exceeds the maximum budget.""" + + def __init__(self, spent: float, budget: float, message: str | None = None): + self.spent = spent + self.budget = budget + super().__init__(message or f"Budget exceeded: spent ${spent:.6f} of ${budget:.6f} budget") + + +class TimeoutExceededError(Exception): + """Raised when the RLM execution exceeds the maximum timeout.""" + + def __init__( + self, + elapsed: float, + timeout: float, + partial_answer: str | None = None, + message: str | None = None, + ): + self.elapsed = elapsed + self.timeout = timeout + self.partial_answer = partial_answer + super().__init__(message or f"Timeout exceeded: {elapsed:.1f}s of {timeout:.1f}s limit") + + +class TokenLimitExceededError(Exception): + """Raised when the RLM execution exceeds the maximum token limit.""" + + def __init__( + self, + tokens_used: int, + token_limit: int, + partial_answer: str | None = None, + message: str | None = None, + ): + self.tokens_used = tokens_used + self.token_limit = token_limit + self.partial_answer = partial_answer + super().__init__( + message or f"Token limit exceeded: {tokens_used:,} of {token_limit:,} tokens" + ) + + +class ErrorThresholdExceededError(Exception): + """Raised when the RLM encounters too many consecutive errors.""" + + def __init__( + self, + error_count: int, + threshold: int, + last_error: str | None = None, + partial_answer: str | None = None, + message: str | None = None, + ): + self.error_count = error_count + self.threshold = threshold + self.last_error = last_error + self.partial_answer = partial_answer + super().__init__( + message + or f"Error threshold exceeded: {error_count} consecutive errors (limit: {threshold})" + ) + + +class CancellationError(Exception): + """Raised when the RLM execution is cancelled by the user.""" + + def __init__(self, partial_answer: str | None = None, message: str | None = None): + self.partial_answer = partial_answer + super().__init__(message or "Execution cancelled by user") diff --git a/rlm/utils/prompts.py b/rlm/utils/prompts.py index 1cfb372b..fe8230d8 100644 --- a/rlm/utils/prompts.py +++ b/rlm/utils/prompts.py @@ -9,14 +9,22 @@ The REPL environment is initialized with: 1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query. -2. A `llm_query` function that allows you to query an LLM (that can handle around 500K chars) inside your REPL environment. -3. A `llm_query_batched` function that allows you to query multiple prompts concurrently: `llm_query_batched(prompts: List[str]) -> List[str]`. This is much faster than sequential `llm_query` calls when you have multiple independent queries. Results are returned in the same order as the input prompts. -4. A `SHOW_VARS()` function that returns all variables you have created in the REPL. Use this to check what variables exist before using FINAL_VAR. -5. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning. +2. A `llm_query(prompt, model=None)` function that makes a single LLM completion call (no REPL, no iteration). Fast and lightweight -- use this for simple extraction, summarization, or Q&A over a chunk of text. The sub-LLM can handle around 500K chars. +3. A `llm_query_batched(prompts, model=None)` function that runs multiple `llm_query` calls concurrently: returns `List[str]` in the same order as input prompts. Much faster than sequential `llm_query` calls for independent queries. +4. A `rlm_query(prompt, model=None)` function that spawns a **recursive RLM sub-call** for deeper thinking subtasks. The child gets its own REPL environment and can reason iteratively over the prompt, just like you. Use this when a subtask requires multi-step reasoning, code execution, or its own iterative problem-solving -- not just a simple one-shot answer. Falls back to `llm_query` if recursion is not available. +5. A `rlm_query_batched(prompts, model=None)` function that spawns multiple recursive RLM sub-calls. Each prompt gets its own child RLM. Falls back to `llm_query_batched` if recursion is not available. +6. A `SHOW_VARS()` function that returns all variables you have created in the REPL. Use this to check what variables exist before using FINAL_VAR. +7. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning. {custom_tools_section} +**When to use `llm_query` vs `rlm_query`:** +- Use `llm_query` for simple, one-shot tasks: extracting info from a chunk, summarizing text, answering a factual question, classifying content. These are fast single LLM calls. +- Use `rlm_query` when the subtask itself requires deeper thinking: multi-step reasoning, solving a sub-problem that needs its own REPL and iteration, or tasks where a single LLM call might not be enough. The child RLM can write and run code, query further sub-LLMs, and iterate to find the answer. + +**Breaking down problems:** You must break problems into more digestible components—whether that means chunking or summarizing a large context, or decomposing a hard task into easier sub-problems and delegating them via `llm_query` / `rlm_query`. Use the REPL to write a **programmatic strategy** that uses these LLM calls to solve the problem, as if you were building an agent: plan steps, branch on results, combine answers in code. + You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer. -Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer. +Make sure to explicitly look through the entire context in REPL before answering your query. Break the context and the problem into digestible pieces: e.g. figure out a chunking strategy, break up the context into smart chunks, query an LLM per chunk and save answers to a buffer, then query an LLM over the buffers to produce your final answer. You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them. For example, a viable strategy is to feed 10 documents per sub-LLM query. Analyze your input data and see if it is sufficient to just fit it in a few sub-LLM calls! @@ -60,18 +68,28 @@ final_answer = llm_query(f"Aggregating all the answers per chunk, answer the original query about total number of jobs: {{query}}\\n\\nAnswers:\\n" + "\\n".join(answers)) ``` -As a final example, after analyzing the context and realizing its separated by Markdown headers, we can maintain state through buffers by chunking the context by headers, and iteratively querying an LLM over it: +For subtasks that require deeper reasoning (e.g. solving a complex sub-problem), use `rlm_query` instead. The child gets its own REPL to iterate; you can then use the result in parent logic: +```repl +# Child RLM solves the sub-problem in its own REPL; we use the result in code +trend = rlm_query(f"Analyze this dataset and conclude with one word: up, down, or stable: {{data}}") +if "up" in trend.lower(): + recommendation = "Consider increasing exposure." +elif "down" in trend.lower(): + recommendation = "Consider hedging." +else: + recommendation = "Hold position." +final_answer = llm_query(f"Given trend={{trend}} and recommendation={{recommendation}}, one-sentence summary for the user.") +``` + +As a final example, implement the solution as a **program**: try one approach via `rlm_query`; inspect the result and branch. If it suffices, use it. If not, break into one easier subproblem and delegate that only. More branches, one path runs—don't load the model. Example: prove √2 irrational. ```repl -# After finding out the context is separated by Markdown headers, we can chunk, summarize, and answer -import re -sections = re.split(r'### (.+)', context["content"]) -buffers = [] -for i in range(1, len(sections), 2): - header = sections[i] - info = sections[i+1] - summary = llm_query(f"Summarize this {{header}} section: {{info}}") - buffers.append(f"{{header}}: {{summary}}") -final_answer = llm_query(f"Based on these summaries, answer the original query: {{query}}\\n\\nSummaries:\\n" + "\\n".join(buffers)) +r = rlm_query("Prove √2 irrational. Give a 1-2 sentence proof, or reply only: USE_LEMMA or USE_CONTRADICTION.") +if "USE_LEMMA" in r.upper(): + final_answer = rlm_query("Prove 'n^2 even => n even' then use it to show √2 irrational. Two sentences.") +elif "USE_CONTRADICTION" in r.upper(): + final_answer = rlm_query("Prove √2 irrational by contradiction. Two sentences.") +else: + final_answer = r ``` In the next step, we can return FINAL_VAR(final_answer). @@ -136,7 +154,7 @@ def build_rlm_system_prompt( return [ {"role": "system", "content": final_system_prompt}, - {"role": "assistant", "content": metadata_prompt}, + {"role": "user", "content": metadata_prompt}, ] diff --git a/rlm/utils/token_utils.py b/rlm/utils/token_utils.py index 79febbe1..44772ce9 100644 --- a/rlm/utils/token_utils.py +++ b/rlm/utils/token_utils.py @@ -18,7 +18,7 @@ # Longest matching key wins. MODEL_CONTEXT_LIMITS: dict[str, int] = { # OpenAI (GPT-5: 272k input, 128k reasoning+output) - "gpt-5-nano": 128_000, + "gpt-5-nano": 272_000, "gpt-5": 272_000, "gpt-4o-mini": 128_000, "gpt-4o-2024": 128_000, diff --git a/tests/test_depth_metadata.py b/tests/test_depth_metadata.py new file mode 100644 index 00000000..def2dc6a --- /dev/null +++ b/tests/test_depth_metadata.py @@ -0,0 +1,558 @@ +"""Tests verifying depth=1 functionality is preserved and depth>1 metadata propagation works. + +1. depth=1: completion loop, limit checks, logger metadata all work as before +2. depth>1: child RLM gets its own logger, metadata flows back through RLMChatCompletion +""" + +from unittest.mock import Mock, patch + +import pytest + +import rlm.core.rlm as rlm_module +from rlm import RLM +from rlm.core.types import ModelUsageSummary, UsageSummary +from rlm.logger import RLMLogger +from rlm.utils.exceptions import ( + BudgetExceededError, + ErrorThresholdExceededError, + TimeoutExceededError, + TokenLimitExceededError, +) + + +def create_mock_lm(responses: list[str], model_name: str = "mock-model") -> Mock: + """Create a mock LM that returns responses in order.""" + mock = Mock() + mock.model_name = model_name + mock.completion.side_effect = list(responses) + mock.get_usage_summary.return_value = UsageSummary( + model_usage_summaries={ + model_name: ModelUsageSummary( + total_calls=1, total_input_tokens=100, total_output_tokens=50 + ) + } + ) + mock.get_last_usage.return_value = mock.get_usage_summary.return_value + return mock + + +# ======================================================================== +# depth=1 tests: verify existing behavior is preserved +# ======================================================================== + + +class TestDepth1CompletionLoop: + """Verify depth=1 completion loop works identically to before refactoring.""" + + def test_basic_completion_with_final_answer(self): + """depth=1 RLM should complete normally with FINAL() answer.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(42)"]) + mock_get_client.return_value = mock_lm + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=1, + ) + result = rlm.completion("What is the answer?") + assert result.response == "42" + assert result.root_model == "test-model" + + def test_multi_iteration_before_final(self): + """depth=1 should iterate multiple times before finding FINAL().""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm( + [ + "Let me think...\n```repl\nx = 1 + 1\nprint(x)\n```", + "Now I know.\n```repl\ny = x * 2\nprint(y)\n```", + "FINAL(4)", + ] + ) + mock_get_client.return_value = mock_lm + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=1, + ) + result = rlm.completion("Compute 2*2") + assert result.response == "4" + + def test_no_subcall_fn_at_depth_1(self): + """depth=1 (max_depth=1) should NOT pass subcall_fn to environment.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(done)"]) + mock_get_client.return_value = mock_lm + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=1, + ) + + # Patch get_environment to capture kwargs without running full loop + with patch.object(rlm_module, "get_environment") as mock_get_env: + # Make get_environment raise to short-circuit after capturing args + mock_get_env.side_effect = lambda env_type, kwargs: (_ for _ in ()).throw( + RuntimeError("captured") + ) + try: + rlm.completion("test") + except RuntimeError: + pass + + call_args = mock_get_env.call_args + env_kwargs = call_args[0][1] + assert "subcall_fn" not in env_kwargs + + def test_subcall_fn_passed_at_depth_gt_1(self): + """max_depth>1 SHOULD pass subcall_fn to environment.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(done)"]) + mock_get_client.return_value = mock_lm + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=2, + ) + + with patch.object(rlm_module, "get_environment") as mock_get_env: + mock_get_env.side_effect = lambda env_type, kwargs: (_ for _ in ()).throw( + RuntimeError("captured") + ) + try: + rlm.completion("test") + except RuntimeError: + pass + + call_args = mock_get_env.call_args + env_kwargs = call_args[0][1] + assert "subcall_fn" in env_kwargs + assert env_kwargs["subcall_fn"] is not None + + +class TestDepth1LimitChecks: + """Verify limit checks work correctly in the refactored helpers.""" + + def test_timeout_check_raises(self): + """_check_timeout should raise TimeoutExceededError when exceeded.""" + import time + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + max_timeout=10.0, + ) + # Simulate start time 15 seconds ago + fake_start = time.perf_counter() - 15.0 + + with pytest.raises(TimeoutExceededError) as exc_info: + rlm._check_timeout(0, fake_start) + assert exc_info.value.elapsed > 10.0 + assert exc_info.value.timeout == 10.0 + + def test_timeout_check_no_raise_within_limit(self): + """_check_timeout should not raise when within limit.""" + import time + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + max_timeout=100.0, + ) + fake_start = time.perf_counter() - 1.0 + # Should not raise + rlm._check_timeout(0, fake_start) + + def test_timeout_check_noop_when_none(self): + """_check_timeout should be a no-op when max_timeout is None.""" + import time + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + max_timeout=None, + ) + # Even with a very old start time, should not raise + rlm._check_timeout(0, time.perf_counter() - 99999) + + def test_error_threshold_check(self): + """_check_iteration_limits should raise on consecutive errors.""" + from rlm.core.types import CodeBlock, REPLResult, RLMIteration + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + max_errors=2, + ) + + mock_handler = Mock() + mock_handler.get_usage_summary.return_value = UsageSummary( + model_usage_summaries={ + "test": ModelUsageSummary( + total_calls=1, total_input_tokens=10, total_output_tokens=10 + ) + } + ) + + error_result = REPLResult(stdout="", stderr="SyntaxError: bad", locals={}, rlm_calls=[]) + error_iteration = RLMIteration( + prompt="test", response="code", code_blocks=[CodeBlock(code="bad", result=error_result)] + ) + + # First error + rlm._check_iteration_limits(error_iteration, 0, mock_handler) + assert rlm._consecutive_errors == 1 + + # Second error should raise + with pytest.raises(ErrorThresholdExceededError) as exc_info: + rlm._check_iteration_limits(error_iteration, 1, mock_handler) + assert exc_info.value.error_count == 2 + assert exc_info.value.threshold == 2 + + def test_error_count_resets_on_success(self): + """Consecutive error count should reset on a successful iteration.""" + from rlm.core.types import CodeBlock, REPLResult, RLMIteration + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + max_errors=3, + ) + + mock_handler = Mock() + mock_handler.get_usage_summary.return_value = UsageSummary( + model_usage_summaries={ + "test": ModelUsageSummary( + total_calls=1, total_input_tokens=10, total_output_tokens=10 + ) + } + ) + + error_result = REPLResult(stdout="", stderr="Error!", locals={}, rlm_calls=[]) + error_iter = RLMIteration( + prompt="test", response="code", code_blocks=[CodeBlock(code="bad", result=error_result)] + ) + + ok_result = REPLResult(stdout="ok", stderr="", locals={}, rlm_calls=[]) + ok_iter = RLMIteration( + prompt="test", response="code", code_blocks=[CodeBlock(code="good", result=ok_result)] + ) + + # Two errors + rlm._check_iteration_limits(error_iter, 0, mock_handler) + rlm._check_iteration_limits(error_iter, 1, mock_handler) + assert rlm._consecutive_errors == 2 + + # Success resets + rlm._check_iteration_limits(ok_iter, 2, mock_handler) + assert rlm._consecutive_errors == 0 + + def test_budget_check_raises(self): + """_check_iteration_limits should raise BudgetExceededError when budget exceeded.""" + from rlm.core.types import RLMIteration + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + max_budget=0.01, + ) + + mock_handler = Mock() + mock_handler.get_usage_summary.return_value = UsageSummary( + model_usage_summaries={ + "test": ModelUsageSummary( + total_calls=10, + total_input_tokens=10000, + total_output_tokens=10000, + total_cost=0.05, + ) + } + ) + + iteration = RLMIteration(prompt="test", response="code", code_blocks=[]) + + with pytest.raises(BudgetExceededError) as exc_info: + rlm._check_iteration_limits(iteration, 0, mock_handler) + assert exc_info.value.spent > 0.01 + assert exc_info.value.budget == 0.01 + + def test_token_limit_check_raises(self): + """_check_iteration_limits should raise TokenLimitExceededError when tokens exceeded.""" + from rlm.core.types import RLMIteration + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test"}, + max_tokens=100, + ) + + mock_handler = Mock() + mock_handler.get_usage_summary.return_value = UsageSummary( + model_usage_summaries={ + "test": ModelUsageSummary( + total_calls=1, + total_input_tokens=80, + total_output_tokens=80, + ) + } + ) + + iteration = RLMIteration(prompt="test", response="code", code_blocks=[]) + + with pytest.raises(TokenLimitExceededError) as exc_info: + rlm._check_iteration_limits(iteration, 0, mock_handler) + assert exc_info.value.tokens_used == 160 + assert exc_info.value.token_limit == 100 + + +class TestDepth1LoggerMetadata: + """Verify depth=1 logger metadata is captured correctly.""" + + def test_completion_returns_metadata_with_logger(self): + """When logger is provided, completion result should have metadata.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(42)"]) + mock_get_client.return_value = mock_lm + + logger = RLMLogger() + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=1, + logger=logger, + ) + result = rlm.completion("What is the answer?") + + assert result.metadata is not None + assert "run_metadata" in result.metadata + assert "iterations" in result.metadata + assert len(result.metadata["iterations"]) == 1 + assert result.metadata["run_metadata"]["root_model"] == "test-model" + + def test_completion_returns_no_metadata_without_logger(self): + """When no logger is provided, metadata should be None.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(42)"]) + mock_get_client.return_value = mock_lm + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=1, + ) + result = rlm.completion("What is the answer?") + assert result.metadata is None + + def test_metadata_has_multiple_iterations(self): + """Logger should capture all iterations.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm( + [ + "Let me compute.\n```repl\nx = 1\n```", + "More work.\n```repl\ny = 2\n```", + "FINAL(done)", + ] + ) + mock_get_client.return_value = mock_lm + + logger = RLMLogger() + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=1, + logger=logger, + ) + result = rlm.completion("compute") + + assert result.metadata is not None + assert len(result.metadata["iterations"]) == 3 + + +# ======================================================================== +# depth>1 tests: verify subcall metadata propagation +# ======================================================================== + + +class TestSubcallLoggerPropagation: + """Verify child RLM gets a logger when parent has one, and metadata flows back.""" + + def test_child_gets_logger_when_parent_has_logger(self): + """When parent has a logger, child RLM should also get a logger.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + logger = RLMLogger() + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + logger=logger, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + # Child should have received a logger + child_logger = captured_child_params.get("logger") + assert child_logger is not None + assert isinstance(child_logger, RLMLogger) + # But it should be a DIFFERENT instance from parent's logger + assert child_logger is not logger + + parent.close() + + def test_child_gets_no_logger_when_parent_has_none(self): + """When parent has no logger, child should also get None.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + logger=None, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + assert captured_child_params.get("logger") is None + + parent.close() + + def test_leaf_subcall_returns_no_metadata(self): + """At max_depth (leaf), subcall returns plain LM completion with no metadata.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["leaf response"] * 3) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + depth=1, + max_depth=2, # next_depth=2 >= max_depth=2 → leaf + logger=RLMLogger(), + ) + + result = parent._subcall("test prompt") + + # Leaf completions don't use RLM, so no metadata + assert result.metadata is None + assert result.response == "leaf response" + + parent.close() + + def test_subcall_metadata_has_trajectory(self): + """When child RLM completes with a logger, the returned RLMChatCompletion should have metadata.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + # Need enough responses: parent init + child init + child completion + mock_lm = create_mock_lm(["FINAL(child answer)"] * 5, model_name="test-model") + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + depth=0, + max_depth=3, # Allows child at depth=1 to have its own REPL + logger=RLMLogger(), + ) + + result = parent._subcall("What is 2+2?") + + # Child should have returned metadata with trajectory + assert result.metadata is not None + assert "run_metadata" in result.metadata + assert "iterations" in result.metadata + assert len(result.metadata["iterations"]) >= 1 + + parent.close() + + +class TestSubcallCustomToolsPropagation: + """Verify custom_tools propagation to child RLM in _subcall.""" + + def test_sub_tools_propagated_to_child(self): + """Child should receive parent's custom_sub_tools as its custom_tools.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + my_tool = lambda x: x * 2 # noqa: E731 + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=3, + custom_tools={"double": my_tool}, + custom_sub_tools={"double": my_tool}, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + assert "double" in captured_child_params.get("custom_tools", {}) + assert "double" in captured_child_params.get("custom_sub_tools", {}) + + parent.close() + + def test_empty_sub_tools_propagated(self): + """When custom_sub_tools is empty dict, child should get empty dict (no tools).""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "test-model"}, + max_depth=3, + custom_tools={"tool": lambda: 1}, + custom_sub_tools={}, # Explicitly no tools for children + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + assert captured_child_params.get("custom_tools") == {} + assert captured_child_params.get("custom_sub_tools") == {} + + parent.close() diff --git a/tests/test_e2e_depth.py b/tests/test_e2e_depth.py new file mode 100644 index 00000000..f951218d --- /dev/null +++ b/tests/test_e2e_depth.py @@ -0,0 +1,25 @@ +"""E2E test for depth>1 with real LLM via OpenRouter.""" + +import os + +import pytest + +from rlm import RLM + + +@pytest.mark.skipif( + not os.environ.get("OPENROUTER_API_KEY"), + reason="OPENROUTER_API_KEY not set", +) +def test_depth_2_real_llm(): + """Test depth=2 recursion with google/gemini-3-flash-preview.""" + rlm = RLM( + backend="openrouter", + backend_kwargs={"model_name": "google/gemini-3-flash-preview"}, + max_iterations=2, + max_depth=2, + ) + result = rlm.completion("What is 2+2? Answer with just the number.") + assert result.response is not None + assert len(result.response) > 0 + print(f"Response: {result.response}") diff --git a/tests/test_rlm_query.py b/tests/test_rlm_query.py new file mode 100644 index 00000000..850a87ff --- /dev/null +++ b/tests/test_rlm_query.py @@ -0,0 +1,210 @@ +"""Tests for rlm_query and rlm_query_batched in LocalREPL.""" + +from unittest.mock import MagicMock + +from rlm.core.types import RLMChatCompletion, UsageSummary +from rlm.environments.local_repl import LocalREPL + + +def _make_completion(response: str) -> RLMChatCompletion: + """Create a minimal RLMChatCompletion for testing.""" + return RLMChatCompletion( + root_model="test-model", + prompt="test", + response=response, + usage_summary=UsageSummary(model_usage_summaries={}), + execution_time=0.1, + ) + + +class TestRlmQueryWithSubcallFn: + """Tests for rlm_query when subcall_fn is provided (depth > 1).""" + + def test_rlm_query_uses_subcall_fn(self): + """rlm_query should use subcall_fn when available.""" + subcall_fn = MagicMock(return_value=_make_completion("child response")) + repl = LocalREPL(subcall_fn=subcall_fn) + result = repl.execute_code("response = rlm_query('hello')") + assert result.stderr == "" + assert repl.locals["response"] == "child response" + subcall_fn.assert_called_once_with("hello", None) + repl.cleanup() + + def test_rlm_query_with_model_override(self): + """rlm_query should pass model to subcall_fn.""" + subcall_fn = MagicMock(return_value=_make_completion("override response")) + repl = LocalREPL(subcall_fn=subcall_fn) + repl.execute_code("response = rlm_query('hello', model='gpt-4')") + assert repl.locals["response"] == "override response" + subcall_fn.assert_called_once_with("hello", "gpt-4") + repl.cleanup() + + def test_rlm_query_tracks_pending_calls(self): + """rlm_query should append completion to _pending_llm_calls.""" + completion = _make_completion("tracked") + subcall_fn = MagicMock(return_value=completion) + repl = LocalREPL(subcall_fn=subcall_fn) + result = repl.execute_code("rlm_query('test')") + assert len(result.rlm_calls) == 1 + assert result.rlm_calls[0].response == "tracked" + repl.cleanup() + + def test_rlm_query_error_handling(self): + """rlm_query should return error string if subcall_fn raises.""" + subcall_fn = MagicMock(side_effect=RuntimeError("subcall failed")) + repl = LocalREPL(subcall_fn=subcall_fn) + result = repl.execute_code("response = rlm_query('hello')") + assert result.stderr == "" + assert "Error" in repl.locals["response"] + assert "subcall failed" in repl.locals["response"] + repl.cleanup() + + +class TestRlmQueryWithoutSubcallFn: + """Tests for rlm_query when no subcall_fn (depth == 1 or max_depth reached).""" + + def test_rlm_query_falls_back_to_llm_query(self): + """Without subcall_fn, rlm_query should fall back to llm_query (which returns error without handler).""" + repl = LocalREPL() + repl.execute_code("response = rlm_query('test')") + assert "Error" in repl.locals["response"] + repl.cleanup() + + +class TestRlmQueryBatchedWithSubcallFn: + """Tests for rlm_query_batched when subcall_fn is provided.""" + + def test_batched_calls_subcall_fn_per_prompt(self): + """rlm_query_batched should call subcall_fn once per prompt.""" + completions = [ + _make_completion("answer 1"), + _make_completion("answer 2"), + _make_completion("answer 3"), + ] + subcall_fn = MagicMock(side_effect=completions) + repl = LocalREPL(subcall_fn=subcall_fn) + result = repl.execute_code( + "answers = rlm_query_batched(['q1', 'q2', 'q3'])\nprint(len(answers))" + ) + assert result.stderr == "" + assert "3" in result.stdout + assert repl.locals["answers"] == ["answer 1", "answer 2", "answer 3"] + assert subcall_fn.call_count == 3 + repl.cleanup() + + def test_batched_tracks_all_pending_calls(self): + """rlm_query_batched should track all completions in rlm_calls.""" + completions = [_make_completion(f"resp {i}") for i in range(3)] + subcall_fn = MagicMock(side_effect=completions) + repl = LocalREPL(subcall_fn=subcall_fn) + result = repl.execute_code("rlm_query_batched(['a', 'b', 'c'])") + assert len(result.rlm_calls) == 3 + assert [c.response for c in result.rlm_calls] == ["resp 0", "resp 1", "resp 2"] + repl.cleanup() + + def test_batched_with_model_override(self): + """rlm_query_batched should pass model to each subcall_fn call.""" + subcall_fn = MagicMock(return_value=_make_completion("ok")) + repl = LocalREPL(subcall_fn=subcall_fn) + repl.execute_code("rlm_query_batched(['q1', 'q2'], model='custom-model')") + assert subcall_fn.call_count == 2 + for call in subcall_fn.call_args_list: + assert call[0][1] == "custom-model" + repl.cleanup() + + def test_batched_partial_failure(self): + """If one subcall_fn call fails, others should still succeed.""" + subcall_fn = MagicMock( + side_effect=[ + _make_completion("ok 1"), + RuntimeError("boom"), + _make_completion("ok 3"), + ] + ) + repl = LocalREPL(subcall_fn=subcall_fn) + result = repl.execute_code("answers = rlm_query_batched(['a', 'b', 'c'])") + assert result.stderr == "" + answers = repl.locals["answers"] + assert answers[0] == "ok 1" + assert "Error" in answers[1] + assert "boom" in answers[1] + assert answers[2] == "ok 3" + repl.cleanup() + + def test_batched_empty_prompts(self): + """rlm_query_batched with empty list should return empty list.""" + subcall_fn = MagicMock() + repl = LocalREPL(subcall_fn=subcall_fn) + repl.execute_code("answers = rlm_query_batched([])") + assert repl.locals["answers"] == [] + subcall_fn.assert_not_called() + repl.cleanup() + + def test_batched_single_prompt(self): + """rlm_query_batched with single prompt should work.""" + subcall_fn = MagicMock(return_value=_make_completion("single")) + repl = LocalREPL(subcall_fn=subcall_fn) + repl.execute_code("answers = rlm_query_batched(['only one'])") + assert repl.locals["answers"] == ["single"] + subcall_fn.assert_called_once_with("only one", None) + repl.cleanup() + + +class TestRlmQueryBatchedWithoutSubcallFn: + """Tests for rlm_query_batched when no subcall_fn.""" + + def test_batched_falls_back_to_llm_query_batched(self): + """Without subcall_fn, should fall back to llm_query_batched (error without handler).""" + repl = LocalREPL() + repl.execute_code("answers = rlm_query_batched(['q1', 'q2'])") + answers = repl.locals["answers"] + assert len(answers) == 2 + assert all("Error" in a for a in answers) + repl.cleanup() + + +class TestLlmQueryDoesNotUseSubcallFn: + """Verify that llm_query never uses subcall_fn even when one is present.""" + + def test_llm_query_ignores_subcall_fn(self): + """llm_query should always do a plain LM call, never use subcall_fn.""" + subcall_fn = MagicMock(return_value=_make_completion("should not see this")) + repl = LocalREPL(subcall_fn=subcall_fn) + repl.execute_code("response = llm_query('test')") + # Without a handler, llm_query returns an error — importantly, subcall_fn is NOT called + assert "Error" in repl.locals["response"] + subcall_fn.assert_not_called() + repl.cleanup() + + def test_llm_query_batched_ignores_subcall_fn(self): + """llm_query_batched should never use subcall_fn.""" + subcall_fn = MagicMock(return_value=_make_completion("nope")) + repl = LocalREPL(subcall_fn=subcall_fn) + repl.execute_code("answers = llm_query_batched(['q1', 'q2'])") + assert all("Error" in a for a in repl.locals["answers"]) + subcall_fn.assert_not_called() + repl.cleanup() + + +class TestRlmQueryScaffoldRestoration: + """Test that rlm_query and rlm_query_batched are restored after overwrite.""" + + def test_rlm_query_restored_after_overwrite(self): + """If model overwrites rlm_query, the next execution should have the real one.""" + subcall_fn = MagicMock(return_value=_make_completion("real")) + repl = LocalREPL(subcall_fn=subcall_fn) + repl.execute_code("rlm_query = lambda x: 'hijacked'") + # After restoration, rlm_query should work normally + repl.execute_code("response = rlm_query('test')") + assert repl.locals["response"] == "real" + subcall_fn.assert_called_once() + repl.cleanup() + + def test_rlm_query_batched_restored_after_overwrite(self): + """If model overwrites rlm_query_batched, the next execution should have the real one.""" + subcall_fn = MagicMock(return_value=_make_completion("real")) + repl = LocalREPL(subcall_fn=subcall_fn) + repl.execute_code("rlm_query_batched = 'garbage'") + repl.execute_code("answers = rlm_query_batched(['q1'])") + assert repl.locals["answers"] == ["real"] + repl.cleanup() diff --git a/tests/test_subcall.py b/tests/test_subcall.py new file mode 100644 index 00000000..807fe7f8 --- /dev/null +++ b/tests/test_subcall.py @@ -0,0 +1,472 @@ +"""Unit tests for RLM._subcall() method. + +Tests for the parameter propagation to child RLM instances: +1. max_timeout (remaining time) is passed to child +2. max_tokens is passed to child +3. max_errors is passed to child +4. model= parameter overrides child's backend model +""" + +import time +from unittest.mock import Mock, patch + +import rlm.core.rlm as rlm_module +from rlm import RLM +from rlm.core.types import ModelUsageSummary, UsageSummary + + +def create_mock_lm(responses: list[str], model_name: str = "mock-model") -> Mock: + """Create a mock LM that returns responses in order.""" + mock = Mock() + mock.model_name = model_name + mock.completion.side_effect = list(responses) + mock.get_usage_summary.return_value = UsageSummary( + model_usage_summaries={ + model_name: ModelUsageSummary( + total_calls=1, total_input_tokens=100, total_output_tokens=50 + ) + } + ) + mock.get_last_usage.return_value = mock.get_usage_summary.return_value + return mock + + +class TestSubcallTimeoutPropagation: + """Tests for max_timeout propagation to child RLM.""" + + def test_child_receives_remaining_timeout(self): + """When parent has max_timeout=60 and 10s have elapsed, child should get max_timeout approx 50.""" + captured_child_params = {} + + # Create a fake child RLM class to capture initialization params + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + # Capture the kwargs before calling parent + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + # Create parent RLM with max_timeout + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, # Need depth > 1 to allow child spawning + max_timeout=60.0, + ) + + # Simulate that 10 seconds have elapsed since completion started + parent._completion_start_time = time.perf_counter() - 10.0 + + # Patch RLM class to capture child creation + with patch.object(rlm_module, "RLM", CapturingRLM): + # Call _subcall which should spawn a child RLM + parent._subcall("test prompt") + + # Verify child received remaining timeout (approximately 50 seconds) + assert "max_timeout" in captured_child_params + remaining = captured_child_params["max_timeout"] + # Allow some tolerance for test execution time + assert 45.0 < remaining < 55.0, f"Expected ~50s remaining, got {remaining}" + + parent.close() + + def test_child_receives_none_timeout_when_parent_has_none(self): + """When parent has no max_timeout, child should also have None.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + max_timeout=None, # No timeout + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + assert captured_child_params.get("max_timeout") is None + + parent.close() + + def test_subcall_returns_error_when_timeout_exhausted(self): + """When timeout is already exhausted, _subcall should return error message.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + max_timeout=10.0, + ) + + # Simulate that more time has elapsed than the timeout + parent._completion_start_time = time.perf_counter() - 15.0 + + result = parent._subcall("test prompt") + + assert "Error: Timeout exhausted" in result.response + + parent.close() + + +class TestSubcallTokensPropagation: + """Tests for max_tokens propagation to child RLM.""" + + def test_child_receives_max_tokens(self): + """Child RLM should get same max_tokens as parent.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + max_tokens=50000, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + assert captured_child_params.get("max_tokens") == 50000 + + parent.close() + + def test_child_receives_none_tokens_when_parent_has_none(self): + """When parent has no max_tokens, child should also have None.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + max_tokens=None, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + assert captured_child_params.get("max_tokens") is None + + parent.close() + + +class TestSubcallErrorsPropagation: + """Tests for max_errors propagation to child RLM.""" + + def test_child_receives_max_errors(self): + """Child RLM should get same max_errors as parent.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + max_errors=5, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + assert captured_child_params.get("max_errors") == 5 + + parent.close() + + def test_child_receives_none_errors_when_parent_has_none(self): + """When parent has no max_errors, child should also have None.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + max_errors=None, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt") + + assert captured_child_params.get("max_errors") is None + + parent.close() + + +class TestSubcallModelOverride: + """Tests for model= parameter override in _subcall.""" + + def test_model_override_sets_child_backend_kwargs(self): + """When llm_query(prompt, model='test-model') is called, child's backend_kwargs should have model_name='test-model'.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model", "api_key": "test-key"}, + max_depth=3, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + # Call _subcall with model override + parent._subcall("test prompt", model="override-model") + + # Verify child received overridden model in backend_kwargs + child_backend_kwargs = captured_child_params.get("backend_kwargs", {}) + assert child_backend_kwargs.get("model_name") == "override-model" + # Original kwargs should be preserved + assert child_backend_kwargs.get("api_key") == "test-key" + + parent.close() + + def test_model_override_does_not_mutate_parent_kwargs(self): + """Model override should not mutate parent's backend_kwargs.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + ) + + original_model = parent.backend_kwargs["model_name"] + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt", model="override-model") + + # Parent's backend_kwargs should be unchanged + assert parent.backend_kwargs["model_name"] == original_model + + parent.close() + + def test_no_model_override_uses_parent_kwargs(self): + """When no model override is provided, child uses parent's backend_kwargs.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + max_depth=3, + ) + + with patch.object(rlm_module, "RLM", CapturingRLM): + # Call _subcall without model override + parent._subcall("test prompt") + + # Child should use parent's backend_kwargs + child_backend_kwargs = captured_child_params.get("backend_kwargs", {}) + assert child_backend_kwargs.get("model_name") == "parent-model" + + parent.close() + + +class TestSubcallModelOverrideAtLeafDepth: + """Tests for model override at max_depth (leaf LM completion).""" + + def test_model_override_at_leaf_depth_uses_overridden_model(self): + """When at max_depth, the leaf LM completion should use the overridden model.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["leaf response"]) + mock_get_client.return_value = mock_lm + + # Parent at depth 1, max_depth 2 means next depth (2) will be at max_depth + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + depth=1, + max_depth=2, + ) + + # Call _subcall with model override - should trigger leaf LM completion + result = parent._subcall("test prompt", model="leaf-override-model") + + # Verify get_client was called with overridden model in backend_kwargs + # The call should be: get_client("openai", {"model_name": "leaf-override-model"}) + call_args = mock_get_client.call_args_list + # Find the call that has the overridden model + found_override_call = False + for call in call_args: + args, kwargs = call + if len(args) >= 2: + backend_kwargs = args[1] + if ( + isinstance(backend_kwargs, dict) + and backend_kwargs.get("model_name") == "leaf-override-model" + ): + found_override_call = True + break + + assert found_override_call, ( + f"Expected get_client to be called with model_name='leaf-override-model', got calls: {call_args}" + ) + assert result.response == "leaf response" + + parent.close() + + def test_leaf_depth_without_model_override_uses_parent_model(self): + """When at max_depth without model override, uses parent's model.""" + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"] * 2 + ["leaf response"]) + mock_get_client.return_value = mock_lm + + # Parent at depth 1, max_depth 2 means next depth (2) will be at max_depth + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model"}, + depth=1, + max_depth=2, + ) + + # Call _subcall without model override + parent._subcall("test prompt") + + # Verify get_client was called with parent's model + # The last call should use the parent's backend_kwargs + call_args = mock_get_client.call_args_list + # Check the most recent call (for leaf completion) + last_call = call_args[-1] + args, _ = last_call + if len(args) >= 2: + backend_kwargs = args[1] + assert backend_kwargs.get("model_name") == "parent-model" + + parent.close() + + +class TestSubcallCombinedParameters: + """Tests for combined parameter propagation.""" + + def test_all_parameters_propagate_together(self): + """All parameters (timeout, tokens, errors, model) should propagate correctly together.""" + captured_child_params = {} + + original_rlm_class = rlm_module.RLM + + class CapturingRLM(original_rlm_class): + def __init__(self, *args, **kwargs): + captured_child_params.update(kwargs) + super().__init__(*args, **kwargs) + + with patch.object(rlm_module, "get_client") as mock_get_client: + mock_lm = create_mock_lm(["FINAL(answer)"]) + mock_get_client.return_value = mock_lm + + parent = RLM( + backend="openai", + backend_kwargs={"model_name": "parent-model", "api_key": "test-key"}, + max_depth=3, + max_timeout=120.0, + max_tokens=100000, + max_errors=10, + ) + + # Simulate 30 seconds elapsed + parent._completion_start_time = time.perf_counter() - 30.0 + + with patch.object(rlm_module, "RLM", CapturingRLM): + parent._subcall("test prompt", model="override-model") + + # Verify all parameters + assert captured_child_params.get("max_tokens") == 100000 + assert captured_child_params.get("max_errors") == 10 + + # Remaining timeout should be around 90 seconds + remaining_timeout = captured_child_params.get("max_timeout") + assert 85.0 < remaining_timeout < 95.0 + + # Model should be overridden + child_backend_kwargs = captured_child_params.get("backend_kwargs", {}) + assert child_backend_kwargs.get("model_name") == "override-model" + assert child_backend_kwargs.get("api_key") == "test-key" + + parent.close()