diff --git a/CLAUDE.md b/CLAUDE.md index 0156976d7..e514973c4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -173,9 +173,9 @@ class Container(Protocol): #### 4. LLM Clients (`src/ares/llms/`) **Core Abstractions:** -- `LLMRequest` - Dataclass with messages and optional temperature -- `LLMResponse` - Dataclass with ChatCompletion and cost tracking -- `LLMClient` Protocol - `async def __call__(request: LLMRequest) -> LLMResponse` +- `lft.OpenResponsesRequest` - Canonical Open Responses request (from linguafranca) used for observations and client inputs +- `InferenceResult` - Dataclass wrapping `lft.OpenResponsesResponse` with cost tracking +- `LLMClient` Protocol - `async def __call__(request: lft.OpenResponsesRequest) -> InferenceResult` **Key Pattern: Queue-Mediated LLM Client (`queue_mediated_client.py`):** @@ -281,12 +281,13 @@ Follow Google-style isort configuration: - **Always import modules, not classes or functions** - **External consumers** (examples, docs): - ✅ Good: `import ares` → use `ares.make(...)` - - ✅ Good: `from ares import llms` → use `llms.LLMRequest`, `llms.TextData` - - ❌ Avoid: `from ares.llms import LLMRequest, TextData` + - ✅ Good: `from ares.llms import open_responses` → use `open_responses.make_request(...)` + - ✅ Good: `from ares import llms` → use `llms.TextData`, `llms.Usage` + - ❌ Avoid: `from ares.llms import OpenResponsesRequest, TextData` - **Internal code**: - - ✅ Good: `from ares.llms import request` → use `request.LLMRequest` + - ✅ Good: `from ares.llms import open_responses` → use `open_responses.make_request(...)` - ✅ Good: `from ares.llms import response` → use `response.TextData`, `response.Usage` - - ❌ Avoid: `from ares.llms.request import LLMRequest` + - ❌ Avoid: `from ares.llms.open_responses import Request` - ❌ Avoid: `from ares.llms.response import TextData, Usage` - Rationale: Makes code more readable and explicit about where objects come from diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7f1565038..7e6e8832e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -103,21 +103,21 @@ Follow **Google-style imports**: always import modules, not individual classes o ```python # Good ✅ import ares -from ares import llms +from ares.llms import open_responses -request = llms.LLMRequest(messages=[...]) +request = open_responses.make_request([open_responses.user_message("Hello")]) env = ares.make("sbv-mswea") # Good for internal code ✅ -from ares.llms import request +from ares.llms import open_responses from ares.llms import response -req = request.LLMRequest(messages=[...]) -resp = response.LLMResponse(data=[...], cost=0.0, usage=...) +req = open_responses.make_request([open_responses.user_message("Hello")]) +resp = response.InferenceResult(response=response.make_response("Hello!"), cost=0.0) # Avoid ❌ -from ares.llms import LLMRequest, TextData -from ares.llms.request import LLMRequest +from ares.llms import OpenResponsesRequest, TextData +from ares.llms.open_responses import Request ``` **Rationale:** Makes code more readable and explicit about where objects come from. diff --git a/README.md b/README.md index 7c794efeb..bbb1a71c2 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ ARES is an RL-first framework for training and evaluating LLM agents, especially It is a modern [gym](https://github.com/Farama-Foundation/Gymnasium): the environment layer powering RL research. -ARES treats LLMRequests as observations and LLMResponses as actions within the environment, so you can focus on training just the LLM - not the Code Agent surrounding it. The interface is entirely async, and supports scaling up to hundreds or thousands of parallel environments easily - check out [example 3](https://github.com/withmartian/ares/tree/main/examples/03_parallel_eval_with_api.py) to run this yourself. +ARES treats Open Responses requests as observations and LLMResponses as actions within the environment, so you can focus on training just the LLM - not the Code Agent surrounding it. The interface is entirely async, and supports scaling up to hundreds or thousands of parallel environments easily - check out [example 3](https://github.com/withmartian/ares/tree/main/examples/03_parallel_eval_with_api.py) to run this yourself. ## Quick Start diff --git a/docs/source/core-concepts.rst b/docs/source/core-concepts.rst index 819d74e87..2102408f9 100644 --- a/docs/source/core-concepts.rst +++ b/docs/source/core-concepts.rst @@ -11,7 +11,7 @@ It's important to understand two different concepts in ARES: The orchestration logic that uses a Container and LLM to solve tasks (e.g., MiniSWECodeAgent). This is **part of the environment** and remains fixed during training. Think of it as the scaffold that defines how an LLM interacts with code. * Agent/Policy (Trained) - The component you're actually training - a function that maps ``LLMRequest → LLMResponse``. This could be a fine-tuned LLM, a prompt optimizer, or any policy that produces better responses. This is what improves through reinforcement learning. + The component you're actually training - a function that maps ``OpenResponsesRequest → InferenceResult``. This could be a fine-tuned LLM, a prompt optimizer, or any policy that produces better responses. This is what improves through reinforcement learning. System Architecture ------------------- @@ -30,13 +30,13 @@ Here's how the components fit together: | generates response | │ │ └──────────┬─────────────┘ │ ┌────────────────────────────────┐ │ ^ │ │ │ QueueMediatedLLMClient │ │ - | │ LLMResponse (action) │ │ │ │ + | │ InferenceResult (action) │ │ │ │ | └──────────────────────────┼─>│ Intercepts LLM calls │ │ | │ │ from code agent via │ │ └─────────────────────────────────┼──│ QueueMediatedLLMClient │ │ - LLMRequest (observation) │ └──────────────────┬─────────────┘ │ + Open Responses observation │ └──────────────────┬─────────────┘ │ │ ^ │ │ - │ LLMRequest │ │ LLMResponse │ + │ Open Responses │ │ InferenceResult│ │ │ v │ │ ┌──────────────└─────────────────┐ │ │ │ CodeAgent │ │ @@ -87,7 +87,7 @@ The key abstraction is ``CodeEnvironment``, which: * **Exposes LLM requests as observations** - Intercepts calls from the code agent * **Treats LLM responses as actions** - Your trainable agent/policy provides responses -Crucially, the **CodeAgent is part of the environment**, not what you're training. Your training loop optimizes an agent/policy that produces better ``LLMResponse`` outputs given ``LLMRequest`` observations. +Crucially, the **CodeAgent is part of the environment**, not what you're training. Your training loop optimizes an agent/policy that produces better ``InferenceResult`` outputs given canonical Open Responses observations. Standard RL Loop ~~~~~~~~~~~~~~~~ @@ -101,10 +101,10 @@ Every environment follows the standard RL pattern: timestep = await env.reset() while not timestep.last(): - # timestep.observation is an LLMRequest from the code agent + # timestep.observation is an Open Responses request from the code agent action = await your_policy(timestep.observation) - # action is an LLMResponse that continues the agent's execution + # action is an InferenceResult that continues the agent's execution timestep = await env.step(action) # timestep.reward contains the reward for the final step @@ -116,7 +116,7 @@ TimeStep Structure Each call to ``reset()`` or ``step()`` returns a ``TimeStep`` with: * ``step_type``: One of ``"FIRST"``, ``"MID"``, or ``"LAST"`` -* ``observation``: An ``LLMRequest`` object (or ``None`` on termination) +* ``observation``: An Open Responses request object (or ``None`` on termination) * ``reward``: A float reward for each step * ``discount``: A float discount factor for RL algorithms @@ -160,7 +160,7 @@ Example structure: async def run(self, task: str) -> None: while not self.is_done(): # Ask LLM what to do next - request = LLMRequest(messages=[...]) + request = open_responses.make_request([open_responses.user_message(...)]) response = await self._llm_client(request) # Parse and execute commands from LLM response @@ -234,8 +234,8 @@ Which you will need to rewrite into something like: # Decide what to ask LLM next ... llm_response = await self.llm_client( - LLMRequest( - messages=[...], + open_responses.make_request( + [open_responses.user_message(...)], ... # Other request params ) ) @@ -293,30 +293,27 @@ Core Interface .. code-block:: python + from linguafranca import types as lft + class LLMClient(Protocol): - async def __call__(self, request: LLMRequest) -> LLMResponse: + async def __call__(self, request: lft.OpenResponsesRequest) -> InferenceResult: ... @dataclass(frozen=True) - class LLMRequest: - messages: Iterable[ChatCompletionMessageParam] - temperature: float | None = None - - @dataclass(frozen=True) - class LLMResponse: - chat_completion_response: ChatCompletion + class InferenceResult: + response: lft.OpenResponsesResponse cost: float -This simple interface wraps OpenAI-style chat completion APIs. The ``messages`` field follows the OpenAI format with ``role`` (system/user/assistant) and ``content``. +ARES uses linguafranca's ``OpenResponsesRequest`` as the canonical request type for observations and client inputs. Edge adapters convert to Chat/Responses/Anthropic formats only when needed. Why LLMClient? ~~~~~~~~~~~~~~ The ``LLMClient`` abstraction serves two purposes: -1. **Observations = LLM Requests**: In the RL loop, ``timestep.observation`` is an ``LLMRequest`` containing the messages the code agent wants to send to the LLM. This is the "state" your policy observes. +1. **Observations = Open Responses requests**: In the RL loop, ``timestep.observation`` is a canonical Open Responses request containing what the code agent wants to send to the LLM. This is the "state" your policy observes. -2. **Actions = LLM Responses**: In the RL loop, the ``action`` you pass to ``env.step()`` is an ``LLMResponse`` containing the LLM's reply. This is how your policy controls the agent's behavior. +2. **Actions = LLM Responses**: In the RL loop, the ``action`` you pass to ``env.step()`` is an ``InferenceResult`` containing the LLM's reply. This is how your policy controls the agent's behavior. This framing makes it natural to think about code agent training as an RL problem: you're learning a policy that maps agent requests to helpful responses. diff --git a/docs/source/how-it-works.rst b/docs/source/how-it-works.rst index 49e0c8144..377c22fa3 100644 --- a/docs/source/how-it-works.rst +++ b/docs/source/how-it-works.rst @@ -28,7 +28,7 @@ The ``QueueMediatedLLMClient`` implements the ``LLMClient`` protocol, but instea Meanwhile, the environment: -1. **Watches the queue**: Extracts ``LLMRequest`` objects as they arrive +1. **Watches the queue**: Extracts canonical Open Responses requests as they arrive 2. **Exposes them as observations**: Returns them from ``reset()`` and ``step()`` 3. **Provides responses**: When you call ``step(action)``, sets the Future's result @@ -39,12 +39,14 @@ The core implementation is simple: .. code-block:: python + from linguafranca import types as lft + @dataclass(frozen=True) class QueueMediatedLLMClient(LLMClient): - q: asyncio.Queue[ValueAndFuture[LLMRequest, LLMResponse]] + q: asyncio.Queue[ValueAndFuture[lft.OpenResponsesRequest, InferenceResult]] - async def __call__(self, request: LLMRequest) -> LLMResponse: - future = asyncio.Future[LLMResponse]() + async def __call__(self, request: lft.OpenResponsesRequest) -> InferenceResult: + future = asyncio.Future[InferenceResult]() await self.q.put(ValueAndFuture(value=request, future=future)) return await future # Blocks until env provides response @@ -65,7 +67,7 @@ The environment side: self._llm_req_future = value_and_future.future return TimeStep(step_type="MID", observation=value_and_future.value, ...) - async def step(self, action: LLMResponse) -> TimeStep: + async def step(self, action: InferenceResult) -> TimeStep: # Unblock the code agent by providing response self._llm_req_future.set_result(action) return await self._get_time_step() diff --git a/docs/source/index.rst b/docs/source/index.rst index 2d5aa938f..104892606 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,10 +20,10 @@ See the main `README `_ for installation in Key Features ------------ -* **RL-First Design**: Built around the reinforcement learning loop with observations (LLM requests) and actions (LLM responses) +* **RL-First Design**: Built around the reinforcement learning loop with observations (Open Responses requests) and actions (LLM responses) * **LLM-Level Optimization**: Train the LLM within code agents, not just the agent as a whole * **Distributed Workloads**: Support for high-volume, distributed training and evaluation -* **Mechanistic Interpretability**: Raw access to LLM requests and responses for deep analysis +* **Mechanistic Interpretability**: Raw access to canonical LLM requests and responses for deep analysis * **Async Gym/dm_env like Spec**: Close to Gym/dm_env spec, but incorporating async methods for performance Indices and tables diff --git a/examples/04_rl_training_with_skyrl.py b/examples/04_rl_training_with_skyrl.py index 6ddef091d..caa05a38d 100644 --- a/examples/04_rl_training_with_skyrl.py +++ b/examples/04_rl_training_with_skyrl.py @@ -48,7 +48,9 @@ import ares from ares import llms +from ares.llms import open_responses import hydra +from linguafranca import types as lft import omegaconf import ray import skyrl_gym @@ -91,7 +93,7 @@ def __init__(self, env_config: dict | None = None, extras: dict | None = None, * self.preset_name = extras.get("preset_name", kwargs.get("preset_name")) if not self.preset_name: raise ValueError("preset_name must be provided in extras or kwargs") - self.env: ares.Environment[llms.LLMResponse, llms.LLMRequest, float, float] | None = None + self.env: ares.Environment[llms.InferenceResult, lft.OpenResponsesRequest, float, float] | None = None async def init( self, prompt: base_text_env.ConversationType @@ -104,7 +106,8 @@ async def init( await self.env.__aenter__() ts = await self.env.reset() - return ts.observation.messages, {} # type: ignore + assert ts.observation is not None + return open_responses.to_chat_messages(ts.observation, strict=True), {} async def step(self, action: str) -> base_text_env.BaseTextEnvStepOutput: """Runs one environment step. @@ -119,10 +122,9 @@ async def step(self, action: str) -> base_text_env.BaseTextEnvStepOutput: """ assert self.env is not None - llm_resp = llms.LLMResponse( - data=[llms.TextData(content=action)], + llm_resp = llms.InferenceResult( + response=llms.make_response(action), cost=0.0, - usage=llms.Usage(prompt_tokens=-1, generated_tokens=-1), ) ts = await self.env.step(llm_resp) @@ -130,7 +132,7 @@ async def step(self, action: str) -> base_text_env.BaseTextEnvStepOutput: # Hack to approximate a context manager await self.env.__aexit__(None, None, None) - msgs = [] if ts.last() else ts.observation.messages + msgs = [] if ts.last() else open_responses.to_chat_messages(ts.observation, strict=True) return base_text_env.BaseTextEnvStepOutput( observations=msgs, reward=ts.reward or 0.0, diff --git a/examples/05_tinker_train.py b/examples/05_tinker_train.py index 97bee3d61..2a481501b 100644 --- a/examples/05_tinker_train.py +++ b/examples/05_tinker_train.py @@ -49,8 +49,10 @@ import ares from ares import containers from ares import llms +from ares.llms import open_responses import chz import frozendict +from linguafranca import types as lft import numpy as np import tinker from tinker_cookbook import cli_utils @@ -109,8 +111,8 @@ class TinkerCompatibleEnv(tinker_types.Env): """Adapter wrapping ARES environments to work with Tinker's RL training loop. Handles bidirectional conversion: - - ARES LLMRequest -> Tinker ModelInput (tokenized prompts) - - Tinker Action (text) -> ARES LLMResponse + - ARES Open Responses request -> Tinker ModelInput (tokenized prompts) + - Tinker Action (text) -> ARES InferenceResult - ARES TimeStep -> Tinker StepResult This enables using any ARES environment with Tinker's training infrastructure. @@ -121,7 +123,7 @@ class TinkerCompatibleEnv(tinker_types.Env): def __init__( self, - env: ares.Environment[llms.LLMResponse, llms.LLMRequest, float, float], + env: ares.Environment[llms.InferenceResult, lft.OpenResponsesRequest, float, float], renderer: renderers.Renderer, convo_prefix: list[renderers.Message] | None, max_tokens: int, @@ -132,14 +134,14 @@ def __init__( self.max_tokens = max_tokens def _get_tinker_observation( - self, ts: ares.TimeStep[llms.LLMRequest | None, float, float] + self, ts: ares.TimeStep[lft.OpenResponsesRequest | None, float, float] ) -> tinker_types.Observation: if ts.observation is None: return tinker.ModelInput.empty() messages = self.convo_prefix + [ renderers.Message(role=message["role"], content=message["content"]) # type: ignore - for message in ts.observation.messages + for message in open_responses.to_chat_messages(ts.observation, strict=True) ] model_input = self.renderer.build_generation_prompt(messages) @@ -149,15 +151,14 @@ def _get_tinker_observation( return model_input - def _get_ares_action(self, action: tinker_types.Action) -> llms.LLMResponse: + def _get_ares_action(self, action: tinker_types.Action) -> llms.InferenceResult: message, parse_success = self.renderer.parse_response(action) if not parse_success: _LOGGER.warning("Failed to parse response: %s", message) - return llms.LLMResponse( - data=[llms.TextData(content=_get_text_content(message))], + return llms.InferenceResult( + response=llms.make_response(_get_text_content(message)), cost=0.0, - usage=llms.Usage(prompt_tokens=-1, generated_tokens=-1), ) @property diff --git a/examples/utils.py b/examples/utils.py index f6e9376bd..64500b582 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -7,6 +7,8 @@ import ares from ares import llms +from ares.llms import open_responses +from linguafranca import types as lft import tqdm _LOGGER = logging.getLogger(__name__) @@ -14,8 +16,8 @@ def print_step( step_count: int, - observation: llms.LLMRequest | None, - action: llms.LLMResponse, + observation: lft.OpenResponsesRequest | None, + action: llms.InferenceResult, ) -> None: """Print a step in the RL loop. @@ -28,19 +30,19 @@ def print_step( print("-" * 80) if observation is not None: - messages = list(observation.messages) + messages = open_responses.to_chat_messages(observation, strict=False) if len(messages) > 0: observation_content = messages[-1].get("content", "") else: - observation_content = str(observation.system_prompt) or "(no messages)" + observation_content = str(observation.instructions) or "(no messages)" observation_preview = str(observation_content)[:200] if len(str(observation_content)) > 200: observation_preview += "..." print(f"Observation (from environment): {observation_preview}") - action_content = action.data[0].content - action_preview = str(action_content)[:200] + action_content = llms.extract_text_content(action.response) + action_preview = action_content[:200] if len(action_content) > 200: action_preview += "..." print(f"Action (from LLM): {action_preview}") diff --git a/pyproject.toml b/pyproject.toml index 076175e76..37cf75645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "harbor>=0.1.32", "httpx>=0.28.1", "jinja2>=3.1.6", + "martian-linguafranca>=0.1.5", "mini-swe-agent>=1.17.3", "numpy>=2.3.5", # TODO: Fix this hard requirement. diff --git a/src/ares/__init__.py b/src/ares/__init__.py index a026f9d00..6487e1c3e 100644 --- a/src/ares/__init__.py +++ b/src/ares/__init__.py @@ -1,8 +1,8 @@ """ARES: Agentic Research and Evaluation Suite. ARES is an RL-first framework for training and evaluating code agents. It implements -an async version of DeepMind's dm_env specification, treating LLM requests as -observations and LLM responses as actions within a standard RL loop. +an async version of DeepMind's dm_env specification, treating canonical Open Responses +requests as observations and LLM responses as actions within a standard RL loop. The primary way to create environments is via the registry system: diff --git a/src/ares/code_agents/code_agent_base.py b/src/ares/code_agents/code_agent_base.py index 41d2cc5e1..c59963ed5 100644 --- a/src/ares/code_agents/code_agent_base.py +++ b/src/ares/code_agents/code_agent_base.py @@ -4,7 +4,7 @@ from ares.containers import containers from ares.llms import llm_clients -from ares.llms import request +from ares.llms import open_responses class CodeAgent(Protocol): @@ -31,7 +31,7 @@ async def run(self, problem_statement: str) -> None: del problem_statement # Unused. for _ in range(50): - await self._llm_client(request.LLMRequest(messages=[{"role": "user", "content": "Print 'Yes' only."}])) + await self._llm_client(open_responses.make_request([open_responses.user_message("Print 'Yes' only.")])) await self._container.exec_run("sleep 1") return diff --git a/src/ares/code_agents/mini_swe_agent.py b/src/ares/code_agents/mini_swe_agent.py index da2238f4e..ebc9f4bf9 100644 --- a/src/ares/code_agents/mini_swe_agent.py +++ b/src/ares/code_agents/mini_swe_agent.py @@ -19,14 +19,15 @@ from typing import Literal, assert_never import jinja2 +from linguafranca import types as lft import yaml from ares.code_agents import code_agent_base from ares.containers import containers from ares.experiment_tracking import stat_tracker from ares.llms import llm_clients -from ares.llms import request -from ares.llms import response +from ares.llms import open_responses +from ares.llms import response as response_lib # Ensure that MSWEA doesn't log its startup message on import. os.environ["MSWEA_SILENT_STARTUP"] = "1" @@ -142,14 +143,14 @@ def __post_init__(self): self._cost_limit = self._agent_config.get("cost_limit", 0.0) self._system_prompt = _render_system_template(self._agent_config["system_template"]) - self._messages: list[request.Message] = [] + self._messages: list[lft.InputItemMessage] = [] _LOGGER.debug("[%d] Initialized MiniSWECodeAgent.", id(self)) def _add_message(self, role: Literal["user", "assistant"], content: str) -> None: if role == "user": - self._messages.append(request.UserMessage(role="user", content=content)) + self._messages.append(open_responses.user_message(content)) elif role == "assistant": - self._messages.append(request.AssistantMessage(role="assistant", content=content)) + self._messages.append(open_responses.assistant_message(content)) else: assert_never(role) @@ -195,7 +196,7 @@ async def step(self) -> None: llm_response = await self.query() await self.execute_action(llm_response) - async def query(self) -> response.LLMResponse: + async def query(self) -> response_lib.InferenceResult: """Query the model and return the response.""" # Check step limit before making LLM call if 0 < self._step_limit <= self._n_calls: @@ -207,30 +208,28 @@ async def query(self) -> response.LLMResponse: _LOGGER.debug("[%d] Querying LLM.", id(self)) with self.tracker.timeit("mswea/llm_request"): - response = await self.llm_client( - request.LLMRequest( - messages=self._messages, - system_prompt=self._system_prompt, + llm_response = await self.llm_client( + open_responses.make_request( + self._messages, + instructions=self._system_prompt, temperature=0.0, ) ) _LOGGER.debug("[%d] LLM response received.", id(self)) self._n_calls += 1 - self._total_cost += response.cost + self._total_cost += llm_response.cost - message_content = response.data[0].content - assert message_content is not None + message_content = response_lib.extract_text_content(llm_response.response) self._add_message("assistant", message_content) - return response + return llm_response - async def execute_action(self, response: response.LLMResponse) -> None: + async def execute_action(self, llm_response: response_lib.InferenceResult) -> None: """Execute the action and return the observation.""" _LOGGER.debug("[%d] Executing action.", id(self)) - response_text = response.data[0].content - assert response_text is not None + response_text = response_lib.extract_text_content(llm_response.response) action = self.parse_action(response_text) diff --git a/src/ares/code_agents/terminus2/terminus2_agent.py b/src/ares/code_agents/terminus2/terminus2_agent.py index 5bd76f497..bf7654caf 100644 --- a/src/ares/code_agents/terminus2/terminus2_agent.py +++ b/src/ares/code_agents/terminus2/terminus2_agent.py @@ -26,13 +26,15 @@ import shlex from typing import Literal, cast +from linguafranca import types as lft + from ares.code_agents import code_agent_base from ares.code_agents.terminus2 import json_parser from ares.code_agents.terminus2 import xml_parser from ares.containers import containers from ares.experiment_tracking import stat_tracker from ares.llms import llm_clients -from ares.llms import request +from ares.llms import open_responses from ares.llms import response _LOGGER = logging.getLogger(__name__) @@ -174,7 +176,7 @@ def __post_init__(self): self._summarize_template = _load_template(template_dir / "summarize.txt") # Conversation history - self._messages: list[request.Message] = [] + self._messages: list[lft.InputItemMessage] = [] self._system_prompt: str | None = None # Set during run() # State tracking @@ -527,10 +529,8 @@ async def run(self, task: str) -> None: try: # Query the LLM try: - response = await self._query_llm() - assert len(response.data) == 1 - assistant_message = response.data[0].content - assert assistant_message is not None + llm_response = await self._query_llm() + assistant_message = response.extract_text_content(llm_response.response) self._add_message("assistant", assistant_message) @@ -646,9 +646,11 @@ async def run(self, task: str) -> None: def _get_total_chars(self) -> int: """Get the total number of characters in the conversation history and system prompt.""" - return sum(len(str(msg.get("content", ""))) for msg in self._messages) + len(self._system_prompt or "") + return sum(len(open_responses.message_text(msg, strict=False)) for msg in self._messages) + len( + self._system_prompt or "" + ) - async def _query_llm(self) -> response.LLMResponse: + async def _query_llm(self) -> response.InferenceResult: """Query the LLM with the current conversation history. Returns: @@ -681,7 +683,7 @@ async def _query_llm(self) -> response.LLMResponse: try: response = await self.llm_client( - request.LLMRequest(messages=self._messages, system_prompt=self._system_prompt) + open_responses.make_request(self._messages, instructions=self._system_prompt) ) _LOGGER.debug("[%d] Received LLM response", id(self)) return response @@ -714,7 +716,7 @@ async def _query_llm(self) -> response.LLMResponse: # Retry the query with summarized context response = await self.llm_client( - request.LLMRequest(messages=self._messages, system_prompt=self._system_prompt) + open_responses.make_request(self._messages, instructions=self._system_prompt) ) _LOGGER.debug("[%d] Received LLM response after summarization", id(self)) return response @@ -871,26 +873,27 @@ def _add_message(self, role: Literal["user", "assistant"], content: str) -> None # Sanitize content before adding sanitized = _sanitize_content(content) assert role in ("user", "assistant") - self._messages.append(cast(request.Message, {"role": role, "content": sanitized})) + if role == "user": + self._messages.append(open_responses.user_message(sanitized)) + else: + self._messages.append(open_responses.assistant_message(sanitized)) - def _unwrap_single_response(self, llm_response: response.LLMResponse) -> str: - """Unwrap a single-item LLMResponse and update metrics. + def _unwrap_single_response(self, llm_response: response.InferenceResult) -> str: + """Unwrap an InferenceResult text content and update metrics. Args: - llm_response: The LLM response with a single data item. + llm_response: The LLM response. Returns: The content string from the response. """ - assert len(llm_response.data) == 1 - content = llm_response.data[0].content - assert content is not None + content = response.extract_text_content(llm_response.response) # Track subagent metrics usage = llm_response.usage if usage: - self._subagent_metrics.total_prompt_tokens += usage.prompt_tokens or 0 - self._subagent_metrics.total_completion_tokens += usage.generated_tokens or 0 + self._subagent_metrics.total_prompt_tokens += usage.input_tokens or 0 + self._subagent_metrics.total_completion_tokens += usage.output_tokens or 0 return content @@ -914,10 +917,10 @@ async def _summarize(self) -> str: try: # Use the conversation history for context (same as Harbor's implementation) summary_messages = list(self._messages) - summary_messages.append({"role": "user", "content": summary_prompt}) + summary_messages.append(open_responses.user_message(summary_prompt)) summary_response = await self.llm_client( - request.LLMRequest(messages=summary_messages, system_prompt=self._system_prompt) + open_responses.make_request(summary_messages, instructions=self._system_prompt) ) summary_content = self._unwrap_single_response(summary_response) @@ -932,9 +935,9 @@ async def _summarize(self) -> str: _LOGGER.warning("[%d] Summarization hit context limit, using last 20 messages only", id(self)) try: summary_messages = [self._messages[0], *self._messages[-20:]] - summary_messages.append({"role": "user", "content": summary_prompt}) + summary_messages.append(open_responses.user_message(summary_prompt)) summary_response = await self.llm_client( - request.LLMRequest(messages=summary_messages, system_prompt=self._system_prompt) + open_responses.make_request(summary_messages, instructions=self._system_prompt) ) summary_content = self._unwrap_single_response(summary_response) _LOGGER.info("[%d] Step 1/3: Summary generated (with fallback)", id(self)) @@ -957,7 +960,7 @@ async def _summarize(self) -> str: try: questions_response = await self.llm_client( - request.LLMRequest(messages=[{"role": "user", "content": question_prompt}]) + open_responses.make_request([open_responses.user_message(question_prompt)]) ) model_questions = self._unwrap_single_response(questions_response) @@ -977,10 +980,10 @@ async def _summarize(self) -> str: try: # Use the conversation history for context (same as Harbor's implementation) answer_messages = list(self._messages) - answer_messages.append({"role": "user", "content": answer_request_prompt}) + answer_messages.append(open_responses.user_message(answer_request_prompt)) answers_response = await self.llm_client( - request.LLMRequest(messages=answer_messages, system_prompt=self._system_prompt) + open_responses.make_request(answer_messages, instructions=self._system_prompt) ) answers_content = self._unwrap_single_response(answers_response) @@ -994,9 +997,9 @@ async def _summarize(self) -> str: _LOGGER.warning("[%d] Answer generation hit context limit, using last 20 messages only", id(self)) try: answer_messages = [*self._messages[-20:]] - answer_messages.append({"role": "user", "content": answer_request_prompt}) + answer_messages.append(open_responses.user_message(answer_request_prompt)) answers_response = await self.llm_client( - request.LLMRequest(messages=answer_messages, system_prompt=self._system_prompt) + open_responses.make_request(answer_messages, instructions=self._system_prompt) ) answers_content = self._unwrap_single_response(answers_response) _LOGGER.info("[%d] Step 3/3: Answers provided (with fallback)", id(self)) diff --git a/src/ares/code_agents/terminus2/terminus2_agent_test.py b/src/ares/code_agents/terminus2/terminus2_agent_test.py index 693923b26..a9a02e569 100644 --- a/src/ares/code_agents/terminus2/terminus2_agent_test.py +++ b/src/ares/code_agents/terminus2/terminus2_agent_test.py @@ -16,6 +16,39 @@ from ares.llms import response from ares.testing.mock_container import MockContainer +# Test response JSON payloads +_RESPONSE_EXECUTE_COMMAND = """ +{ + "analysis": "Need to list files", + "plan": "Run ls command", + "commands": [ + { + "keystrokes": "ls -la\\n", + "duration": 0.1 + } + ], + "task_complete": false +} +""".strip() + +_RESPONSE_MARK_COMPLETE = """ +{ + "analysis": "Files listed", + "plan": "Task is done", + "commands": [], + "task_complete": true +} +""".strip() + +_RESPONSE_CONFIRM_COMPLETE = """ +{ + "analysis": "Confirming completion", + "plan": "Done", + "commands": [], + "task_complete": true +} +""".strip() + class TmuxSimulator: """Helper class to simulate tmux behavior for testing.""" @@ -222,56 +255,21 @@ async def test_simple_task_completion(self): llm_client = mock.AsyncMock(spec=llm_clients.LLMClient) # First response: execute a command - response1 = response.LLMResponse( - data=[ - response.TextData( - content="""{ - "analysis": "Need to list files", - "plan": "Run ls command", - "commands": [ - { - "keystrokes": "ls -la\\n", - "duration": 0.1 - } - ], - "task_complete": false -}""" - ) - ], + response1 = response.InferenceResult( + response=response.make_response(_RESPONSE_EXECUTE_COMMAND, input_tokens=100, output_tokens=50), cost=0.0, - usage=response.Usage(prompt_tokens=100, generated_tokens=50), ) # Second response: mark complete - response2 = response.LLMResponse( - data=[ - response.TextData( - content="""{ - "analysis": "Files listed", - "plan": "Task is done", - "commands": [], - "task_complete": true -}""" - ) - ], + response2 = response.InferenceResult( + response=response.make_response(_RESPONSE_MARK_COMPLETE, input_tokens=100, output_tokens=50), cost=0.0, - usage=response.Usage(prompt_tokens=100, generated_tokens=50), ) # Third response: confirm completion - response3 = response.LLMResponse( - data=[ - response.TextData( - content="""{ - "analysis": "Confirming completion", - "plan": "Done", - "commands": [], - "task_complete": true -}""" - ) - ], + response3 = response.InferenceResult( + response=response.make_response(_RESPONSE_CONFIRM_COMPLETE, input_tokens=100, output_tokens=50), cost=0.0, - usage=response.Usage(prompt_tokens=100, generated_tokens=50), ) llm_client.side_effect = [response1, response2, response3] diff --git a/src/ares/containers/docker.py b/src/ares/containers/docker.py index 6d1cae5d3..74fd7e605 100644 --- a/src/ares/containers/docker.py +++ b/src/ares/containers/docker.py @@ -101,6 +101,9 @@ async def exec_run( env: dict[str, str] | None = None, timeout_s: float | None = None, ) -> containers.ExecResult: + if self._container is None: + raise RuntimeError("Container not started. Call start() first.") + # Use default_workdir if workdir is not explicitly provided if workdir is None: workdir = self.default_workdir @@ -129,6 +132,8 @@ def stop_and_remove(self) -> None: async def upload_files(self, local_paths: list[pathlib.Path], remote_paths: list[str]) -> None: """Upload files to the container.""" + if self._container is None: + raise RuntimeError("Container not started. Call start() first.") if len(local_paths) != len(remote_paths): raise ValueError("local_paths and remote_paths must have the same length") @@ -157,6 +162,8 @@ async def upload_files(self, local_paths: list[pathlib.Path], remote_paths: list async def download_files(self, remote_paths: list[str], local_paths: list[pathlib.Path]) -> None: """Download files from the container.""" + if self._container is None: + raise RuntimeError("Container not started. Call start() first.") if len(remote_paths) != len(local_paths): raise ValueError("remote_paths and local_paths must have the same length") diff --git a/src/ares/contrib/eval_visualizer.py b/src/ares/contrib/eval_visualizer.py index be65b64a2..188a3c51d 100644 --- a/src/ares/contrib/eval_visualizer.py +++ b/src/ares/contrib/eval_visualizer.py @@ -26,6 +26,7 @@ import time from typing import ClassVar +from linguafranca import types as lft import rich.markup import rich.text from textual import app @@ -33,7 +34,6 @@ from textual import widgets from ares.environments import base -from ares.llms import request from ares.llms import response @@ -72,8 +72,8 @@ def duration(self) -> float | None: class TrackedEnvironment[RewardType: base.Scalar, DiscountType: base.Scalar]: """Wrapper around an ARES LLM environment that automatically tracks state and reports to dashboard. - This wrapper is specifically designed for environments that use LLMResponse as actions - and LLMRequest as observations. It automatically tracks LLM costs and step progress. + This wrapper is specifically designed for environments that use InferenceResult as actions + and Open Responses requests as observations. It automatically tracks LLM costs and step progress. This wrapper intercepts reset() and step() calls to automatically update the dashboard with progress information, eliminating the need for manual instrumentation. @@ -90,14 +90,14 @@ class TrackedEnvironment[RewardType: base.Scalar, DiscountType: base.Scalar]: def __init__( self, - env: base.Environment[response.LLMResponse, request.LLMRequest, RewardType, DiscountType], + env: base.Environment[response.InferenceResult, lft.OpenResponsesRequest, RewardType, DiscountType], task_id: int, dashboard: "EvaluationDashboard", ): """Initialize the tracked environment wrapper. Args: - env: The environment to wrap (must use LLMRequest/LLMResponse). + env: The environment to wrap (must use Open Responses requests and InferenceResult). task_id: The task ID for dashboard tracking. dashboard: The dashboard to report to. """ @@ -107,7 +107,7 @@ def __init__( self._step_count = 0 self._total_cost = 0.0 - async def reset(self) -> base.TimeStep[request.LLMRequest, RewardType, DiscountType]: + async def reset(self) -> base.TimeStep[lft.OpenResponsesRequest, RewardType, DiscountType]: """Reset the environment and update dashboard.""" self._dashboard.update_task(self._task_id, status=TaskStatus.RUNNING, log="Resetting environment") ts = await self._env.reset() @@ -116,7 +116,9 @@ async def reset(self) -> base.TimeStep[request.LLMRequest, RewardType, DiscountT self._total_cost = 0.0 return ts - async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMRequest, RewardType, DiscountType]: + async def step( + self, action: response.InferenceResult + ) -> base.TimeStep[lft.OpenResponsesRequest, RewardType, DiscountType]: """Step the environment and update dashboard.""" self._step_count += 1 @@ -697,16 +699,16 @@ def update_task( def wrap[RewardType: base.Scalar, DiscountType: base.Scalar]( self, task_id: int, - env: base.Environment[response.LLMResponse, request.LLMRequest, RewardType, DiscountType], + env: base.Environment[response.InferenceResult, lft.OpenResponsesRequest, RewardType, DiscountType], ) -> TrackedEnvironment[RewardType, DiscountType]: """Wrap an ARES LLM environment with automatic dashboard tracking. - This method is specifically for environments that use LLMRequest as observations - and LLMResponse as actions (all ARES code agent environments). + This method is specifically for environments that use Open Responses requests as observations + and InferenceResult as actions (all ARES code agent environments). Args: task_id: The task ID for this environment. - env: The environment to wrap (must use LLMRequest/LLMResponse). + env: The environment to wrap (must use Open Responses requests and InferenceResult). Returns: A tracked environment that automatically updates the dashboard with diff --git a/src/ares/contrib/llama_cpp.py b/src/ares/contrib/llama_cpp.py index ecc248eac..3eee86f9e 100644 --- a/src/ares/contrib/llama_cpp.py +++ b/src/ares/contrib/llama_cpp.py @@ -14,7 +14,7 @@ Example usage: from ares.contrib import llama_cpp from ares.llms import llm_clients -from ares.llms import request + from ares.llms import open_responses # Initialize with a local GGUF model file client = llama_cpp.LlamaCppLLMClient( @@ -23,7 +23,7 @@ ) # Use like any other LLM client - request = request.LLMRequest(messages=[{"role": "user", "content": "Hello!"}]) + request = open_responses.make_request([open_responses.user_message("Hello!")]) response = await client(request) Note: Download GGUF models from HuggingFace. For example: @@ -35,12 +35,12 @@ import functools import logging +from linguafranca import types as lft import llama_cpp import openai.types.chat.chat_completion from ares.llms import llm_clients -from ares.llms import openai_chat_converter -from ares.llms import request +from ares.llms import open_responses from ares.llms import response _LOGGER = logging.getLogger(__name__) @@ -75,18 +75,19 @@ def _llm(self) -> llama_cpp.Llama: n_ctx=self.n_ctx, ) - async def __call__(self, req: request.LLMRequest) -> response.LLMResponse: + async def __call__(self, req: lft.OpenResponsesRequest) -> response.InferenceResult: """Generate a response using llama.cpp. Args: - request: The LLM request containing messages and optional temperature + request: The Open Responses request to run. Returns: - LLMResponse with the generated completion + InferenceResult with the generated completion """ _LOGGER.debug("[%d] Requesting LLM.", id(self)) - completion_kwargs = openai_chat_converter.to_external(req) + completion_kwargs = open_responses.to_chat_completions_kwargs(req, model=self.model_name, strict=True) + completion_kwargs.pop("model", None) # Since llama-cpp-python sets default temperature to 0.2, we explicitly # override it to 1.0 if it's not provided by the request. completion_kwargs.setdefault("temperature", 1.0) @@ -98,11 +99,14 @@ async def __call__(self, req: request.LLMRequest) -> response.LLMResponse: _LOGGER.debug("[%d] LLM response received.", id(self)) content = chat_completion.choices[0].message.content or "" - usage = response.Usage( - prompt_tokens=chat_completion.usage.prompt_tokens if chat_completion.usage else 0, - generated_tokens=chat_completion.usage.completion_tokens if chat_completion.usage else 0, + lf_response = response.make_response( + content, + model=self.model_name, + input_tokens=chat_completion.usage.prompt_tokens if chat_completion.usage else 0, + output_tokens=chat_completion.usage.completion_tokens if chat_completion.usage else 0, + response_id=chat_completion.id, ) - return response.LLMResponse(data=[response.TextData(content=content)], cost=0.0, usage=usage) + return response.InferenceResult(response=lf_response, cost=0.0) create_qwen2_0_5b_instruct_llama_cpp_client = functools.partial( diff --git a/src/ares/contrib/mech_interp/hooked_transformer_client.py b/src/ares/contrib/mech_interp/hooked_transformer_client.py index 9cd35dc11..ba6a45231 100644 --- a/src/ares/contrib/mech_interp/hooked_transformer_client.py +++ b/src/ares/contrib/mech_interp/hooked_transformer_client.py @@ -4,10 +4,12 @@ import dataclasses from typing import Any +from linguafranca import types as lft import torch import transformer_lens from ares import llms +from ares.contrib import transformers_client @dataclasses.dataclass @@ -58,24 +60,24 @@ def _default_format_messages(messages: Sequence[Any]) -> str: async def __call__( self, - request: llms.LLMRequest, + request: lft.OpenResponsesRequest, max_output_tokens: int | None = None, - ) -> llms.LLMResponse: + ) -> llms.InferenceResult: """Generate a completion using the HookedTransformer. Args: - request: LLM request containing messages and optional temperature. + request: Open Responses request containing the model input. Returns: - LLM response with chat completion and cost information. + Inference result with response and cost information. """ max_output_tokens = max_output_tokens or request.max_output_tokens or self.max_new_tokens - # Format messages into text - messages_list = [] - if request.system_prompt: - messages_list.append({"role": "system", "content": request.system_prompt}) - messages_list.extend(request.messages) + # Use the custom renderer instead of open_responses.to_chat_messages() because + # local model tokenizers (via apply_chat_template) generally don't handle OpenAI- + # format tool_calls arrays or role="tool" messages. The custom renderer flattens + # tool interactions into plain user/assistant text that any chat template can process. + messages_list = transformers_client._render_request_to_chat_messages(request) # Tokenize input # TODO: Need to support various truncation methods @@ -121,11 +123,9 @@ async def __call__( output_text = self.model.to_string(output_ids) assert isinstance(output_text, str) # typing - return llms.LLMResponse( - data=[llms.TextData(content=output_text)], - cost=0.0, # Local inference has no cost - usage=llms.Usage( - prompt_tokens=num_input_tokens, - generated_tokens=num_output_tokens, - ), + lf_response = llms.make_response( + output_text, + input_tokens=num_input_tokens, + output_tokens=num_output_tokens, ) + return llms.InferenceResult(response=lf_response, cost=0.0) diff --git a/src/ares/contrib/transformers_client.py b/src/ares/contrib/transformers_client.py index 4f71ea9cd..8620887ac 100644 --- a/src/ares/contrib/transformers_client.py +++ b/src/ares/contrib/transformers_client.py @@ -10,7 +10,7 @@ Example usage: from ares.contrib import transformers_client - from ares.llms import request + from ares.llms import open_responses client = transformers_client.TransformersLLMClient( model_name="Qwen/Qwen2.5-0.5B-Instruct", @@ -18,7 +18,7 @@ max_batch_size=8, ) - req = request.LLMRequest(messages=[{"role": "user", "content": "Hello!"}]) + req = open_responses.make_request([open_responses.user_message("Hello!")]) response = await client(req) """ @@ -27,19 +27,21 @@ import contextlib import dataclasses import functools +import json import logging from typing import Literal, cast +from linguafranca import types as lft import torch import transformers from ares.async_utils import ValueAndFuture from ares.llms import llm_clients -from ares.llms import openai_chat_converter -from ares.llms import request +from ares.llms import open_responses from ares.llms import response _LOGGER = logging.getLogger(__name__) +_SUPPORTED_RENDER_FIELDS = frozenset({"input", "instructions", "max_output_tokens", "model", "temperature", "top_p"}) # This is defined in transformers, but not exposed. @@ -48,6 +50,123 @@ class _BaseModelWithGenerate(transformers.PreTrainedModel, transformers.Generati pass +def _render_content_to_text(content: object, *, context: str) -> str: + if isinstance(content, str): + return content + if content is None: + return "" + if not isinstance(content, list): + return str(content) + + text_parts: list[str] = [] + dropped_parts: list[str] = [] + for part in content: + if isinstance(part, dict) and part.get("type") in {"input_text", "text"}: + text_parts.append(str(part.get("text", ""))) + continue + dropped_parts.append( + str(part.get("type", type(part).__name__)) if isinstance(part, dict) else type(part).__name__ + ) + + if dropped_parts: + _LOGGER.warning( + "TransformersLLMClient dropped unsupported %s parts: %s", + context, + ", ".join(dropped_parts), + ) + + return "".join(text_parts) + + +def _render_value_to_text(value: object) -> str: + if isinstance(value, str): + return value + return json.dumps(value, ensure_ascii=True, sort_keys=True) + + +def _render_request_to_chat_messages(request: lft.OpenResponsesRequest) -> list[dict[str, str]]: + """Convert an Open Responses request to simple ``{role, content}`` chat dicts. + + We use a custom renderer instead of ``open_responses.to_chat_messages()`` because + local model tokenizers (via ``apply_chat_template``) generally don't handle OpenAI- + format ``tool_calls`` arrays or ``role="tool"`` messages. This function flattens + tool interactions into plain user/assistant text that any chat template can process. + """ + payload = open_responses.request_to_jsonable(request) + + dropped_fields = sorted( + field + for field, value in payload.items() + if field not in _SUPPORTED_RENDER_FIELDS and value not in (None, False, [], {}) + ) + if dropped_fields: + _LOGGER.warning( + "TransformersLLMClient dropped unsupported Open Responses fields: %s", + ", ".join(dropped_fields), + ) + + system_parts: list[str] = [] + instructions = payload.get("instructions") + if instructions: + system_parts.append(str(instructions)) + + rendered_messages: list[dict[str, str]] = [] + input_value = payload["input"] + if isinstance(input_value, str): + rendered_messages.append({"role": "user", "content": input_value}) + else: + for item in input_value: + if not isinstance(item, dict): + _LOGGER.warning( + "TransformersLLMClient dropped unsupported input item type: %s", + type(item).__name__, + ) + continue + + item_type = item.get("type") + if item_type == "message": + role = str(item.get("role", "user")) + content = _render_content_to_text(item.get("content"), context=f"{role} message") + if role in {"developer", "system"}: + system_parts.append(content) + elif role in {"assistant", "user"}: + rendered_messages.append({"role": role, "content": content}) + else: + _LOGGER.warning("TransformersLLMClient dropped unsupported message role: %s", role) + continue + + if item_type == "function_call": + rendered_messages.append( + { + "role": "assistant", + "content": ( + f"Called tool {item.get('name', '')} with call_id {item.get('call_id', '')}:\n" + f"{item.get('arguments', '')}" + ), + } + ) + continue + + if item_type == "function_call_output": + rendered_messages.append( + { + "role": "user", + "content": ( + f"Tool result for call_id {item.get('call_id', '')}:\n" + f"{_render_value_to_text(item.get('output', ''))}" + ), + } + ) + continue + + _LOGGER.warning("TransformersLLMClient dropped unsupported input item type: %s", item_type) + + if system_parts: + rendered_messages.insert(0, {"role": "system", "content": "\n\n".join(system_parts)}) + + return rendered_messages + + def _detect_device() -> str: """Auto-detect best available device (CUDA > MPS > CPU).""" if torch.cuda.is_available(): @@ -117,7 +236,7 @@ class TransformersLLMClient(llm_clients.LLMClient): temperature: float = 1.0 @functools.cached_property - def _request_queue(self) -> asyncio.Queue[ValueAndFuture[request.LLMRequest, response.LLMResponse]]: + def _request_queue(self) -> asyncio.Queue[ValueAndFuture[lft.OpenResponsesRequest, response.InferenceResult]]: """Lazy-initialized queue for batching requests.""" return asyncio.Queue() @@ -169,19 +288,19 @@ def _tokenizer(self) -> transformers.PreTrainedTokenizer: tokenizer.pad_token = tokenizer.eos_token return tokenizer - async def __call__(self, req: request.LLMRequest) -> response.LLMResponse: + async def __call__(self, req: lft.OpenResponsesRequest) -> response.InferenceResult: """Queue request and wait for batched inference. The background inference task is started automatically on first call. Args: - req: The LLM request containing messages and optional temperature + req: The Open Responses request to run. Returns: - LLMResponse with the generated completion + InferenceResult with the generated completion """ _ = self._inference_task # Trigger lazy initialization - future: asyncio.Future[response.LLMResponse] = asyncio.Future() + future: asyncio.Future[response.InferenceResult] = asyncio.Future() await self._request_queue.put(ValueAndFuture(value=req, future=future)) return await future @@ -202,7 +321,7 @@ async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> async def _inference_loop(self) -> None: """Background task that batches and processes requests. - Groups requests by (temperature, max_tokens) during collection to preserve per-request + Groups requests by (temperature, max_tokens, top_p) during collection to preserve per-request semantics. Collects until batch_wait_ms elapses OR any group reaches max_batch_size. Tradeoff: Grouping at collection time is more efficient than collecting mixed batches @@ -211,21 +330,23 @@ async def _inference_loop(self) -> None: and timer, allowing truly independent batching per parameter set. """ while True: - collected_items: list[ValueAndFuture[request.LLMRequest, response.LLMResponse]] = [] + collected_items: list[ValueAndFuture[lft.OpenResponsesRequest, response.InferenceResult]] = [] try: first_item = await self._request_queue.get() collected_items.append(first_item) - kwargs = openai_chat_converter.to_external(first_item.value, strict=False) - temp = kwargs.get("temperature") - max_tokens = kwargs.get("max_completion_tokens") + temp = first_item.value.temperature + max_tokens = first_item.value.max_output_tokens + top_p = first_item.value.top_p first_params = ( self.temperature if temp is None else temp, self.max_new_tokens if max_tokens is None else max_tokens, + top_p, ) - groups: dict[tuple[float, int], list[ValueAndFuture[request.LLMRequest, response.LLMResponse]]] = ( - collections.defaultdict(list) - ) + groups: dict[ + tuple[float, int, float | None], + list[ValueAndFuture[lft.OpenResponsesRequest, response.InferenceResult]], + ] = collections.defaultdict(list) groups[first_params].append(first_item) deadline = asyncio.get_event_loop().time() + (self.batch_wait_ms / 1000.0) @@ -239,12 +360,13 @@ async def _inference_loop(self) -> None: try: item = await asyncio.wait_for(self._request_queue.get(), timeout=timeout) collected_items.append(item) - kwargs = openai_chat_converter.to_external(item.value, strict=False) - temp = kwargs.get("temperature") - max_tokens = kwargs.get("max_completion_tokens") + temp = item.value.temperature + max_tokens = item.value.max_output_tokens + top_p = item.value.top_p params = ( self.temperature if temp is None else temp, self.max_new_tokens if max_tokens is None else max_tokens, + top_p, ) groups[params].append(item) except TimeoutError: @@ -263,28 +385,28 @@ async def _inference_loop(self) -> None: async def _process_batch( self, - batch: list[ValueAndFuture[request.LLMRequest, response.LLMResponse]], - params: tuple[float, int], + batch: list[ValueAndFuture[lft.OpenResponsesRequest, response.InferenceResult]], + params: tuple[float, int, float | None], ) -> None: """Process a batch of requests with homogeneous parameters. Args: batch: List of request-future pairs (all with same temperature/max_tokens) - params: (temperature, max_new_tokens) for this batch + params: (temperature, max_new_tokens, top_p) for this batch """ try: - temperature, max_new_tokens = params + temperature, max_new_tokens, top_p = params chat_conversations = [] for item in batch: - kwargs = openai_chat_converter.to_external(item.value, strict=False) - chat_conversations.append(kwargs["messages"]) + chat_conversations.append(_render_request_to_chat_messages(item.value)) responses = await asyncio.to_thread( # transformers is not async self._generate_batch, chat_conversations, temperature=temperature, max_new_tokens=max_new_tokens, + top_p=top_p, ) for item, resp in zip(batch, responses, strict=True): @@ -301,16 +423,18 @@ def _generate_batch( chat_conversations: list[list[dict[str, str]]], temperature: float, max_new_tokens: int, - ) -> list[response.LLMResponse]: + top_p: float | None, + ) -> list[response.InferenceResult]: """Generate responses for a batch of chat conversations. Args: chat_conversations: List of chat message lists temperature: Sampling temperature max_new_tokens: Maximum tokens to generate + top_p: Optional top-p sampling parameter Returns: - List of LLMResponses + List of InferenceResults """ input_texts: list[str] = [] for conv in chat_conversations: @@ -334,12 +458,18 @@ def _generate_batch( input_lengths = inputs["input_ids"].shape[1] # type: ignore[union-attr] with torch.no_grad(): + generation_kwargs = { + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "do_sample": temperature > 0, + "pad_token_id": self._tokenizer.pad_token_id, + } + if top_p is not None: + generation_kwargs["top_p"] = top_p + outputs = self._model.generate( **inputs, # type: ignore[arg-type] - max_new_tokens=max_new_tokens, - temperature=temperature, - do_sample=temperature > 0, - pad_token_id=self._tokenizer.pad_token_id, + **generation_kwargs, # type: ignore[arg-type] ) generated_ids = outputs[:, input_lengths:] @@ -350,15 +480,12 @@ def _generate_batch( prompt_tokens = inputs["input_ids"][i].ne(self._tokenizer.pad_token_id).sum().item() # type: ignore[union-attr] generated_tokens = generated_ids[i].ne(self._tokenizer.pad_token_id).sum().item() # type: ignore[arg-type] - responses.append( - response.LLMResponse( - data=[response.TextData(content=text)], - cost=0.0, # Local inference has no API cost - usage=response.Usage( - prompt_tokens=prompt_tokens, - generated_tokens=generated_tokens, - ), - ) + lf_response = response.make_response( + text, + model=self.model_name, + input_tokens=prompt_tokens, + output_tokens=generated_tokens, ) + responses.append(response.InferenceResult(response=lf_response, cost=0.0)) return responses diff --git a/src/ares/contrib/transformers_client_test.py b/src/ares/contrib/transformers_client_test.py index c0fedcdf1..92f0d2e77 100644 --- a/src/ares/contrib/transformers_client_test.py +++ b/src/ares/contrib/transformers_client_test.py @@ -9,7 +9,7 @@ import transformers from ares.contrib import transformers_client -from ares.llms import request as request_lib +from ares.llms import open_responses from ares.llms import response as response_lib # Helper functions for mocking @@ -170,6 +170,49 @@ def test_cached_properties(self): assert "_inference_task" not in client.__dict__ +class TestOpenResponsesRendering: + """Tests for direct Open Responses rendering.""" + + def test_render_request_with_instructions_and_tool_history(self): + request = open_responses.make_request( + [ + open_responses.user_message("What is the weather?"), + open_responses.assistant_message("Let me check."), + open_responses.function_call(call_id="call_123", name="get_weather", arguments='{"city":"SF"}'), + open_responses.function_call_output(call_id="call_123", output="Sunny"), + ], + instructions="You are helpful.", + ) + + rendered = transformers_client._render_request_to_chat_messages(request) + + assert rendered == [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "What is the weather?"}, + {"role": "assistant", "content": "Let me check."}, + { + "role": "assistant", + "content": 'Called tool get_weather with call_id call_123:\n{"city":"SF"}', + }, + {"role": "user", "content": "Tool result for call_id call_123:\nSunny"}, + ] + + def test_render_request_logs_dropped_fields(self, caplog): + request = open_responses.make_request( + [open_responses.user_message("Hello")], + metadata={"user_id": "123"}, + tools=[open_responses.function_tool(name="lookup")], + tool_choice=open_responses.specific_tool_choice("lookup"), + stream=True, + ) + + with caplog.at_level("WARNING"): + rendered = transformers_client._render_request_to_chat_messages(request) + + assert rendered == [{"role": "user", "content": "Hello"}] + assert "dropped unsupported Open Responses fields: metadata, stream, tool_choice, tools" in caplog.text + + class TestTransformersLLMClientLifecycle: """Tests for client lifecycle behavior.""" @@ -204,7 +247,7 @@ async def test_lazy_task_start(self): mock.patch.object(type(client), "_tokenizer", new_callable=mock.PropertyMock, return_value=mock_tokenizer), ): async with client: - req = request_lib.LLMRequest(messages=[{"role": "user", "content": "test"}]) + req = open_responses.make_request([open_responses.user_message("test")]) # Make request - should start task via cached_property resp = await client(req) @@ -212,7 +255,7 @@ async def test_lazy_task_start(self): # Task should now be cached assert "_inference_task" in client.__dict__ assert isinstance(client._inference_task, asyncio.Task) - assert isinstance(resp, response_lib.LLMResponse) + assert isinstance(resp, response_lib.InferenceResult) class TestTransformersLLMClientBatching: @@ -232,18 +275,16 @@ async def test_single_request_processing(self): with setup_client_mocks(client, mock_model, mock_tokenizer): async with client: - req = request_lib.LLMRequest( - messages=[{"role": "user", "content": "test"}], - ) + req = open_responses.make_request([open_responses.user_message("test")]) resp = await client(req) - assert isinstance(resp, response_lib.LLMResponse) - assert len(resp.data) == 1 - assert resp.data[0].content == "Response text" + assert isinstance(resp, response_lib.InferenceResult) + assert response_lib.extract_text_content(resp.response) == "Response text" assert resp.cost == 0.0 - assert resp.usage.prompt_tokens > 0 - assert resp.usage.generated_tokens > 0 + assert resp.usage is not None + assert resp.usage.input_tokens > 0 + assert resp.usage.output_tokens > 0 @pytest.mark.asyncio async def test_batch_multiple_requests(self): @@ -273,20 +314,46 @@ def tokenizer_side_effect(*args, **_): with setup_client_mocks(client, mock_model, mock_tokenizer): async with client: # Submit 3 requests concurrently - requests = [ - request_lib.LLMRequest(messages=[{"role": "user", "content": f"test {i}"}]) for i in range(3) - ] + requests = [open_responses.make_request([open_responses.user_message(f"test {i}")]) for i in range(3)] responses = await asyncio.gather(*[client(req) for req in requests]) assert len(responses) == 3 for i, resp in enumerate(responses): - assert isinstance(resp, response_lib.LLMResponse) - assert resp.data[0].content == f"Response {i + 1}" + assert isinstance(resp, response_lib.InferenceResult) + assert response_lib.extract_text_content(resp.response) == f"Response {i + 1}" # Verify generate was called once with batch mock_model.generate.assert_called_once() + @pytest.mark.asyncio + async def test_request_params_propagate_top_p_and_max_tokens(self): + """Test per-request generation params are passed through batching.""" + client = transformers_client.TransformersLLMClient( + model_name="test-model", + batch_wait_ms=10, + ) + + mock_tokenizer = create_mock_tokenizer() + mock_model = mock.MagicMock() + mock_model.generate.return_value = torch.tensor([[1, 2, 3, 4, 5, 6]]) + + with setup_client_mocks(client, mock_model, mock_tokenizer): + async with client: + req = open_responses.make_request( + [open_responses.user_message("test")], + max_output_tokens=7, + temperature=0.4, + top_p=0.8, + ) + + await client(req) + + _, kwargs = mock_model.generate.call_args + assert kwargs["max_new_tokens"] == 7 + assert kwargs["temperature"] == 0.4 + assert kwargs["top_p"] == 0.8 + @pytest.mark.asyncio async def test_integration_with_minimal_model(): @@ -342,15 +409,13 @@ async def test_integration_with_minimal_model(): ), ): async with client: - req = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - ) + req = open_responses.make_request([open_responses.user_message("Hello")]) resp = await client(req) - assert isinstance(resp, response_lib.LLMResponse) - assert len(resp.data) == 1 - assert isinstance(resp.data[0].content, str) + assert isinstance(resp, response_lib.InferenceResult) + assert isinstance(response_lib.extract_text_content(resp.response), str) assert resp.cost == 0.0 - assert resp.usage.prompt_tokens > 0 - assert resp.usage.generated_tokens > 0 + assert resp.usage is not None + assert resp.usage.input_tokens > 0 + assert resp.usage.output_tokens > 0 diff --git a/src/ares/environments/code_env.py b/src/ares/environments/code_env.py index 5c8b30a3d..f13dd6fdd 100644 --- a/src/ares/environments/code_env.py +++ b/src/ares/environments/code_env.py @@ -21,6 +21,7 @@ from harbor.models.task import task as harbor_task from harbor.models.trial import paths as harbor_paths from harbor.registry import client as harbor_dataset_client +from linguafranca import types as lft from ares.code_agents import code_agent_base from ares.code_agents import mini_swe_agent @@ -29,7 +30,6 @@ from ares.environments import base from ares.experiment_tracking import stat_tracker from ares.llms import queue_mediated_client -from ares.llms import request from ares.llms import response _LOGGER = logging.getLogger(__name__) @@ -55,7 +55,7 @@ def list_harbor_datasets() -> tuple[harbor_registry.DatasetSpec, ...]: return tuple(client.get_datasets()) -class CodeEnvironment(base.Environment[response.LLMResponse, request.LLMRequest | None, float, float]): +class CodeEnvironment(base.Environment[response.InferenceResult, lft.OpenResponsesRequest | None, float, float]): """Environment for code agent datasets that computes reward at the end of an episode.""" def __init__( @@ -79,7 +79,7 @@ def __init__( # we can return LLM requests in the reset and step methods. # We should never allow a user to pass a different LLM client. self._llm_client = queue_mediated_client.QueueMediatedLLMClient(q=asyncio.Queue()) - self._llm_req_future: asyncio.Future[response.LLMResponse] | None = None + self._llm_req_future: asyncio.Future[response.InferenceResult] | None = None # State. self._is_active = False @@ -92,7 +92,7 @@ def __init__( # Register for cleanup on exit. _ENVIRONMENT_JANITOR.register_for_cleanup(self) - async def reset(self) -> base.TimeStep[request.LLMRequest, float, float]: + async def reset(self) -> base.TimeStep[lft.OpenResponsesRequest, float, float]: reset_start_time = time.time() self._assert_active() @@ -126,7 +126,9 @@ async def reset(self) -> base.TimeStep[request.LLMRequest, float, float]: self._tracker.scalar(f"{self._prefix}/reset", reset_end_time - reset_start_time) return result - async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMRequest | None, float, float]: + async def step( + self, action: response.InferenceResult + ) -> base.TimeStep[lft.OpenResponsesRequest | None, float, float]: step_start_time = time.time() self._assert_active() @@ -162,7 +164,7 @@ async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMR async def _get_time_step( self, - ) -> base.TimeStep[request.LLMRequest | None, float, float]: + ) -> base.TimeStep[lft.OpenResponsesRequest | None, float, float]: # Wait for the code agent to send another request or complete. _LOGGER.debug("[%d] Waiting for code agent or LLM request.", id(self)) with self._tracker.timeit(f"{self._prefix}/get_from_queue"): diff --git a/src/ares/environments/twenty_questions.py b/src/ares/environments/twenty_questions.py index 29c888eb7..5d4d71012 100644 --- a/src/ares/environments/twenty_questions.py +++ b/src/ares/environments/twenty_questions.py @@ -11,11 +11,12 @@ from typing import Self import frozendict +from linguafranca import types as lft from ares.environments import base from ares.experiment_tracking import stat_tracker from ares.llms import chat_completions_compatible -from ares.llms import request +from ares.llms import open_responses from ares.llms import response _LOGGER = logging.getLogger(__name__) @@ -168,7 +169,7 @@ SIMPLE_OBJECT_LIST = ("Football", "Dog", "Banana", "Truck", "Pants", "Computer", "Piano", "Chair", "Pen", "Scissors") -class TwentyQuestionsEnvironment(base.Environment[response.LLMResponse, request.LLMRequest, float, float]): +class TwentyQuestionsEnvironment(base.Environment[response.InferenceResult, lft.OpenResponsesRequest, float, float]): """Environment for twenty questions game using an LLM-based oracle.""" def __init__( @@ -209,7 +210,7 @@ def __init__( self._step_count = 0 self._requires_reset = False - async def reset(self) -> base.TimeStep[request.LLMRequest, float, float]: + async def reset(self) -> base.TimeStep[lft.OpenResponsesRequest, float, float]: """Start a new episode by selecting a random object.""" reset_start_time = time.time() self._assert_active() @@ -232,11 +233,9 @@ async def reset(self) -> base.TimeStep[request.LLMRequest, float, float]: "When you think you know the answer, ask 'Is it [object]?'" ) - observation = request.LLMRequest( - messages=[ - request.UserMessage(role="user", content=initial_prompt), - ], - system_prompt=self._system_prompt, + observation = open_responses.make_request( + [open_responses.user_message(initial_prompt)], + instructions=self._system_prompt, ) reset_end_time = time.time() @@ -244,7 +243,7 @@ async def reset(self) -> base.TimeStep[request.LLMRequest, float, float]: return base.TimeStep(step_type="FIRST", reward=None, discount=None, observation=observation) - async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMRequest, float, float]: + async def step(self, action: response.InferenceResult) -> base.TimeStep[lft.OpenResponsesRequest, float, float]: """Process agent's question and get oracle's answer.""" step_start_time = time.time() self._assert_active() @@ -276,11 +275,9 @@ async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMR "\n".join(self._conversation_history) + f"\n\nYou win! The object was {self._hidden_object}." ) - observation = request.LLMRequest( - messages=[ - request.UserMessage(role="user", content=observation_content), - ], - system_prompt=self._system_prompt, + observation = open_responses.make_request( + [open_responses.user_message(observation_content)], + instructions=self._system_prompt, ) step_end_time = time.time() @@ -309,11 +306,9 @@ async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMR + f"\n\nYou've run out of questions! The object was {self._hidden_object}." ) - observation = request.LLMRequest( - messages=[ - request.UserMessage(role="user", content=observation_content), - ], - system_prompt=self._system_prompt, + observation = open_responses.make_request( + [open_responses.user_message(observation_content)], + instructions=self._system_prompt, ) step_end_time = time.time() @@ -324,11 +319,9 @@ async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMR # Continue episode observation_content = "\n".join(self._conversation_history) - observation = request.LLMRequest( - messages=[ - request.UserMessage(role="user", content=observation_content), - ], - system_prompt=self._system_prompt, + observation = open_responses.make_request( + [open_responses.user_message(observation_content)], + instructions=self._system_prompt, ) step_end_time = time.time() @@ -344,29 +337,25 @@ async def _get_oracle_answer(self, question: str) -> str: # Create oracle prompt oracle_prompt = ORACLE_PROMPT_TEMPLATE.format(word=self._hidden_object, question=question) - oracle_request = request.LLMRequest( - messages=[ - request.UserMessage(role="user", content=oracle_prompt), - ], - temperature=0.0, # Deterministic answers + oracle_request = open_responses.make_request( + [open_responses.user_message(oracle_prompt)], + temperature=0.0, ) # Call oracle with self._tracker.timeit(f"{self._prefix}/oracle_call"): oracle_response = await self._oracle_client(oracle_request) - # Extract answer from response - LLMResponse has data: list[TextData] - answer_text = oracle_response.data[0].content.strip() + # Extract answer from response + answer_text = response.extract_text_content(oracle_response.response).strip() _LOGGER.debug("[%d] Raw oracle response: %s", id(self), answer_text) return answer_text - def _extract_question_from_response(self, action: response.LLMResponse) -> str: + def _extract_question_from_response(self, action: response.InferenceResult) -> str: """Extract the question text from the agent's response.""" - # Get the text content from the first data element - question = action.data[0].content if action.data else "" - return question.strip() + return response.extract_text_content(action.response).strip() def _check_if_correct_guess(self, question: str) -> bool: """Check if the question is a correct guess of the hidden object.""" diff --git a/src/ares/llms/__init__.py b/src/ares/llms/__init__.py index d75ad66f1..37a1a1698 100644 --- a/src/ares/llms/__init__.py +++ b/src/ares/llms/__init__.py @@ -1,31 +1,21 @@ -"""LLM client interfaces and data types.""" +"""LLM client interfaces and data types. + +Canonical request builders and request types live in :mod:`ares.llms.open_responses`. + +Prefer ``from ares.llms import open_responses`` to access request types and builders +rather than importing individual type aliases from this package. +""" -# Request types -# Client protocol from ares.llms.chat_completions_compatible import ChatCompletionCompatibleLLMClient from ares.llms.llm_clients import LLMClient -from ares.llms.request import AssistantMessage -from ares.llms.request import LLMRequest -from ares.llms.request import Message -from ares.llms.request import ToolCallMessage -from ares.llms.request import ToolCallResponseMessage -from ares.llms.request import UserMessage - -# Response types -from ares.llms.response import LLMResponse -from ares.llms.response import TextData -from ares.llms.response import Usage +from ares.llms.response import InferenceResult +from ares.llms.response import extract_text_content +from ares.llms.response import make_response __all__ = [ - "AssistantMessage", "ChatCompletionCompatibleLLMClient", + "InferenceResult", "LLMClient", - "LLMRequest", - "LLMResponse", - "Message", - "TextData", - "ToolCallMessage", - "ToolCallResponseMessage", - "Usage", - "UserMessage", + "extract_text_content", + "make_response", ] diff --git a/src/ares/llms/anthropic_converter.py b/src/ares/llms/anthropic_converter.py deleted file mode 100644 index c76004ad5..000000000 --- a/src/ares/llms/anthropic_converter.py +++ /dev/null @@ -1,383 +0,0 @@ -"""Converter for Anthropic Messages API format. - -This module provides bidirectional conversion between ARES's internal LLMRequest format -and the Anthropic Messages API format. The module itself conforms to the -RequestConverter Protocol through its to_external and from_external functions. - -Conversion Notes: - - temperature converted between OpenAI range (0-2) and Claude range (0-1) - - messages must alternate user/assistant (enforced by Claude API) - - system_prompt mapped to/from system parameter - - service_tier options limited to "auto" and "standard_only" - - top_k is Claude-specific (supported) -""" - -import logging -from typing import Any, cast - -import anthropic.types - -from ares.llms import request as llm_request - -_LOGGER = logging.getLogger(__name__) - - -def _tool_to_anthropic(tool: llm_request.Tool) -> anthropic.types.ToolParam: - """Convert Tool from ARES internal format to Anthropic Messages format. - - Args: - tool: Tool in ARES internal format (flat with input_schema) - - Returns: - Tool in Anthropic Messages format (custom tool with type, name, description, input_schema) - """ - return anthropic.types.ToolParam( - type="custom", - name=tool["name"], - description=tool["description"], - input_schema=cast(dict[str, object], tool["input_schema"]), - ) - - -def _tool_from_anthropic( - anthropic_tool: anthropic.types.ToolUnionParam, -) -> llm_request.Tool: - """Convert tool from Anthropic Messages format to ARES internal format. - - Args: - anthropic_tool: Tool in Anthropic format (ToolParam with type='custom'/None, or built-in tool types) - - Returns: - Tool in ARES internal format - - Raises: - ValueError: If tool type is unsupported or required fields are missing - - Note: - Only supports ToolParam with type='custom' or type=None. Built-in tool types - (bash_20250124, text_editor_*, web_search_*) are not supported. - """ - # Check tool type - we only accept "custom" (or None which defaults to custom) - # Reject built-in tool types like bash_20250124, text_editor_*, web_search_* - tool_type = anthropic_tool.get("type") - if tool_type is not None and tool_type != "custom": - raise ValueError( - f"Unsupported tool type: {tool_type}. Only 'custom' tools are supported. " - f"Built-in tools (bash, text_editor, web_search) are not supported." - ) - - # Validate required fields - if "name" not in anthropic_tool: - raise ValueError("Tool missing required 'name' field") - - if "input_schema" not in anthropic_tool: - raise ValueError(f"Tool '{anthropic_tool.get('name')}' missing required 'input_schema' field") - - # Validate input_schema structure - input_schema = anthropic_tool["input_schema"] - if not isinstance(input_schema, dict): - raise ValueError(f"Tool '{anthropic_tool['name']}' input_schema must be a dict, got {type(input_schema)}") - - if "type" not in input_schema: - raise ValueError(f"Tool '{anthropic_tool['name']}' input_schema must have a 'type' field") - - return llm_request.Tool( - name=anthropic_tool["name"], - description=anthropic_tool.get("description", ""), - input_schema=cast(llm_request.JSONSchema, input_schema), - ) - - -def _tool_choice_to_anthropic(tool_choice: llm_request.ToolChoice | None) -> dict[str, Any] | None: - """Convert internal ToolChoice to Anthropic Messages format. - - Args: - tool_choice: Internal tool choice - - Returns: - Tool choice in Anthropic format: - - {"type": "auto"}: Model decides - - {"type": "any"}: Must use at least one tool - - {"type": "none"}: Must not use any tools - - {"type": "tool", "name": "..."}: Specific tool - """ - if tool_choice is None: - return None - - if tool_choice == "auto": - return {"type": "auto"} - elif tool_choice == "any": - return {"type": "any"} - elif tool_choice == "none": - return {"type": "none"} - elif isinstance(tool_choice, dict) and tool_choice.get("type") == "tool": - return {"type": "tool", "name": tool_choice["name"]} - - return None - - -def _tool_choice_from_anthropic( - tool_choice: dict[str, Any] | None, -) -> llm_request.ToolChoice | None: - """Convert Anthropic Messages tool_choice to internal format. - - Args: - tool_choice: Anthropic tool choice parameter - - Returns: - Internal ToolChoice format - """ - if tool_choice is None: - return None - - if isinstance(tool_choice, dict): - choice_type = tool_choice.get("type") - if choice_type == "auto": - return "auto" - elif choice_type == "any": - return "any" - elif choice_type == "none": - return "none" - elif choice_type == "tool" and "name" in tool_choice: - return llm_request.ToolChoiceTool(type="tool", name=tool_choice["name"]) - - return None - - -def _messages_to_claude_format(messages: list[llm_request.Message], *, strict: bool = True) -> list[dict[str, Any]]: - """Convert messages from Chat format to Claude alternating format. - - Args: - messages: List of messages in internal format - strict: If True, raise ValueError on non-alternating messages. If False, drop - consecutive messages with the same role (keeping first) with a warning. - - Returns: - List of messages in Claude format (user/assistant alternating) - - Raises: - ValueError: If strict=True and messages don't alternate roles - - Note: - Claude requires strict alternation. This method filters out system/developer - messages (should be in system_prompt) and ensures alternation. - """ - claude_messages = [] - last_role = None - dropped_count = 0 - - for msg in messages: - msg_dict = dict(msg) # Convert to regular dict for type safety - role = msg_dict["role"] - # Skip system/developer messages (should be in system_prompt) - if role in ("system", "developer"): - continue - # Map tool/function to user role (tool results) - if role in ("tool", "function"): - role = "user" - - # Check for alternation - if last_role == role: - content = str(msg_dict.get("content", ""))[:50] - if strict: - raise ValueError( - f"Messages must alternate between user and assistant roles for Claude API. " - f"Found consecutive '{role}' messages. Message content: {content}..." - ) - else: - _LOGGER.warning( - "Dropping non-alternating message with role '%s' (content: %s...). " - "Claude requires strict alternation between user and assistant.", - role, - content, - ) - dropped_count += 1 - continue - - # Keep only role and content - Claude API only accepts these fields - claude_messages.append( - { - "role": role, - "content": msg_dict.get("content", ""), - } - ) - last_role = role - - if dropped_count > 0 and not strict: - _LOGGER.warning("Dropped %d non-alternating messages for Claude API compliance", dropped_count) - - return claude_messages - - -def to_external(request: llm_request.LLMRequest, *, strict: bool = True) -> dict[str, Any]: - """Convert ARES LLMRequest to Claude Messages format. - - Args: - request: ARES internal request format - strict: If True, raise ValueError on information loss. If False, log warnings. - - Returns: - Dictionary of kwargs for anthropic.messages.create() (without model) - - Raises: - ValueError: If strict=True and information would be lost in conversion - - Note: - Model parameter is NOT included - it should be added by the LLMClient - """ - # Check for information loss - lost_info = [] - if request.service_tier not in (None, "auto", "standard_only"): - lost_info.append(f"service_tier='{request.service_tier}' (Claude only supports 'auto' and 'standard_only')") - - # Check for filtered messages - filtered_messages = [] - for msg in request.messages: - msg_dict = dict(msg) - role = msg_dict["role"] - if role in ("system", "developer"): - content = str(msg_dict.get("content", ""))[:50] - filtered_messages.append(f"{role} message: {content}...") - - if filtered_messages: - lost_info.append(f"Messages filtered out (use system_prompt instead): {'; '.join(filtered_messages)}") - - if lost_info: - msg = f"Converting to Claude Messages will lose information: {'; '.join(lost_info)}" - if strict: - raise ValueError(msg) - _LOGGER.warning(msg) - - kwargs: dict[str, Any] = { - "messages": _messages_to_claude_format(request.messages, strict=strict), - "max_tokens": request.max_output_tokens or 1024, # max_tokens is required by Claude - } - - if request.system_prompt: - kwargs["system"] = request.system_prompt - - if request.temperature is not None: - # Convert from OpenAI range (0-2) to Claude range (0-1) - kwargs["temperature"] = min(request.temperature / 2.0, 1.0) - if request.top_p is not None: - kwargs["top_p"] = request.top_p - if request.top_k is not None: - kwargs["top_k"] = request.top_k - if request.stream: - kwargs["stream"] = True - if request.tools: - # Convert tools to Anthropic format (adds explicit type: "custom") - kwargs["tools"] = [_tool_to_anthropic(tool) for tool in request.tools] - if request.tool_choice is not None: - kwargs["tool_choice"] = _tool_choice_to_anthropic(request.tool_choice) - if request.metadata: - # Claude uses metadata.user_id specifically - kwargs["metadata"] = request.metadata - if request.service_tier in ("auto", "standard_only"): - kwargs["service_tier"] = request.service_tier - if request.stop_sequences: - kwargs["stop_sequences"] = request.stop_sequences - - return kwargs - - -def from_external( - kwargs: anthropic.types.MessageCreateParams, - *, - strict: bool = True, -) -> llm_request.LLMRequest: - """Create LLMRequest from Claude Messages API kwargs. - - Args: - kwargs: Claude Messages API parameters - strict: If True, raise ValueError for unhandled parameters. If False, log warnings. - - Returns: - LLMRequest instance - - Raises: - ValueError: If strict=True and there are unhandled parameters - """ - # Define parameters we handle (model is accepted but not stored) - handled_params = { - "model", # Accepted but not stored - managed by LLMClient - "messages", - "max_tokens", - "temperature", - "top_p", - "top_k", - "stream", - "tools", - "tool_choice", - "metadata", - "service_tier", - "stop_sequences", - "system", - } - - # Check for unhandled parameters - unhandled = set(kwargs.keys()) - handled_params - if unhandled: - msg = f"Unhandled Claude Messages parameters (will be ignored): {sorted(unhandled)}" - if strict: - raise ValueError(msg) - _LOGGER.warning(msg) - - # Convert temperature from Claude range (0-1) to OpenAI range (0-2) - temperature = kwargs.get("temperature") - if temperature is not None: - temperature = temperature * 2.0 - - # Extract system prompt (can be str or list of text blocks) - system_param = kwargs.get("system") - system_prompt = None - if system_param: - system_prompt = llm_request._extract_string_content(system_param, strict=strict, context="System prompt") - - # Filter and validate messages - filtered_messages: list[llm_request.Message] = [] - for msg in kwargs["messages"]: - role = msg.get("role") - - # Validate role is supported - if role not in llm_request._VALID_ROLES: - if strict: - raise ValueError(f"Unsupported message role: {role}. Must be one of {llm_request._VALID_ROLES}") - _LOGGER.warning("Skipping message with unsupported role: %s", role) - continue - - # Convert to our Message format, validating content - message_dict = dict(msg) - if "content" in message_dict: - message_dict["content"] = llm_request._extract_string_content( - message_dict["content"], strict=strict, context=f"Message content (role={role})" - ) - filtered_messages.append(cast(llm_request.Message, message_dict)) - - # Convert tools from Anthropic format to internal format - tools_param = kwargs.get("tools") - converted_tools: list[llm_request.Tool] | None = None - if tools_param: - converted_tools = [] - for tool in tools_param: - try: - converted_tools.append(_tool_from_anthropic(tool)) - except ValueError as e: - if strict: - raise - _LOGGER.warning("Skipping invalid tool: %s", e) - - return llm_request.LLMRequest( - messages=filtered_messages, - max_output_tokens=kwargs["max_tokens"], - temperature=temperature, - top_p=kwargs.get("top_p"), - top_k=kwargs.get("top_k"), - stream=bool(kwargs.get("stream", False)), - tools=converted_tools, - tool_choice=_tool_choice_from_anthropic(cast(dict[str, Any] | None, kwargs.get("tool_choice"))), - metadata=cast(dict[str, Any] | None, kwargs.get("metadata")), - service_tier=kwargs.get("service_tier"), - stop_sequences=cast(list[str] | None, kwargs.get("stop_sequences")), - system_prompt=system_prompt, - ) diff --git a/src/ares/llms/anthropic_converter_test.py b/src/ares/llms/anthropic_converter_test.py deleted file mode 100644 index fa575e8bf..000000000 --- a/src/ares/llms/anthropic_converter_test.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Unit tests for Anthropic Messages API converter.""" - -import anthropic.types -import anthropic.types.message_create_params -import pytest - -from ares.llms import anthropic_converter -from ares.llms import request as request_lib - - -class TestStructuredContentHandling: - """Tests for handling structured content (list of blocks) in Messages API conversions.""" - - def test_from_messages_with_structured_content_strict(self): - """Test that structured content in Claude messages raises error in strict mode.""" - kwargs = anthropic.types.message_create_params.MessageCreateParamsNonStreaming( - model="claude-3-opus", - max_tokens=100, - messages=[ - anthropic.types.MessageParam( - role="user", - content=[ - anthropic.types.TextBlockParam(type="text", text="Analyze this image"), - anthropic.types.ImageBlockParam( - type="image", - source=anthropic.types.base64_image_source_param.Base64ImageSourceParam( - type="base64", media_type="image/png", data="..." - ), - ), - ], - ) - ], - ) - - with pytest.raises(ValueError, match=r"list of blocks.*structured content"): - anthropic_converter.from_external(kwargs, strict=True) - - def test_system_prompt_with_structured_content_strict(self): - """Test that structured system prompt raises error in strict mode.""" - kwargs = anthropic.types.message_create_params.MessageCreateParamsNonStreaming( - model="claude-3-opus", - max_tokens=100, - system=[ - anthropic.types.TextBlockParam(type="text", text="You are a helpful assistant."), - anthropic.types.TextBlockParam(type="text", text="Always be concise."), - ], - messages=[anthropic.types.MessageParam(role="user", content="Hello")], - ) - - with pytest.raises(ValueError, match=r"list of blocks.*structured content"): - anthropic_converter.from_external(kwargs, strict=True) - - -class TestLLMRequestMessagesConversion: - """Tests for Claude Messages API conversion.""" - - def test_to_messages_minimal(self): - """Test minimal conversion to Messages format.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - ) - kwargs = anthropic_converter.to_external(request) - - assert kwargs["messages"] == [{"role": "user", "content": "Hello"}] - assert kwargs["max_tokens"] == 1024 # Default required by Claude - - def test_to_messages_all_params(self): - """Test conversion with all common parameters.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - max_output_tokens=100, - temperature=1.4, # Will be converted to 0.7 - top_p=0.9, - top_k=40, - stream=True, - tools=[ - { - "name": "test", - "description": "A test function", - "input_schema": {"type": "object", "properties": {}}, - } - ], - tool_choice="auto", # Internal format - metadata={"user_id": "123"}, - service_tier="auto", - stop_sequences=["STOP", "END"], - ) - kwargs = anthropic_converter.to_external(request) - - assert kwargs["max_tokens"] == 100 - assert kwargs["temperature"] == 0.7 # Converted from 1.4 - assert kwargs["top_p"] == 0.9 - assert kwargs["top_k"] == 40 - assert kwargs["stream"] is True - # Tools are converted to Anthropic format (explicit type: "custom") - assert kwargs["tools"] == [ - { - "type": "custom", - "name": "test", - "description": "A test function", - "input_schema": {"type": "object", "properties": {}}, - } - ] - assert kwargs["tool_choice"] == {"type": "auto"} - assert kwargs["metadata"] == {"user_id": "123"} - assert kwargs["service_tier"] == "auto" - assert kwargs["stop_sequences"] == ["STOP", "END"] - - def test_to_messages_temperature_conversion(self): - """Test temperature conversion from OpenAI (0-2) to Claude (0-1) range.""" - test_cases = [ - (0.0, 0.0), - (1.0, 0.5), - (2.0, 1.0), - (0.5, 0.25), - (1.5, 0.75), - ] - for openai_temp, claude_temp in test_cases: - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - temperature=openai_temp, - ) - kwargs = anthropic_converter.to_external(request) - assert kwargs["temperature"] == claude_temp - - def test_to_messages_with_system_prompt(self): - """Test system prompt is mapped to system parameter.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - system_prompt="You are a helpful assistant.", - ) - kwargs = anthropic_converter.to_external(request) - - assert kwargs["system"] == "You are a helpful assistant." - - def test_to_messages_excludes_invalid_service_tier(self): - """Test that non-Claude service tiers are excluded.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - service_tier="flex", # Not supported by Claude - ) - kwargs = anthropic_converter.to_external(request, strict=False) - - assert "service_tier" not in kwargs - - def test_to_messages_filters_system_messages(self): - """Test that system/developer messages are filtered out.""" - # Note: system messages are not valid in our Message type but we want to test filtering - # so we pass them as-is (they'll be filtered by to_messages_kwargs) - request = request_lib.LLMRequest( - messages=[ - request_lib.UserMessage(role="user", content="Hello"), - request_lib.AssistantMessage(role="assistant", content="Hi"), - ], - system_prompt="System message", - ) - kwargs = anthropic_converter.to_external(request, strict=False) - - assert len(kwargs["messages"]) == 2 - assert kwargs["messages"][0]["role"] == "user" - assert kwargs["messages"][1]["role"] == "assistant" - - def test_to_messages_maps_tool_to_user(self): - """Test that tool/function roles are mapped to user.""" - messages: list[request_lib.Message] = [ - request_lib.UserMessage(role="user", content="What's the weather?"), - request_lib.AssistantMessage(role="assistant", content="Let me check..."), - request_lib.ToolCallResponseMessage(role="tool", content="Sunny, 72°F", tool_call_id="test"), - ] - request = request_lib.LLMRequest( - messages=messages, - ) - kwargs = anthropic_converter.to_external(request) - - assert kwargs["messages"][2]["role"] == "user" - - def test_from_messages_minimal(self): - """Test parsing minimal Messages request.""" - kwargs: anthropic.types.MessageCreateParams = { - "model": "claude-sonnet-4-5-20250929", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 100, - } - request = anthropic_converter.from_external(kwargs) - - assert list(request.messages) == [{"role": "user", "content": "Hello"}] - assert request.max_output_tokens == 100 - - def test_from_messages_all_params(self): - """Test parsing Messages with all parameters.""" - tool: anthropic.types.ToolParam = { - "type": "custom", - "name": "test", - "input_schema": {"type": "object", "properties": {}}, - } - tool_choice: anthropic.types.ToolChoiceAutoParam = {"type": "auto"} - msg: anthropic.types.MessageParam = {"role": "user", "content": "Hello"} - metadata: anthropic.types.MetadataParam = {"user_id": "123"} - - kwargs: anthropic.types.MessageCreateParams = { - "model": "claude-sonnet-4-5-20250929", - "messages": [msg], - "max_tokens": 100, - "temperature": 0.7, # Will be converted to 1.4 - "top_p": 0.9, - "top_k": 40, - "stream": True, - "tools": [tool], - "tool_choice": tool_choice, - "metadata": metadata, - "service_tier": "auto", - "stop_sequences": ["STOP"], - "system": "You are helpful.", - } - request = anthropic_converter.from_external(kwargs) - - assert request.max_output_tokens == 100 - assert request.temperature == 1.4 # Converted from 0.7 - assert request.top_p == 0.9 - assert request.top_k == 40 - assert request.stream is True - # Tools are converted to internal format (name, description, input_schema) - assert request.tools == [ - {"name": "test", "description": "", "input_schema": {"type": "object", "properties": {}}} - ] - assert request.tool_choice == "auto" # Internal format - assert request.metadata == {"user_id": "123"} - assert request.service_tier == "auto" - assert request.stop_sequences == ["STOP"] - assert request.system_prompt == "You are helpful." - - def test_from_messages_temperature_conversion(self): - """Test temperature conversion from Claude (0-1) to OpenAI (0-2) range.""" - test_cases = [ - (0.0, 0.0), - (0.5, 1.0), - (1.0, 2.0), - (0.25, 0.5), - (0.75, 1.5), - ] - for claude_temp, openai_temp in test_cases: - kwargs: anthropic.types.MessageCreateParams = { - "model": "claude-sonnet-4-5-20250929", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 100, - "temperature": claude_temp, - } - request = anthropic_converter.from_external(kwargs) - assert request.temperature == openai_temp - - def test_to_messages_strict_alternation_error(self): - """Test that non-alternating messages raise error in strict mode.""" - request = request_lib.LLMRequest( - messages=[ - request_lib.UserMessage(role="user", content="Hello"), - request_lib.UserMessage(role="user", content="How are you?"), - ], - ) - with pytest.raises(ValueError, match="Messages must alternate"): - anthropic_converter.to_external(request, strict=True) - - def test_to_messages_non_strict_drops_duplicates(self): - """Test that non-alternating messages are dropped in non-strict mode.""" - request = request_lib.LLMRequest( - messages=[ - request_lib.UserMessage(role="user", content="Hello"), - request_lib.UserMessage(role="user", content="How are you?"), - request_lib.AssistantMessage(role="assistant", content="I'm fine"), - request_lib.AssistantMessage(role="assistant", content="Thanks"), - request_lib.UserMessage(role="user", content="Great"), - ], - ) - kwargs = anthropic_converter.to_external(request, strict=False) - - # Should keep first of each consecutive group - assert len(kwargs["messages"]) == 3 - assert kwargs["messages"][0] == {"role": "user", "content": "Hello"} - assert kwargs["messages"][1] == {"role": "assistant", "content": "I'm fine"} - assert kwargs["messages"][2] == {"role": "user", "content": "Great"} diff --git a/src/ares/llms/chat_completions_compatible.py b/src/ares/llms/chat_completions_compatible.py index f98ec4d57..9728ae96d 100644 --- a/src/ares/llms/chat_completions_compatible.py +++ b/src/ares/llms/chat_completions_compatible.py @@ -5,6 +5,7 @@ import threading import httpx +from linguafranca import types as lft import openai import openai.types.chat.chat_completion import tenacity @@ -12,8 +13,7 @@ from ares import config from ares.llms import accounting from ares.llms import llm_clients -from ares.llms import openai_chat_converter -from ares.llms import request +from ares.llms import open_responses from ares.llms import response _LOGGER = logging.getLogger(__name__) @@ -47,9 +47,9 @@ def _get_llm_client(base_url: str, api_key: str) -> openai.AsyncClient: before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), ) async def _query_llm_with_retry( - llm_client: openai.AsyncClient, model: str, req: request.LLMRequest + llm_client: openai.AsyncClient, req: lft.OpenResponsesRequest ) -> openai.types.chat.chat_completion.ChatCompletion: - response = await llm_client.chat.completions.create(model=model, **openai_chat_converter.to_external(req)) + response = await llm_client.chat.completions.create(**open_responses.to_chat_completions_kwargs(req)) return response @@ -59,22 +59,27 @@ class ChatCompletionCompatibleLLMClient(llm_clients.LLMClient): base_url: str = config.CONFIG.chat_completion_api_base_url api_key: str = config.CONFIG.chat_completion_api_key - async def __call__(self, request: request.LLMRequest) -> response.LLMResponse: + async def __call__(self, request: lft.OpenResponsesRequest) -> response.InferenceResult: _LOGGER.debug("[%d] Requesting LLM.", id(self)) + request = open_responses.with_model(request, self.model) + # GPT-5 models don't support temperature. if self.model.startswith("openai/gpt-5"): request = dataclasses.replace(request, temperature=None) - resp = await _query_llm_with_retry(_get_llm_client(self.base_url, self.api_key), self.model, request) + resp = await _query_llm_with_retry(_get_llm_client(self.base_url, self.api_key), request) _LOGGER.debug("[%d] LLM response received.", id(self)) cost = accounting.get_llm_cost(self.model, resp, cost_mapping=accounting.martian_cost_list()) cost = float(cost) content = resp.choices[0].message.content or "" - usage = response.Usage( - prompt_tokens=resp.usage.prompt_tokens if resp.usage else 0, - generated_tokens=resp.usage.completion_tokens if resp.usage else 0, + lf_response = response.make_response( + content, + model=self.model, + input_tokens=resp.usage.prompt_tokens if resp.usage else 0, + output_tokens=resp.usage.completion_tokens if resp.usage else 0, + response_id=resp.id, ) - return response.LLMResponse(data=[response.TextData(content=content)], cost=cost, usage=usage) + return response.InferenceResult(response=lf_response, cost=cost) diff --git a/src/ares/llms/llm_clients.py b/src/ares/llms/llm_clients.py index 7a9ebc73d..e4a4d1e3e 100644 --- a/src/ares/llms/llm_clients.py +++ b/src/ares/llms/llm_clients.py @@ -2,7 +2,8 @@ from typing import Protocol -from ares.llms import request +from linguafranca import types as lft + from ares.llms import response @@ -18,4 +19,4 @@ def __init__(self, message: str, truncated_response: str | None = None): class LLMClient(Protocol): - async def __call__(self, request: request.LLMRequest) -> response.LLMResponse: ... + async def __call__(self, request: lft.OpenResponsesRequest) -> response.InferenceResult: ... diff --git a/src/ares/llms/open_responses.py b/src/ares/llms/open_responses.py new file mode 100644 index 000000000..92aa00a0b --- /dev/null +++ b/src/ares/llms/open_responses.py @@ -0,0 +1,384 @@ +"""Helpers for ARES's canonical Open Responses request type.""" + +from collections.abc import Sequence +import dataclasses +import enum +import logging +from typing import Any, cast + +import linguafranca as lf +from linguafranca import types as lft + +_LOGGER = logging.getLogger(__name__) + +# Placeholder model identifier used when the actual model isn't known at request creation time. +# In ARES's RL loop, code agents create requests without knowing which model will handle them - +# the model is determined later by the policy/LLM client. This stub is replaced via with_model() +# or stripped with a warning if it leaks into API calls (see to_chat_completions_kwargs). +MODEL_STUB = "__ARES_MODEL_UNSET__" +_JSON_VALUE = int | float | str | bool | None +_JSONABLE = dict[str, "_JSONABLE"] | list["_JSONABLE"] | _JSON_VALUE + + +def user_message(content: str) -> lft.InputItemMessage: + """Create a user message input item. + + Args: + content: The text content of the message. + + Returns: + An InputItemMessage with role "user". + """ + return lft.InputItemMessage( + content=content, + role=lft.MessageRole.user, # type: ignore[arg-type] + type="message", + ) + + +def assistant_message(content: str) -> lft.InputItemMessage: + """Create an assistant message input item. + + Args: + content: The text content of the message. + + Returns: + An InputItemMessage with role "assistant". + """ + return lft.InputItemMessage( + content=content, + role=lft.MessageRole.assistant, # type: ignore[arg-type] + type="message", + ) + + +def function_call(*, call_id: str, name: str, arguments: str) -> lft.InputItemFunctionCall: + """Create a function call input item representing an assistant's tool invocation. + + Args: + call_id: Unique identifier for this tool call. + name: Name of the function being called. + arguments: JSON-encoded arguments to the function. + + Returns: + An InputItemFunctionCall representing the tool invocation. + """ + return lft.InputItemFunctionCall(arguments=arguments, call_id=call_id, name=name, type="function_call") + + +def function_call_output(*, call_id: str, output: str) -> lft.InputItemFunctionCallOutput: + """Create a function call output item representing the result of a tool invocation. + + Args: + call_id: Identifier matching the original function_call. + output: The string result of the function execution. + + Returns: + An InputItemFunctionCallOutput containing the tool result. + """ + return lft.InputItemFunctionCallOutput(call_id=call_id, output=output, type="function_call_output") + + +def function_tool( + *, name: str, description: str | None = None, parameters: Any | None = None, strict: bool | None = None +) -> lft.Tool: + """Create a function tool definition for use in requests. + + Args: + name: The name of the function. + description: Human-readable description of what the function does. + parameters: JSON Schema object describing the function's parameters. + strict: Whether to enforce strict schema validation. + + Returns: + A ToolFunction that can be passed to make_request. + """ + return lft.ToolFunction( + name=name, + type="function", + description=description, + parameters=parameters, + strict=strict, + ) + + +def specific_tool_choice(name: str) -> lft.SpecificToolChoiceFunction: + """Create a tool choice that forces the model to call a specific function. + + Args: + name: Name of the function the model must call. + + Returns: + A SpecificToolChoiceFunction for use as tool_choice in make_request. + """ + return lft.SpecificToolChoiceFunction(type="function", name=name) + + +def make_request( + items: str | Sequence[lft.InputItem], + *, + model: str = MODEL_STUB, + instructions: str | None = None, + max_output_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + stream: bool | None = None, + tools: Sequence[lft.Tool] | None = None, + tool_choice: lft.ToolChoice | None = None, + metadata: dict[str, Any] | None = None, + parallel_tool_calls: bool | None = None, + service_tier: lft.ServiceTier | None = None, +) -> lft.OpenResponsesRequest: + """Create an OpenResponsesRequest from input items and optional parameters. + + Args: + items: Either a string (treated as a user message) or a sequence of input items. + model: Model identifier. Defaults to MODEL_STUB which must be replaced before API calls. + instructions: System instructions for the model. + max_output_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0 to 2.0). + top_p: Nucleus sampling parameter. + stream: Whether to stream the response. + tools: List of tool definitions the model may call. + tool_choice: How the model should choose tools ("auto", "none", or specific). + metadata: Arbitrary metadata to attach to the request. + parallel_tool_calls: Whether the model can make multiple tool calls in parallel. + service_tier: Service tier for the request. + + Returns: + A fully-formed OpenResponsesRequest ready for use with an LLMClient. + """ + input_value = items if isinstance(items, str) else list(items) + return lft.OpenResponsesRequest( + input=input_value, # type: ignore[arg-type] + model=model, + instructions=instructions, + max_output_tokens=max_output_tokens, + temperature=temperature, + top_p=top_p, + stream=stream, + tools=list(tools) if tools is not None else None, # type: ignore[arg-type] + tool_choice=tool_choice, # type: ignore[arg-type] + metadata=metadata, + parallel_tool_calls=parallel_tool_calls, + service_tier=service_tier, # type: ignore[arg-type] + ) + + +def with_model(request: lft.OpenResponsesRequest, model: str) -> lft.OpenResponsesRequest: + """Return a copy of the request with the model field replaced. + + Args: + request: The original request. + model: The new model identifier. + + Returns: + A new request with the updated model field. + """ + return dataclasses.replace(request, model=model) + + +def input_items(request: lft.OpenResponsesRequest) -> list[lft.InputItem]: + """Extract input items from a request, normalizing string inputs to user messages. + + Args: + request: The request to extract items from. + + Returns: + List of input items. If the request input was a string, returns a single-element + list containing a user message with that content. + """ + if isinstance(request.input, str): + return [user_message(request.input)] + return list(request.input) + + +def message_items(request: lft.OpenResponsesRequest) -> list[lft.InputItemMessage]: + """Extract only message items from a request, filtering out function calls and other types. + + Args: + request: The request to extract messages from. + + Returns: + List of InputItemMessage objects (user and assistant messages only). + """ + return [item for item in input_items(request) if isinstance(item, lft.InputItemMessage)] + + +def extract_text_content( + content: str | Sequence[lft.InputContentPart], *, strict: bool = True, context: str = "content" +) -> str: + if isinstance(content, str): + return content + + text_parts: list[str] = [] + unsupported_types: list[str] = [] + + for part in content: + if isinstance(part, lft.ContentPartInputText): + text_parts.append(part.text) + continue + unsupported_types.append(getattr(part, "type", type(part).__name__)) + + if unsupported_types: + msg = f"{context} contains unsupported parts: {', '.join(unsupported_types)}" + if strict: + raise ValueError(msg) + _LOGGER.warning(msg) + + return "".join(text_parts) + + +def message_text(message: lft.InputItemMessage, *, strict: bool = True) -> str: + """Extract text content from a message. + + Args: + message: The message to extract text from. + strict: If True, raise ValueError for unsupported content types. If False, log a + warning and skip them. + + Returns: + The concatenated text content of the message. + + Raises: + ValueError: If strict=True and the message contains unsupported content parts. + """ + return extract_text_content(message.content, strict=strict, context=f"{message.role} message content") + + +def handle_conversion_warnings( + warnings: Sequence[lf.ConversionWarning], + *, + strict: bool, + context: str, +) -> None: + if not warnings: + return + + formatted = "; ".join(f"{warning.field}: {warning.message}" for warning in warnings) + if strict: + raise ValueError(f"Lossy conversion during {context}: {formatted}") + + for warning in warnings: + _LOGGER.warning("%s warning for %s: %s", context, warning.field, warning.message) + + +def to_jsonable(value: Any) -> _JSONABLE: + # TODO: Replace this with frfr. + # The issue is that OpenResponsesRequest has enum values, which frfr doesn't handle correctly yet. + # It won't, in general, either. So we should make sure OpenResposnesRequest enums are StrEnums where appropriate. + if dataclasses.is_dataclass(value): + return {field.name: to_jsonable(getattr(value, field.name)) for field in dataclasses.fields(value)} + if isinstance(value, enum.Enum): + return value.value + if isinstance(value, list): + return [to_jsonable(item) for item in value] + if isinstance(value, tuple): + return [to_jsonable(item) for item in value] + if isinstance(value, dict): + return {key: to_jsonable(item) for key, item in value.items()} + return value + + +def request_to_jsonable(request: lft.OpenResponsesRequest) -> dict[str, Any]: + """Convert an OpenResponsesRequest to a JSON-serializable dictionary. + + Args: + request: The request to convert. + + Returns: + A dictionary suitable for JSON serialization. + """ + return cast(dict[str, Any], to_jsonable(request)) + + +def _flatten_chat_tool_call_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + flattened: list[dict[str, Any]] = [] + for message in messages: + tool_calls = message.get("tool_calls") + if tool_calls and flattened and flattened[-1].get("role") == "assistant": + flattened[-1]["tool_calls"] = tool_calls + continue + flattened.append(message) + return flattened + + +def _strip_chat_tool_strict_flags(payload: dict[str, Any]) -> None: + tools = payload.get("tools") + if not isinstance(tools, list): + return + + for tool in tools: + if not isinstance(tool, dict): + continue + function = tool.get("function") + if isinstance(function, dict): + function.pop("strict", None) + + +def normalize_chat_completions_payload(payload: dict[str, Any]) -> dict[str, Any]: + messages = payload.get("messages") + if isinstance(messages, list): + payload["messages"] = _flatten_chat_tool_call_messages(cast(list[dict[str, Any]], messages)) + + _strip_chat_tool_strict_flags(payload) + + if payload.get("stream") is False: + payload.pop("stream", None) + + return payload + + +def to_chat_completions_kwargs( + request: lft.OpenResponsesRequest, *, model: str | None = None, strict: bool = True +) -> dict[str, Any]: + """Convert an OpenResponsesRequest to OpenAI Chat Completions API kwargs. + + Args: + request: The Open Responses request to convert. + model: Optional model override. If provided, replaces the request's model field. + strict: If True, raise ValueError on lossy conversions. If False, log warnings. + + Returns: + Dictionary of kwargs suitable for passing to openai.chat.completions.create(). + + Raises: + ValueError: If strict=True and the conversion would lose information. + """ + request_with_model = with_model(request, model) if model is not None else request + result = lf.convert_request_json( + request_to_jsonable(request_with_model), + source_format=lf.FormatName.OPEN_RESPONSES, + target_format=lf.FormatName.OPENAI_CHAT_COMPLETIONS, + ) + handle_conversion_warnings(result.warnings, strict=strict, context="Open Responses -> Chat Completions") + payload = cast(dict[str, Any], result.value) + # Guard against the MODEL_STUB leaking into API calls. + if payload.get("model") == MODEL_STUB: + _LOGGER.warning("MODEL_STUB is still set on request; stripping from Chat Completions payload.") + payload.pop("model", None) + return normalize_chat_completions_payload(payload) + + +def to_chat_messages( + request: lft.OpenResponsesRequest, *, model: str | None = None, strict: bool = True +) -> list[dict[str, Any]]: + """Convert an OpenResponsesRequest to Chat Completions message format. + + This is a convenience wrapper around to_chat_completions_kwargs that returns + only the messages array. + + Args: + request: The Open Responses request to convert. + model: Optional model override. + strict: If True, raise ValueError on lossy conversions. + + Returns: + List of message dictionaries in Chat Completions format. + + Raises: + ValueError: If strict=True and the conversion would lose information. + """ + kwargs = to_chat_completions_kwargs(request, model=model, strict=strict) + return cast(list[dict[str, Any]], kwargs["messages"]) + diff --git a/src/ares/llms/open_responses_test.py b/src/ares/llms/open_responses_test.py new file mode 100644 index 000000000..9f0785884 --- /dev/null +++ b/src/ares/llms/open_responses_test.py @@ -0,0 +1,100 @@ +"""Tests for the canonical Open Responses helpers.""" + +from ares.llms import open_responses + + +def test_make_request_defaults_to_model_stub(): + request = open_responses.make_request([open_responses.user_message("Hello")]) + + assert request.model == open_responses.MODEL_STUB + assert len(open_responses.message_items(request)) == 1 + assert open_responses.message_text(open_responses.message_items(request)[0]) == "Hello" + + +def test_input_items_wraps_string_input_as_user_message(): + request = open_responses.make_request("Hello") + + messages = open_responses.message_items(request) + assert len(messages) == 1 + assert messages[0].role.value == "user" + assert open_responses.message_text(messages[0]) == "Hello" + + +def test_to_chat_completions_kwargs_maps_instructions_and_messages(): + request = open_responses.make_request( + [open_responses.user_message("Hi"), open_responses.assistant_message("Hello")], + model="test-model", + instructions="Be concise.", + temperature=0.25, + ) + + kwargs = open_responses.to_chat_completions_kwargs(request) + + assert kwargs["model"] == "test-model" + assert kwargs["temperature"] == 0.25 + assert kwargs["messages"][0] == {"role": "system", "content": "Be concise."} + assert kwargs["messages"][1] == {"role": "user", "content": "Hi"} + assert kwargs["messages"][2] == {"role": "assistant", "content": "Hello"} + + +def test_to_chat_completions_kwargs_flattens_tool_calls_and_strips_tool_strict(): + request = open_responses.make_request( + [ + open_responses.user_message("What is the weather?"), + open_responses.assistant_message("Let me check."), + open_responses.function_call(call_id="call_123", name="get_weather", arguments='{"location":"SF"}'), + ], + model="test-model", + tools=[ + open_responses.function_tool( + name="get_weather", + description="Look up weather.", + parameters={"type": "object", "properties": {}}, + strict=True, + ) + ], + ) + + kwargs = open_responses.to_chat_completions_kwargs(request) + + assert kwargs["messages"] == [ + {"role": "user", "content": "What is the weather?"}, + { + "role": "assistant", + "content": "Let me check.", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location":"SF"}'}, + } + ], + }, + ] + assert kwargs["tools"][0]["function"]["name"] == "get_weather" + assert "strict" not in kwargs["tools"][0]["function"] + + +def test_ensure_request_accepts_canonical_request(): + request = open_responses.make_request([open_responses.user_message("Hello")]) + result = open_responses.ensure_request(request) + assert result is request + + +def test_to_chat_completions_kwargs_strips_model_stub(caplog): + request = open_responses.make_request([open_responses.user_message("Hello")]) + assert request.model == open_responses.MODEL_STUB + + with caplog.at_level("WARNING"): + kwargs = open_responses.to_chat_completions_kwargs(request) + + assert "model" not in kwargs + assert "MODEL_STUB" in caplog.text + + +def test_to_chat_completions_kwargs_preserves_real_model(): + request = open_responses.make_request([open_responses.user_message("Hello")], model="gpt-4o") + + kwargs = open_responses.to_chat_completions_kwargs(request) + + assert kwargs["model"] == "gpt-4o" diff --git a/src/ares/llms/openai_chat_converter.py b/src/ares/llms/openai_chat_converter.py deleted file mode 100644 index 8b6b7885a..000000000 --- a/src/ares/llms/openai_chat_converter.py +++ /dev/null @@ -1,395 +0,0 @@ -"""Converter for OpenAI Chat Completions API format. - -This module provides bidirectional conversion between ARES's internal LLMRequest format -and the OpenAI Chat Completions API format. The module itself conforms to the -RequestConverter Protocol through its to_external and from_external functions. - -Conversion Notes: - - top_k is not supported (Claude-specific) - - service_tier="standard_only" is not supported - - stop_sequences truncated to 4 (OpenAI limit) - - system_prompt is converted to/from system message in messages list - - ToolCallMessage flattened into AssistantMessage.tool_calls -""" - -import logging -from typing import Any, cast - -import openai.types.chat -import openai.types.chat.completion_create_params - -from ares.llms import request as llm_request - -_LOGGER = logging.getLogger(__name__) - - -def _tool_to_chat_completions(tool: llm_request.Tool) -> openai.types.chat.ChatCompletionToolParam: - """Convert Tool from ARES internal format to OpenAI Chat Completions format. - - Args: - tool: Tool in ARES internal format (flat with input_schema) - - Returns: - Tool in OpenAI Chat Completions format (nested with type and function.parameters) - """ - return openai.types.chat.ChatCompletionToolParam( - type="function", - function=openai.types.shared_params.FunctionDefinition( - name=tool["name"], - description=tool["description"], - parameters=cast(dict[str, object], tool["input_schema"]), - ), - ) - - -def _tool_from_chat_completions(chat_completions_tool: openai.types.chat.ChatCompletionToolParam) -> llm_request.Tool: - """Convert tool from OpenAI Chat Completions format to ARES internal format. - - Args: - chat_completions_tool: Tool in OpenAI Chat Completions format (nested with type and function.parameters) - - Returns: - Tool in ARES internal format (flat with input_schema) - """ - function = chat_completions_tool["function"] - parameters = function.get("parameters", {"type": "object", "properties": {}}) - - # Validate that parameters is a valid JSONSchema - if not isinstance(parameters, dict): - raise ValueError(f"Tool parameters must be a dict, got {type(parameters)}") - if "type" not in parameters: - raise ValueError("Tool parameters must have a 'type' field") - - return llm_request.Tool( - name=function["name"], - description=function.get("description", ""), - input_schema=cast(llm_request.JSONSchema, parameters), - ) - - -def _tool_choice_to_openai(tool_choice: llm_request.ToolChoice | None) -> str | dict[str, Any] | None: - """Convert ARES internal ToolChoice to OpenAI Chat Completions format. - - Args: - tool_choice: ARES internal tool choice - - Returns: - Tool choice in OpenAI format: - - "auto": Model decides - - "required": Must use at least one tool - - "none": Must not use any tools - - {"type": "function", "function": {"name": "..."}}: Specific function - """ - if tool_choice is None: - return None - - if tool_choice == "auto": - return "auto" - elif tool_choice == "any": - return "required" # Map "any" to OpenAI's "required" - elif tool_choice == "none": - return "none" - elif isinstance(tool_choice, dict) and tool_choice.get("type") == "tool": - return { - "type": "function", - "function": {"name": tool_choice["name"]}, - } - - return None - - -def _tool_choice_from_openai( - tool_choice: str | dict[str, Any] | None, -) -> llm_request.ToolChoice | None: - """Convert OpenAI Chat Completions tool_choice to internal format. - - Args: - tool_choice: OpenAI tool choice parameter - - Returns: - Internal ToolChoice format - """ - if tool_choice is None: - return None - - if isinstance(tool_choice, str): - from typing import Literal - - result = {"auto": "auto", "required": "any", "none": "none"}.get(tool_choice) - if not result: - raise ValueError(f"Unsupported tool choice: {tool_choice}") - return cast(Literal["auto", "any", "none"], result) - - elif isinstance(tool_choice, dict): - choice_type = tool_choice.get("type") - if choice_type == "function": - # {"type": "function", "function": {"name": "x"}} -> {"type": "tool", "name": "x"} - function_data = tool_choice.get("function", {}) - if isinstance(function_data, dict) and "name" in function_data: - return llm_request.ToolChoiceTool(type="tool", name=function_data["name"]) - - return None - - -def to_external(request: llm_request.LLMRequest, *, strict: bool = True) -> dict[str, Any]: - """Convert ARES LLMRequest to OpenAI Chat Completions format. - - Args: - request: ARES internal request format - strict: If True, raise ValueError on information loss. If False, log warnings. - - Returns: - Dictionary of kwargs for openai.ChatCompletion.create() (without model) - - Raises: - ValueError: If strict=True and information would be lost in conversion - - Note: - Model parameter is NOT included - it should be added by the LLMClient - """ - # Check for information loss - lost_info = [] - if request.top_k is not None: - lost_info.append(f"top_k={request.top_k} (Claude-specific, not supported)") - if request.service_tier == "standard_only": - lost_info.append("service_tier='standard_only' (not supported by Chat API)") - if request.stop_sequences and len(request.stop_sequences) > 4: - lost_info.append( - f"stop_sequences truncated from {len(request.stop_sequences)} to 4 " - f"(Chat API limit: {request.stop_sequences[4:]} will be dropped)" - ) - - if lost_info: - msg = f"Converting to Chat Completions will lose information: {'; '.join(lost_info)}" - if strict: - raise ValueError(msg) - _LOGGER.warning(msg) - - # Convert messages, flattening ToolCallMessage into AssistantMessage.tool_calls - chat_messages: list[dict[str, Any]] = [] - pending_tool_calls: list[dict[str, Any]] = [] - - for msg in request.messages: - msg_dict = dict(msg) - - # ToolCallMessage → collect for previous assistant message - if "call_id" in msg_dict and "name" in msg_dict and "arguments" in msg_dict: - # This is a ToolCallMessage - pending_tool_calls.append( - { - "id": msg_dict["call_id"], - "type": "function", - "function": { - "name": msg_dict["name"], - "arguments": msg_dict["arguments"], - }, - } - ) - else: - # Flush any pending tool calls to the last assistant message - if pending_tool_calls and chat_messages: - last_msg = chat_messages[-1] - if last_msg.get("role") == "assistant": - last_msg["tool_calls"] = pending_tool_calls - pending_tool_calls = [] - else: - if strict: - role = last_msg.get("role") - raise ValueError(f"ToolCallMessage found but previous message is not assistant (role={role})") - _LOGGER.warning( - "ToolCallMessage found but previous message is not assistant, discarding tool calls" - ) - pending_tool_calls = [] - - # Add the current message - chat_messages.append(msg_dict) - - # Flush any remaining tool calls - if pending_tool_calls and chat_messages: - last_msg = chat_messages[-1] - if last_msg.get("role") == "assistant": - last_msg["tool_calls"] = pending_tool_calls - elif strict: - raise ValueError("ToolCallMessage at end but last message is not assistant") - - kwargs: dict[str, Any] = { - "messages": chat_messages, - } - - # Add system prompt as first message if present - if request.system_prompt: - kwargs["messages"] = [ - {"role": "system", "content": request.system_prompt}, - *kwargs["messages"], - ] - - # Add optional parameters (filter None values) - if request.max_output_tokens is not None: - kwargs["max_completion_tokens"] = request.max_output_tokens - if request.temperature is not None: - kwargs["temperature"] = request.temperature - if request.top_p is not None: - kwargs["top_p"] = request.top_p - if request.stream: - kwargs["stream"] = True - if request.tools: - kwargs["tools"] = [_tool_to_chat_completions(tool) for tool in request.tools] - if request.tool_choice is not None: - kwargs["tool_choice"] = _tool_choice_to_openai(request.tool_choice) - if request.metadata: - kwargs["metadata"] = request.metadata - if request.service_tier and request.service_tier != "standard_only": - kwargs["service_tier"] = request.service_tier - if request.stop_sequences: - # OpenAI Chat supports up to 4 stop sequences - kwargs["stop"] = request.stop_sequences[:4] - - return kwargs - - -def from_external( - kwargs: openai.types.chat.completion_create_params.CompletionCreateParams, - *, - strict: bool = True, -) -> llm_request.LLMRequest: - """Create LLMRequest from OpenAI Chat Completions API kwargs. - - Args: - kwargs: OpenAI Chat Completions API parameters - strict: If True, raise ValueError for unhandled parameters. If False, log warnings. - - Returns: - LLMRequest instance - - Raises: - ValueError: If strict=True and there are unhandled parameters - - Note: - Model parameter is ignored - it should be managed by the LLMClient - """ - # Define parameters we actually store/handle - handled_params = { - "model", # Handled by being ignored (LLMClient manages this) - "messages", - "max_completion_tokens", - "max_tokens", # Fallback for max_output_tokens - "temperature", - "top_p", - "stream", - "tools", - "tool_choice", - "metadata", - "service_tier", - "stop", - } - - # Check for unhandled parameters - unhandled = set(kwargs.keys()) - handled_params - if unhandled: - msg = f"Unhandled Chat Completions parameters (will be ignored): {sorted(unhandled)}" - if strict: - raise ValueError(msg) - _LOGGER.warning(msg) - - # Extract system prompt and filter messages - system_prompt = None - filtered_messages: list[llm_request.Message] = [] - - for msg in kwargs["messages"]: - role = msg.get("role") - - # Extract system/developer messages as system_prompt (use first one) - if role in ("system", "developer"): - if system_prompt is None: - content = msg.get("content", "") - system_prompt = llm_request._extract_string_content( - content, strict=strict, context=f"System/developer message content (role={role})" - ) - continue - - # Validate role is supported - if role not in llm_request._VALID_ROLES: - if strict: - raise ValueError(f"Unsupported message role: {role}. Must be one of {llm_request._VALID_ROLES}") - _LOGGER.warning("Skipping message with unsupported role: %s", role) - continue - - # Extract tool_calls from assistant messages - if role == "assistant" and "tool_calls" in msg and msg["tool_calls"]: - # Add assistant message without tool_calls (or with content only) - assistant_msg = dict(msg) - # Remove tool_calls from the message we store - tool_calls_list = assistant_msg.pop("tool_calls", []) - - # Validate content if present - if "content" in assistant_msg: - assistant_msg["content"] = llm_request._extract_string_content( - assistant_msg["content"], strict=strict, context="Assistant message content" - ) - - filtered_messages.append(cast(llm_request.Message, assistant_msg)) - - # Create separate ToolCallMessage for each tool call - if isinstance(tool_calls_list, list): - for tool_call in tool_calls_list: - if tool_call.get("type") == "function": - function = tool_call.get("function", {}) - filtered_messages.append( - cast( - llm_request.Message, - { - "call_id": tool_call.get("id", ""), - "name": function.get("name", ""), - "arguments": function.get("arguments", ""), - }, - ) - ) - else: - # Convert to our Message format, validating content - message_dict = dict(msg) - if "content" in message_dict: - message_dict["content"] = llm_request._extract_string_content( - message_dict["content"], strict=strict, context=f"Message content (role={role})" - ) - filtered_messages.append(cast(llm_request.Message, message_dict)) - - # Convert tools from OpenAI to Claude format - tools_param = kwargs.get("tools") - converted_tools: list[llm_request.Tool] | None = None - if tools_param: - converted_tools = [] - for tool in tools_param: - tool_type = tool.get("type") - if tool_type != "function": - if strict: - raise ValueError(f"Unsupported tool type: {tool_type}. Only 'function' tools are supported.") - _LOGGER.warning("Skipping tool with unsupported type: %s", tool_type) - continue - converted_tools.append(_tool_from_chat_completions(cast(openai.types.chat.ChatCompletionToolParam, tool))) - - # Handle stop sequences - convert single string to list - stop_param = kwargs.get("stop") - stop_sequences: list[str] | None = None - if isinstance(stop_param, list): - stop_sequences = stop_param - elif isinstance(stop_param, str): - stop_sequences = [stop_param] - - # Handle system prompt - extract string from various formats - final_system_prompt: str | None = None - if system_prompt: - final_system_prompt = llm_request._extract_string_content(system_prompt, strict=strict, context="System prompt") - - return llm_request.LLMRequest( - messages=filtered_messages, - max_output_tokens=kwargs.get("max_completion_tokens") or kwargs.get("max_tokens"), - temperature=kwargs.get("temperature"), - top_p=kwargs.get("top_p"), - stream=bool(kwargs.get("stream", False)), - tools=converted_tools, - tool_choice=_tool_choice_from_openai(cast(str | dict[str, Any] | None, kwargs.get("tool_choice"))), - metadata=cast(dict[str, Any] | None, kwargs.get("metadata")), - service_tier=kwargs.get("service_tier"), - stop_sequences=stop_sequences, - system_prompt=final_system_prompt, - ) diff --git a/src/ares/llms/openai_chat_converter_test.py b/src/ares/llms/openai_chat_converter_test.py deleted file mode 100644 index 1ca1d3705..000000000 --- a/src/ares/llms/openai_chat_converter_test.py +++ /dev/null @@ -1,305 +0,0 @@ -"""Unit tests for OpenAI Chat Completions converter.""" - -import openai.types.chat -import openai.types.chat.chat_completion_content_part_image_param -import openai.types.chat.completion_create_params -import openai.types.shared_params -import pytest - -from ares.llms import openai_chat_converter -from ares.llms import request as request_lib - - -class TestStructuredContentHandling: - """Tests for handling structured content (list of blocks) in Chat Completions API conversions.""" - - def test_from_chat_completion_with_structured_content_strict(self): - """Test that structured content in chat messages raises error in strict mode.""" - kwargs = openai.types.chat.completion_create_params.CompletionCreateParamsNonStreaming( - model="gpt-4", - messages=[ - openai.types.chat.ChatCompletionUserMessageParam( - role="user", - content=[ - openai.types.chat.ChatCompletionContentPartTextParam(type="text", text="What's in this image?"), - openai.types.chat.ChatCompletionContentPartImageParam( - type="image_url", - image_url=openai.types.chat.chat_completion_content_part_image_param.ImageURL( - url="https://example.com/image.png" - ), - ), - ], - ) - ], - ) - - with pytest.raises(ValueError, match=r"list of blocks.*structured content"): - openai_chat_converter.from_external(kwargs, strict=True) - - -class TestLLMRequestChatCompletionConversion: - """Tests for Chat Completions API conversion.""" - - def test_to_chat_completion_minimal(self): - """Test minimal conversion to Chat Completions format.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - ) - kwargs = openai_chat_converter.to_external(request) - - assert kwargs["messages"] == [{"role": "user", "content": "Hello"}] - assert "temperature" not in kwargs - assert "stream" not in kwargs - - def test_to_chat_completion_all_params(self): - """Test conversion with all common parameters.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - max_output_tokens=100, - temperature=0.7, - top_p=0.9, - stream=True, - tools=[ - { - "name": "test", - "description": "A test function", - "input_schema": {"type": "object", "properties": {}}, - } - ], - tool_choice="auto", - metadata={"user_id": "123"}, - service_tier="default", - stop_sequences=["STOP", "END"], - ) - kwargs = openai_chat_converter.to_external(request) - - assert kwargs["max_completion_tokens"] == 100 - assert kwargs["temperature"] == 0.7 - assert kwargs["top_p"] == 0.9 - assert kwargs["stream"] is True - # Tools should be converted to OpenAI format - assert kwargs["tools"] == [ - { - "type": "function", - "function": { - "name": "test", - "description": "A test function", - "parameters": {"type": "object", "properties": {}}, - }, - } - ] - assert kwargs["tool_choice"] == "auto" - assert kwargs["metadata"] == {"user_id": "123"} - assert kwargs["service_tier"] == "default" - assert kwargs["stop"] == ["STOP", "END"] - - def test_to_chat_completion_with_system_prompt(self): - """Test system prompt is added as first message.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - system_prompt="You are a helpful assistant.", - ) - kwargs = openai_chat_converter.to_external(request) - - assert len(kwargs["messages"]) == 2 - assert kwargs["messages"][0] == {"role": "system", "content": "You are a helpful assistant."} - assert kwargs["messages"][1] == {"role": "user", "content": "Hello"} - - def test_to_chat_completion_stop_sequences_truncated(self): - """Test that stop sequences are truncated to 4 (OpenAI limit).""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - stop_sequences=["A", "B", "C", "D", "E", "F"], - ) - kwargs = openai_chat_converter.to_external(request, strict=False) - - assert kwargs["stop"] == ["A", "B", "C", "D"] - - def test_to_chat_completion_excludes_top_k(self): - """Test that top_k (Claude-specific) is excluded.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - top_k=40, - ) - kwargs = openai_chat_converter.to_external(request, strict=False) - - assert "top_k" not in kwargs - - def test_to_chat_completion_excludes_standard_only_tier(self): - """Test that standard_only service tier is excluded.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - service_tier="standard_only", - ) - kwargs = openai_chat_converter.to_external(request, strict=False) - - assert "service_tier" not in kwargs - - def test_from_chat_completion_minimal(self): - """Test parsing minimal Chat Completions request.""" - kwargs: openai.types.chat.completion_create_params.CompletionCreateParams = { - "model": "gpt-4o", - "messages": [{"role": "user", "content": "Hello"}], - } - request = openai_chat_converter.from_external(kwargs) - - assert list(request.messages) == [{"role": "user", "content": "Hello"}] - assert request.max_output_tokens is None - assert request.temperature is None - - def test_from_chat_completion_all_params(self): - """Test parsing Chat Completions with all parameters.""" - kwargs: openai.types.chat.completion_create_params.CompletionCreateParams = { - "model": "gpt-4o", - "messages": [{"role": "user", "content": "Hello"}], - "max_completion_tokens": 100, - "temperature": 0.7, - "top_p": 0.9, - "stream": True, - "tools": [ - openai.types.chat.ChatCompletionToolParam( - type="function", - function=openai.types.shared_params.FunctionDefinition( - name="test", - description="A test function", - parameters={"type": "object", "properties": {"arg": {"type": "string"}}}, - ), - ) - ], - "tool_choice": "auto", - "metadata": {"user_id": "123"}, - "service_tier": "default", - "stop": ["STOP", "END"], - } - request = openai_chat_converter.from_external(kwargs) - - assert request.max_output_tokens == 100 - assert request.temperature == 0.7 - assert request.top_p == 0.9 - assert request.stream is True - # Tools are converted from OpenAI to Claude format internally - assert request.tools == [ - { - "name": "test", - "description": "A test function", - "input_schema": {"type": "object", "properties": {"arg": {"type": "string"}}}, - } - ] - assert request.tool_choice == "auto" - assert request.metadata == {"user_id": "123"} - assert request.service_tier == "default" - assert request.stop_sequences == ["STOP", "END"] - - def test_from_chat_completion_extracts_system_prompt(self): - """Test that system message is extracted to system_prompt.""" - kwargs: openai.types.chat.completion_create_params.CompletionCreateParams = { - "model": "gpt-4o", - "messages": [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": "Hello"}, - ], - } - request = openai_chat_converter.from_external(kwargs) - - assert request.system_prompt == "You are helpful." - assert list(request.messages) == [{"role": "user", "content": "Hello"}] - - def test_from_chat_completion_handles_max_tokens_fallback(self): - """Test that deprecated max_tokens is used as fallback.""" - kwargs: openai.types.chat.completion_create_params.CompletionCreateParams = { - "model": "gpt-4o", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 100, - } - request = openai_chat_converter.from_external(kwargs) - - assert request.max_output_tokens == 100 - - def test_to_chat_completion_flattens_tool_calls(self): - """Test that ToolCallMessage is flattened into AssistantMessage.tool_calls.""" - request = request_lib.LLMRequest( - messages=[ - request_lib.UserMessage(role="user", content="What's the weather?"), - request_lib.AssistantMessage(role="assistant", content=""), - request_lib.ToolCallMessage(call_id="call_123", name="get_weather", arguments='{"location":"LA"}'), - ], - ) - kwargs = openai_chat_converter.to_external(request) - - # Should have 2 messages (user + assistant with tool_calls) - assert len(kwargs["messages"]) == 2 - assert kwargs["messages"][1]["role"] == "assistant" - assert "tool_calls" in kwargs["messages"][1] - assert len(kwargs["messages"][1]["tool_calls"]) == 1 - - tool_call = kwargs["messages"][1]["tool_calls"][0] - assert tool_call["id"] == "call_123" - assert tool_call["type"] == "function" - assert tool_call["function"]["name"] == "get_weather" - assert tool_call["function"]["arguments"] == '{"location":"LA"}' - - def test_from_chat_completion_extracts_tool_calls(self): - """Test that AssistantMessage.tool_calls are extracted as separate ToolCallMessage.""" - kwargs: openai.types.chat.completion_create_params.CompletionCreateParams = { - "model": "gpt-4o", - "messages": [ - {"role": "user", "content": "What's the weather?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_789", - "type": "function", - "function": {"name": "get_weather", "arguments": '{"location":"Seattle"}'}, - } - ], - }, - ], - } - request = openai_chat_converter.from_external(kwargs) - - # Should have 3 messages: user, assistant, tool_call - assert len(request.messages) == 3 - assert request.messages[0].get("role") == "user" - assert request.messages[1].get("role") == "assistant" - - # Third message should be ToolCallMessage - tool_call_msg = request.messages[2] - assert tool_call_msg.get("call_id") == "call_789" - assert tool_call_msg.get("name") == "get_weather" - assert tool_call_msg.get("arguments") == '{"location":"Seattle"}' - - def test_roundtrip_chat_completion(self): - """Test that Chat Completions roundtrip preserves data.""" - original: openai.types.chat.completion_create_params.CompletionCreateParams = { - "model": "gpt-4o", - "messages": [{"role": "user", "content": "Hello"}], - "max_completion_tokens": 100, - "temperature": 0.7, - "top_p": 0.9, - "stream": True, - "tools": [ - openai.types.chat.ChatCompletionToolParam( - type="function", - function=openai.types.shared_params.FunctionDefinition( - name="test", - description="Test tool", - parameters={"type": "object", "properties": {"x": {"type": "number"}}}, - ), - ) - ], - "tool_choice": "auto", - "metadata": {"user_id": "123"}, - } - request = openai_chat_converter.from_external(original) - converted = openai_chat_converter.to_external(request) - - assert converted["messages"] == original["messages"] - assert converted["max_completion_tokens"] == original["max_completion_tokens"] - assert converted["temperature"] == original["temperature"] - assert converted["top_p"] == original["top_p"] - assert converted["stream"] == original["stream"] - assert converted["tools"] == original["tools"] - assert converted["tool_choice"] == original["tool_choice"] - assert converted["metadata"] == original["metadata"] diff --git a/src/ares/llms/openai_responses_converter.py b/src/ares/llms/openai_responses_converter.py deleted file mode 100644 index d28cedecf..000000000 --- a/src/ares/llms/openai_responses_converter.py +++ /dev/null @@ -1,435 +0,0 @@ -"""Converter for OpenAI Responses API format. - -This module provides bidirectional conversion between ARES's internal LLMRequest format -and the OpenAI Responses API format. The module itself conforms to the -RequestConverter Protocol through its to_external and from_external functions. - -Conversion Notes: - - stop_sequences are not supported in Responses API - - top_k is not supported (Claude-specific) - - service_tier="standard_only" is not supported - - system_prompt mapped to/from instructions parameter - - messages converted to/from input items -""" - -import logging -from typing import Any, cast - -import openai.types.responses -import openai.types.responses.response_create_params - -from ares.llms import request as llm_request - -_LOGGER = logging.getLogger(__name__) - - -def _tool_to_responses(tool: llm_request.Tool) -> openai.types.responses.FunctionToolParam: - """Convert Tool from ARES internal format to OpenAI Responses format. - - Args: - tool: Tool in ARES internal format (flat with input_schema) - - Returns: - Tool in OpenAI Responses format (flat with type, name, description, parameters) - """ - return openai.types.responses.FunctionToolParam( - type="function", - name=tool["name"], - description=tool["description"], - parameters=cast(dict[str, object], tool["input_schema"]), - strict=True, - ) - - -def _tool_from_responses(responses_tool: openai.types.responses.ToolParam) -> llm_request.Tool: - """Convert tool from OpenAI Responses format to ARES internal format. - - Args: - responses_tool: Tool in OpenAI Responses format (flat with type, name, parameters) - - Returns: - Tool in ARES internal format (flat with input_schema) - - Note: - Currently only supports FunctionToolParam. Other tool types are not converted. - """ - # Only handle FunctionToolParam for now - if responses_tool.get("type") == "function": - # Type guard: if type is "function", this is FunctionToolParam - func_tool = cast(openai.types.responses.FunctionToolParam, responses_tool) - parameters = func_tool.get("parameters") or {"type": "object", "properties": {}} - - # Validate that parameters is a valid JSONSchema - if not isinstance(parameters, dict): - raise ValueError(f"Tool parameters must be a dict, got {type(parameters)}") - if "type" not in parameters: - raise ValueError("Tool parameters must have a 'type' field") - - return llm_request.Tool( - name=func_tool["name"], - description=func_tool.get("description") or "", - input_schema=cast(llm_request.JSONSchema, parameters), - ) - # For other tool types, we can't convert them to Claude format - raise ValueError(f"Unsupported tool type for conversion: {responses_tool.get('type')}") - - -def _tool_choice_to_responses(tool_choice: llm_request.ToolChoice | None) -> str | dict[str, Any] | None: - """Convert ARES internal ToolChoice to OpenAI Responses format. - - Args: - tool_choice: ARES internal tool choice - - Returns: - Tool choice in OpenAI Responses format: - - "auto": Model decides - - "required": Must use at least one tool - - "none": Must not use any tools - - {"type": "function", "name": "..."}: Specific function (flat structure) - """ - if tool_choice is None: - return None - - if tool_choice == "auto": - return "auto" - elif tool_choice == "any": - return "required" # Map "any" to OpenAI's "required" - elif tool_choice == "none": - return "none" - elif isinstance(tool_choice, dict) and tool_choice.get("type") == "tool": - # Responses API uses flat structure: {"type": "function", "name": "..."} - return { - "type": "function", - "name": tool_choice["name"], - } - - return None - - -def _tool_choice_from_openai( - tool_choice: str | dict[str, Any] | None, -) -> llm_request.ToolChoice | None: - """Convert OpenAI Chat Completions tool_choice to internal format. - - Args: - tool_choice: OpenAI tool choice parameter - - Returns: - Internal ToolChoice format - """ - if tool_choice is None: - return None - - if isinstance(tool_choice, str): - from typing import Literal - - result = {"auto": "auto", "required": "any", "none": "none"}.get(tool_choice) - if not result: - raise ValueError(f"Unsupported tool choice: {tool_choice}") - return cast(Literal["auto", "any", "none"], result) - - elif isinstance(tool_choice, dict): - choice_type = tool_choice.get("type") - if choice_type == "function": - # {"type": "function", "function": {"name": "x"}} -> {"type": "tool", "name": "x"} - # Responses API uses flat format: {"type": "function", "name": "x"} - if "name" in tool_choice: - # Flat format (Responses API) - return llm_request.ToolChoiceTool(type="tool", name=tool_choice["name"]) - else: - # Nested format (Chat API) - function_data = tool_choice.get("function", {}) - if isinstance(function_data, dict) and "name" in function_data: - return llm_request.ToolChoiceTool(type="tool", name=function_data["name"]) - - return None - - -def _messages_to_responses_input(messages: list[llm_request.Message]) -> list[dict[str, Any]]: - """Convert messages from internal format to Responses input items. - - Args: - messages: List of messages in internal format - - Returns: - List of input items for Responses API - - Note: - - ToolCallMessage → function_call items - - ToolCallResponseMessage → function_call_output items - - Other messages → message items - """ - input_items = [] - for msg in messages: - msg_dict = dict(msg) # Convert to regular dict for type safety - - # ToolCallMessage (tool invocation) → function_call - if "call_id" in msg_dict and "name" in msg_dict and "arguments" in msg_dict: - item: dict[str, Any] = { - "type": "function_call", - "call_id": msg_dict["call_id"], - "name": msg_dict["name"], - "arguments": msg_dict["arguments"], - } - input_items.append(item) - - # ToolCallResponseMessage (tool result) → function_call_output - elif msg_dict.get("role") == "tool": - item = { - "type": "function_call_output", - "output": msg_dict.get("content", ""), - } - # Include call_id if present (required for routing) - if "tool_call_id" in msg_dict: - item["call_id"] = msg_dict["tool_call_id"] - - input_items.append(item) - - # Regular messages → message items - else: - role = msg_dict.get("role") - item = { - "type": "message", - "role": role, - "content": msg_dict.get("content", ""), - } - - # Include optional name field if present - if "name" in msg_dict: - item["name"] = msg_dict["name"] - - input_items.append(item) - - return input_items - - -def to_external(request: llm_request.LLMRequest, *, strict: bool = True) -> dict[str, Any]: - """Convert ARES LLMRequest to OpenAI Responses format. - - Args: - request: ARES internal request format - strict: If True, raise ValueError on information loss. If False, log warnings. - - Returns: - Dictionary of kwargs for openai.Responses.create() (without model) - - Raises: - ValueError: If strict=True and information would be lost in conversion - - Note: - Model parameter is NOT included - it should be added by the LLMClient - """ - # Check for information loss - lost_info = [] - if request.stop_sequences: - lost_info.append(f"stop_sequences={request.stop_sequences} (not supported by Responses API)") - if request.top_k is not None: - lost_info.append(f"top_k={request.top_k} (Claude-specific, not supported)") - if request.service_tier == "standard_only": - lost_info.append("service_tier='standard_only' (not supported by Responses API)") - - if lost_info: - msg = f"Converting to Responses will lose information: {'; '.join(lost_info)}" - if strict: - raise ValueError(msg) - _LOGGER.warning(msg) - - kwargs: dict[str, Any] = { - "input": _messages_to_responses_input(request.messages), - } - - if request.system_prompt: - kwargs["instructions"] = request.system_prompt - - if request.max_output_tokens is not None: - kwargs["max_output_tokens"] = request.max_output_tokens - if request.temperature is not None: - kwargs["temperature"] = request.temperature - if request.top_p is not None: - kwargs["top_p"] = request.top_p - if request.stream: - kwargs["stream"] = True - if request.tools: - kwargs["tools"] = [_tool_to_responses(tool) for tool in request.tools] - if request.tool_choice is not None: - kwargs["tool_choice"] = _tool_choice_to_responses(request.tool_choice) - if request.metadata: - kwargs["metadata"] = request.metadata - if request.service_tier and request.service_tier != "standard_only": - kwargs["service_tier"] = request.service_tier - - return kwargs - - -def from_external( - kwargs: openai.types.responses.response_create_params.ResponseCreateParamsBase, - *, - strict: bool = True, -) -> llm_request.LLMRequest: - """Create LLMRequest from OpenAI Responses API kwargs. - - Args: - kwargs: OpenAI Responses API parameters - strict: If True, raise ValueError for unhandled parameters. If False, log warnings. - - Returns: - LLMRequest instance - - Raises: - ValueError: If strict=True and there are unhandled parameters - - Note: - Model parameter is ignored - it should be managed by the LLMClient - """ - # Define parameters we actually store/handle - handled_params = { - "model", # Handled by being ignored (LLMClient manages this) - "input", - "max_output_tokens", - "temperature", - "top_p", - "stream", - "tools", - "tool_choice", - "metadata", - "service_tier", - "instructions", - } - - # Check for unhandled parameters - unhandled = set(kwargs.keys()) - handled_params - if unhandled: - msg = f"Unhandled Responses parameters (will be ignored): {sorted(unhandled)}" - if strict: - raise ValueError(msg) - _LOGGER.warning(msg) - - # Convert input items to messages - input_param = kwargs.get("input", []) - filtered_messages: list[llm_request.Message] = [] - - if isinstance(input_param, str): - filtered_messages = [llm_request.UserMessage(role="user", content=input_param)] - elif isinstance(input_param, list): - for item in input_param: - item_type = item.get("type") - - # Handle function_call (tool invocations) - if item_type == "function_call": - call_id = item.get("call_id") - name = item.get("name") - arguments = item.get("arguments", "") - - if call_id is None or name is None: - if strict: - raise ValueError( - f"Tool call (function_call) missing required 'call_id' or 'name' field. Item: {item}" - ) - _LOGGER.warning("Tool call (function_call) missing required fields, skipping. Item: %s", item) - continue - - # Create ToolCallMessage - filtered_messages.append( - cast( - llm_request.Message, - { - "call_id": call_id, - "name": name, - "arguments": arguments if isinstance(arguments, str) else str(arguments), - }, - ) - ) - - # Handle function_call_output (tool results) - elif item_type == "function_call_output": - call_id = item.get("call_id") - output = item.get("output", "") - output_str = output if isinstance(output, str) else str(output) - - if call_id is None: - if strict: - raise ValueError( - "Tool result (function_call_output) missing required 'call_id' field for routing. " - f"Output: {output_str[:50]}..." - ) - _LOGGER.warning( - "Tool result (function_call_output) missing 'call_id' field (output: %s...). " - "This may cause routing issues.", - output_str[:50], - ) - # Create tool message without call_id - filtered_messages.append(cast(llm_request.Message, {"role": "tool", "content": output_str})) - else: - # Create tool message with call_id - filtered_messages.append( - cast(llm_request.Message, {"role": "tool", "content": output_str, "tool_call_id": call_id}) - ) - - # Handle regular messages - elif item_type == "message": - role = item.get("role") - - # Validate role is supported - if role not in llm_request._VALID_ROLES: - if strict: - raise ValueError(f"Unsupported message role: {role}. Must be one of {llm_request._VALID_ROLES}") - _LOGGER.warning("Skipping message with unsupported role: %s", role) - continue - - # Extract content - use helper to detect unsupported block formats - content_param = item.get("content", "") - content_str = llm_request._extract_string_content( - content_param, strict=strict, context=f"Message content (role={role})" - ) - - # Build message dict with required fields - message_dict: dict[str, Any] = {"role": role, "content": content_str} - - # Include optional name field if present - if "name" in item: - message_dict["name"] = item["name"] - - # Cast to Message after validating role and building dict - filtered_messages.append(cast(llm_request.Message, message_dict)) - - # Convert tools from Responses format to Claude format - tools_param = kwargs.get("tools") - converted_tools: list[llm_request.Tool] | None = None - if tools_param: - temp_tools: list[llm_request.Tool] = [] - for tool in tools_param: - try: - temp_tools.append(_tool_from_responses(tool)) - except ValueError as e: - if strict: - raise - _LOGGER.warning("Skipping tool that cannot be converted: %s", e) - # Only set converted_tools if we successfully converted at least one tool - if temp_tools: - converted_tools = temp_tools - - # Convert tool_choice from Responses flat format to Chat nested format - # Responses: {"type": "function", "name": "..."} - # Chat: {"type": "function", "function": {"name": "..."}} - tool_choice_param = kwargs.get("tool_choice") - if ( - isinstance(tool_choice_param, dict) - and tool_choice_param.get("type") == "function" - and "name" in tool_choice_param - ): - tool_choice_param = {"type": "function", "function": {"name": tool_choice_param["name"]}} - - resolved_tool_choice = _tool_choice_from_openai(cast(str | dict[str, Any] | None, tool_choice_param)) - - return llm_request.LLMRequest( - messages=filtered_messages, - max_output_tokens=kwargs.get("max_output_tokens"), - temperature=kwargs.get("temperature"), - top_p=kwargs.get("top_p"), - stream=bool(kwargs.get("stream", False)), - tools=converted_tools, - tool_choice=resolved_tool_choice, - metadata=cast(dict[str, Any] | None, kwargs.get("metadata")), - service_tier=kwargs.get("service_tier"), - system_prompt=kwargs.get("instructions"), - ) diff --git a/src/ares/llms/openai_responses_converter_test.py b/src/ares/llms/openai_responses_converter_test.py deleted file mode 100644 index ce1279178..000000000 --- a/src/ares/llms/openai_responses_converter_test.py +++ /dev/null @@ -1,406 +0,0 @@ -"""Unit tests for OpenAI Responses API converter.""" - -from typing import cast - -import openai.types.responses -import openai.types.responses.response_create_params -import pytest - -from ares.llms import openai_responses_converter -from ares.llms import request as request_lib - - -class TestStructuredContentHandling: - """Tests for handling structured content (list of blocks) in Responses API conversions.""" - - def test_from_responses_with_structured_content_strict(self): - """Test that structured content raises error in strict mode.""" - kwargs = openai.types.responses.response_create_params.ResponseCreateParamsBase( - model="gpt-4", - input=[ - openai.types.responses.EasyInputMessageParam( - type="message", - role="user", - content=[ - openai.types.responses.ResponseInputTextParam(type="input_text", text="Hello"), - openai.types.responses.ResponseInputImageParam( - type="input_image", detail="auto", image_url="data:image/png;base64,..." - ), - ], - ) - ], - ) - - with pytest.raises(ValueError, match=r"list of blocks.*structured content"): - openai_responses_converter.from_external(kwargs, strict=True) - - def test_from_responses_with_structured_content_non_strict(self): - """Test that structured content returns empty string in non-strict mode.""" - kwargs = openai.types.responses.response_create_params.ResponseCreateParamsBase( - model="gpt-4", - input=[ - openai.types.responses.EasyInputMessageParam( - type="message", - role="user", - content=[ - openai.types.responses.ResponseInputTextParam(type="input_text", text="Hello"), - openai.types.responses.ResponseInputImageParam( - type="input_image", detail="auto", image_url="data:image/png;base64,..." - ), - ], - ) - ], - ) - - # Should not raise, but content will be empty - request = openai_responses_converter.from_external(kwargs, strict=False) - assert len(request.messages) == 1 - assert request.messages[0].get("content") == "" - - -class TestLLMRequestResponsesConversion: - """Tests for Responses API conversion.""" - - def test_to_responses_minimal(self): - """Test minimal conversion to Responses format.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - ) - kwargs = openai_responses_converter.to_external(request) - - assert kwargs["input"] == [{"type": "message", "role": "user", "content": "Hello"}] - assert "temperature" not in kwargs - - def test_to_responses_all_params(self): - """Test conversion with all common parameters.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - max_output_tokens=100, - temperature=0.7, - top_p=0.9, - stream=True, - tools=[ - { - "name": "test", - "description": "A test function", - "input_schema": {"type": "object", "properties": {}}, - } - ], - tool_choice="auto", - metadata={"user_id": "123"}, - service_tier="default", - ) - kwargs = openai_responses_converter.to_external(request) - - assert kwargs["max_output_tokens"] == 100 - assert kwargs["temperature"] == 0.7 - assert kwargs["top_p"] == 0.9 - assert kwargs["stream"] is True - # Tools should be converted to Responses format (flat structure) - assert kwargs["tools"] == [ - { - "type": "function", - "name": "test", - "description": "A test function", - "parameters": {"type": "object", "properties": {}}, - "strict": True, - } - ] - assert kwargs["tool_choice"] == "auto" - assert kwargs["metadata"] == {"user_id": "123"} - assert kwargs["service_tier"] == "default" - - def test_to_responses_with_instructions(self): - """Test system prompt is mapped to instructions.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - system_prompt="You are a helpful assistant.", - ) - kwargs = openai_responses_converter.to_external(request) - - assert kwargs["instructions"] == "You are a helpful assistant." - - def test_to_responses_excludes_stop_sequences(self): - """Test that stop_sequences are not included (not supported).""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - stop_sequences=["STOP"], - ) - kwargs = openai_responses_converter.to_external(request, strict=False) - - assert "stop_sequences" not in kwargs - assert "stop" not in kwargs - - def test_to_responses_excludes_top_k(self): - """Test that top_k is excluded.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - top_k=40, - ) - kwargs = openai_responses_converter.to_external(request, strict=False) - - assert "top_k" not in kwargs - - def test_from_responses_minimal(self): - """Test parsing minimal Responses request.""" - kwargs: openai.types.responses.response_create_params.ResponseCreateParams = { - "model": "gpt-4o", - "input": "Hello", - } - request = openai_responses_converter.from_external(kwargs) - - assert list(request.messages) == [{"role": "user", "content": "Hello"}] - - def test_from_responses_all_params(self): - """Test parsing Responses with all parameters.""" - tool: openai.types.responses.FunctionToolParam = openai.types.responses.FunctionToolParam( - type="function", - name="test", - description="A test function", - parameters={"type": "object", "properties": {"x": {"type": "number"}}}, - strict=None, - ) - kwargs: openai.types.responses.response_create_params.ResponseCreateParams = { - "model": "gpt-4o", - "input": [{"type": "message", "role": "user", "content": "Hello"}], - "max_output_tokens": 100, - "temperature": 0.7, - "top_p": 0.9, - "stream": True, - "tools": [tool], - "tool_choice": "auto", - "metadata": {"user_id": "123"}, - "service_tier": "default", - "instructions": "You are helpful.", - } - request = openai_responses_converter.from_external(kwargs) - - assert request.max_output_tokens == 100 - assert request.temperature == 0.7 - assert request.top_p == 0.9 - assert request.stream is True - # Tools should be converted from OpenAI to Claude format - assert request.tools == [ - { - "name": "test", - "description": "A test function", - "input_schema": {"type": "object", "properties": {"x": {"type": "number"}}}, - } - ] - assert request.tool_choice == "auto" - assert request.metadata == {"user_id": "123"} - assert request.service_tier == "default" - assert request.system_prompt == "You are helpful." - - def test_from_responses_string_input(self): - """Test parsing Responses with string input.""" - kwargs: openai.types.responses.response_create_params.ResponseCreateParams = { - "input": "Hello, world!", - } - request = openai_responses_converter.from_external(kwargs) - - assert list(request.messages) == [{"role": "user", "content": "Hello, world!"}] - - def test_from_responses_list_input(self): - """Test parsing Responses with list input.""" - kwargs: openai.types.responses.response_create_params.ResponseCreateParams = { - "model": "gpt-4o", - "input": [ - {"type": "message", "role": "user", "content": "Hello"}, - {"type": "message", "role": "assistant", "content": "Hi!"}, - ], - } - request = openai_responses_converter.from_external(kwargs) - - assert list(request.messages) == [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi!"}, - ] - - def test_to_responses_includes_tool_call_id(self): - """Test that tool messages are converted to function_call_output format with call_id.""" - request = request_lib.LLMRequest( - messages=[ - request_lib.UserMessage(role="user", content="What's the weather?"), - request_lib.AssistantMessage(role="assistant", content="Let me check..."), - request_lib.ToolCallResponseMessage(role="tool", content="Sunny, 72°F", tool_call_id="call_123"), - ], - ) - kwargs = openai_responses_converter.to_external(request) - - # Tool message should be converted to function_call_output format - tool_output = kwargs["input"][2] - assert tool_output["type"] == "function_call_output" - assert tool_output["output"] == "Sunny, 72°F" - assert tool_output["call_id"] == "call_123" - - def test_from_responses_reads_tool_call_id(self): - """Test that tool_call_id is read from Responses input items (function_call_output format).""" - inputs: list[openai.types.responses.ResponseInputItemParam] = [ - openai.types.responses.EasyInputMessageParam(type="message", role="user", content="What's the weather?"), - openai.types.responses.response_input_item_param.FunctionCallOutput( - type="function_call_output", call_id="call_456", output="Sunny, 72°F" - ), - ] - kwargs = openai.types.responses.response_create_params.ResponseCreateParamsNonStreaming( - model="gpt-4o", - input=inputs, - ) - request = openai_responses_converter.from_external(kwargs) - - tool_msg = request.messages[1] - assert tool_msg.get("role") == "tool" - assert tool_msg.get("content") == "Sunny, 72°F" - assert tool_msg.get("tool_call_id") == "call_456" - - def test_from_responses_tool_missing_id_strict(self): - """Test that missing call_id raises error in strict mode.""" - # Note: This tests a malformed input - function_call_output requires call_id - # We construct a dict to simulate a malformed payload - inputs: list[dict[str, str]] = [ - {"type": "function_call_output", "output": "Result"}, # Missing required call_id - ] - kwargs = openai.types.responses.response_create_params.ResponseCreateParamsNonStreaming( - model="gpt-4o", - input=inputs, # type: ignore[arg-type] - ) - with pytest.raises(ValueError, match=r"Tool .*missing required.*call_id"): - openai_responses_converter.from_external(kwargs, strict=True) - - def test_from_responses_tool_missing_id_non_strict(self): - """Test that missing call_id logs warning in non-strict mode.""" - # Note: This tests a malformed input - function_call_output requires call_id - inputs: list[dict[str, str]] = [ - {"type": "function_call_output", "output": "Result"}, # Missing required call_id - ] - kwargs = openai.types.responses.response_create_params.ResponseCreateParamsNonStreaming( - model="gpt-4o", - input=inputs, # type: ignore[arg-type] - ) - request = openai_responses_converter.from_external(kwargs, strict=False) - - # Should still parse the message but without tool_call_id - assert len(request.messages) == 1 - assert request.messages[0].get("role") == "tool" - assert request.messages[0].get("content") == "Result" - assert "tool_call_id" not in request.messages[0] - - def test_to_responses_includes_tool_calls(self): - """Test that ToolCallMessage is converted to function_call format.""" - request = request_lib.LLMRequest( - messages=[ - request_lib.UserMessage(role="user", content="What's the weather?"), - request_lib.AssistantMessage(role="assistant", content="Let me check..."), - request_lib.ToolCallMessage(call_id="call_123", name="get_weather", arguments='{"location":"SF"}'), - ], - ) - kwargs = openai_responses_converter.to_external(request) - - # ToolCallMessage should be converted to function_call format - tool_call = kwargs["input"][2] - assert tool_call["type"] == "function_call" - assert tool_call["call_id"] == "call_123" - assert tool_call["name"] == "get_weather" - assert tool_call["arguments"] == '{"location":"SF"}' - - def test_from_responses_reads_tool_calls(self): - """Test that function_call items are read as ToolCallMessage.""" - inputs: list[openai.types.responses.ResponseInputItemParam] = [ - openai.types.responses.EasyInputMessageParam(type="message", role="user", content="What's the weather?"), - openai.types.responses.response_function_tool_call_param.ResponseFunctionToolCallParam( - type="function_call", - call_id="call_456", - name="get_weather", - arguments='{"location":"NYC"}', - ), - ] - kwargs = openai.types.responses.response_create_params.ResponseCreateParamsNonStreaming( - model="gpt-4o", - input=inputs, - ) - request = openai_responses_converter.from_external(kwargs) - - tool_call_msg = request.messages[1] - assert "call_id" in tool_call_msg - assert tool_call_msg["call_id"] == "call_456" # type: ignore[typeddict-item] - assert tool_call_msg["name"] == "get_weather" # type: ignore[typeddict-item] - assert tool_call_msg["arguments"] == '{"location":"NYC"}' # type: ignore[typeddict-item] - - def test_responses_roundtrip_preserves_tool_call_id(self): - """Test that tool_call_id is preserved in roundtrip conversion (via function_call_output).""" - original = request_lib.LLMRequest( - messages=[ - request_lib.UserMessage(role="user", content="Check weather"), - request_lib.ToolCallResponseMessage(role="tool", content="Sunny", tool_call_id="call_789"), - ], - ) - # Convert to Responses format (should use function_call_output) - responses_kwargs = openai_responses_converter.to_external(original) - - # Verify it uses function_call_output format - assert responses_kwargs["input"][1]["type"] == "function_call_output" - assert responses_kwargs["input"][1]["call_id"] == "call_789" - - # Convert back and verify tool_call_id is preserved - request2 = openai_responses_converter.from_external( - cast(openai.types.responses.response_create_params.ResponseCreateParamsBase, responses_kwargs) - ) - - tool_msg = cast(request_lib.ToolCallResponseMessage, request2.messages[1]) - assert tool_msg["tool_call_id"] == "call_789" - - def test_from_responses_unsupported_tool_type_strict(self): - """Test that unsupported tool type raises error in strict mode.""" - # Create a tool with unsupported type (not "function") - unsupported_tool: dict[str, str] = {"type": "browser_tool", "name": "search_web"} - - kwargs: openai.types.responses.response_create_params.ResponseCreateParams = { - "model": "gpt-4o", - "input": "Hello", - "tools": [unsupported_tool], # type: ignore[list-item] - } - - with pytest.raises(ValueError, match=r"Unsupported tool type for conversion"): - openai_responses_converter.from_external(kwargs, strict=True) - - def test_from_responses_unsupported_tool_type_non_strict(self): - """Test that unsupported tool type is skipped in non-strict mode.""" - # Mix of supported and unsupported tools - supported_tool: openai.types.responses.FunctionToolParam = openai.types.responses.FunctionToolParam( - type="function", - name="valid_tool", - description="A valid function tool", - parameters={"type": "object", "properties": {}}, - strict=None, - ) - unsupported_tool: dict[str, str] = {"type": "browser_tool", "name": "search_web"} - - kwargs: openai.types.responses.response_create_params.ResponseCreateParams = { - "model": "gpt-4o", - "input": "Hello", - "tools": [supported_tool, unsupported_tool], # type: ignore[list-item] - } - - # Should not raise, but only include the supported tool - request = openai_responses_converter.from_external(kwargs, strict=False) - - assert request.tools is not None - assert len(request.tools) == 1 - assert request.tools[0]["name"] == "valid_tool" - - def test_from_responses_all_unsupported_tools_non_strict(self): - """Test that tools is None when all tools are unsupported in non-strict mode.""" - # Only unsupported tools - unsupported_tool1: dict[str, str] = {"type": "browser_tool", "name": "search_web"} - unsupported_tool2: dict[str, str] = {"type": "computer_tool", "name": "run_code"} - - kwargs: openai.types.responses.response_create_params.ResponseCreateParams = { - "model": "gpt-4o", - "input": "Hello", - "tools": [unsupported_tool1, unsupported_tool2], # type: ignore[list-item] - } - - # Should not raise, but tools should be None since no valid tools were converted - request = openai_responses_converter.from_external(kwargs, strict=False) - - assert request.tools is None diff --git a/src/ares/llms/queue_mediated_client.py b/src/ares/llms/queue_mediated_client.py index 588231882..29575d679 100644 --- a/src/ares/llms/queue_mediated_client.py +++ b/src/ares/llms/queue_mediated_client.py @@ -3,9 +3,10 @@ import asyncio import dataclasses +from linguafranca import types as lft + from ares import async_utils from ares.llms import llm_clients -from ares.llms import request from ares.llms import response @@ -40,11 +41,11 @@ class QueueMediatedLLMClient(llm_clients.LLMClient): awaiting __call__ will block forever. """ - q: asyncio.Queue[async_utils.ValueAndFuture[request.LLMRequest, response.LLMResponse]] = dataclasses.field( - default_factory=asyncio.Queue + q: asyncio.Queue[async_utils.ValueAndFuture[lft.OpenResponsesRequest, response.InferenceResult]] = ( + dataclasses.field(default_factory=asyncio.Queue) ) - async def __call__(self, req: request.LLMRequest) -> response.LLMResponse: - future = asyncio.Future[response.LLMResponse]() + async def __call__(self, req: lft.OpenResponsesRequest) -> response.InferenceResult: + future = asyncio.Future[response.InferenceResult]() await self.q.put(async_utils.ValueAndFuture(value=req, future=future)) return await future diff --git a/src/ares/llms/queue_mediated_client_test.py b/src/ares/llms/queue_mediated_client_test.py new file mode 100644 index 000000000..d3cac73af --- /dev/null +++ b/src/ares/llms/queue_mediated_client_test.py @@ -0,0 +1,27 @@ +"""Tests for the queue-mediated LLM client.""" + +import asyncio + +import pytest + +from ares.llms import open_responses +from ares.llms import queue_mediated_client +from ares.llms import response + + +@pytest.mark.asyncio +async def test_queue_mediated_client_roundtrips_canonical_requests(): + client = queue_mediated_client.QueueMediatedLLMClient() + request = open_responses.make_request([open_responses.user_message("Hello")]) + + async def answer_request() -> None: + queued = await client.q.get() + assert queued.value == request + lf_response = response.make_response("Hi", input_tokens=1, output_tokens=1) + queued.future.set_result(response.InferenceResult(response=lf_response, cost=0.0)) + + answer_task = asyncio.create_task(answer_request()) + llm_response = await client(request) + await answer_task + + assert response.extract_text_content(llm_response.response) == "Hi" diff --git a/src/ares/llms/request.py b/src/ares/llms/request.py deleted file mode 100644 index 6e5422ccb..000000000 --- a/src/ares/llms/request.py +++ /dev/null @@ -1,229 +0,0 @@ -"""Unified LLM request abstraction supporting multiple API formats.""" - -import dataclasses -import logging -from typing import Any, Literal, NotRequired, Protocol, Required, TypedDict - -_LOGGER = logging.getLogger(__name__) - - -def _extract_string_content(content: Any, *, strict: bool = True, context: str = "content") -> str: - """Extract string from content, raising error for unsupported block formats. - - Args: - content: Content value - should be a string - strict: If True, raise ValueError for non-string content. If False, log warning and return empty string. - context: Description of where this content came from (for error messages) - - Returns: - The content string - - Raises: - ValueError: If strict=True and content is not a plain string - - Note: - This function currently does NOT support content blocks (lists of text/image/tool blocks). - If you encounter a ValueError about list content, this means the API returned structured - content that needs proper handling - see the issue about extracting text from blocks. - """ - if isinstance(content, str): - return content - - if not content: - return "" - - # Non-string content - this is where we lose information! - content_type = type(content).__name__ - preview = str(content)[:100] if content else "" - - if isinstance(content, list): - msg = ( - f"{context} is a list of blocks (structured content), but we only support plain strings. " - f"This will lose information. Preview: {preview}..." - ) - else: - msg = f"{context} has unsupported type {content_type}, expected str. Preview: {preview}..." - - if strict: - raise ValueError(msg) - - _LOGGER.warning(msg) - return "" - - -class UserMessage(TypedDict): - """User message in a conversation.""" - - role: Literal["user"] - content: str - name: NotRequired[str] # Optional - for identifying the user - - -class AssistantMessage(TypedDict, total=False): - """Assistant message in a conversation.""" - - role: Required[Literal["assistant"]] - content: str # Optional - might not have content if tool_calls present - name: str # Optional - tool_calls: list[dict[str, Any]] # Optional - for tool usage - - -class ToolCallMessage(TypedDict): - """Tool call message (tool invocation by assistant). - - Note: In Chat Completions, these are embedded in AssistantMessage.tool_calls. - In Responses/Messages APIs, these are separate items. - """ - - call_id: str - name: str - arguments: str - - -class ToolCallResponseMessage(TypedDict): - """Tool call response message (tool result).""" - - role: Literal["tool"] - content: str - tool_call_id: str - name: NotRequired[str] # Optional - - -# Union type for all supported message types -Message = UserMessage | AssistantMessage | ToolCallMessage | ToolCallResponseMessage - -# Valid roles (excludes system/developer which go in system_prompt) -_VALID_ROLES = frozenset(["user", "assistant", "tool"]) - - -class JSONSchema(TypedDict): - """JSON Schema definition for tool parameters/input.""" - - type: Literal["object"] - properties: dict[str, Any] - required: NotRequired[list[str]] # Optional - list of required property names - # Additional schema fields (additionalProperties, etc.) can be passed via properties - - -class Tool(TypedDict): - """Unified tool definition. - - Uses Claude's simpler format internally (flat structure with input_schema). - Converts to/from OpenAI's nested format (type: "function" with function.parameters). - """ - - name: str - description: str - input_schema: JSONSchema - - -class ToolChoiceTool(TypedDict): - """Tool choice: model must use a specific named tool.""" - - type: Literal["tool"] - name: str - - -# Internal tool choice format -# - "auto": Model decides whether to use tools -# - "any": Model must use at least one tool (maps to OpenAI "required") -# - "none": Model must not use any tools -# - ToolChoiceTool: Model must use a specific named tool -ToolChoice = Literal["auto", "any", "none"] | ToolChoiceTool - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class LLMRequest: - """Unified request format for OpenAI Chat Completions, OpenAI Responses, and Claude Messages APIs. - - This class provides a common abstraction over three major LLM API formats, making it easy - to convert between them. It includes the 9 core parameters that exist across all APIs. - - Core Parameters (all APIs): - messages: List of user/assistant/tool messages (system messages go in system_prompt) - max_output_tokens: Maximum tokens to generate (field name varies by API) - temperature: Sampling temperature (range 0-2 for OpenAI, auto-converted to 0-1 for Claude) - top_p: Nucleus sampling parameter - stream: Enable streaming responses - tools: Tool definitions (schema varies by API, will be converted) - tool_choice: Tool selection strategy (options vary by API, will be converted) - metadata: Custom key-value pairs (location varies by API) - - Extended Parameters (partial support): - service_tier: Processing tier (options differ by API) - stop_sequences: Stop sequences (not supported in OpenAI Responses) - system_prompt: System instructions (single source of truth, not in messages list) - top_k: Top-K sampling (Claude only) - - Note: - - Model is NOT stored here - it should be managed by the LLMClient - - Messages only include user/assistant/tool roles (system/developer go in system_prompt) - - Temperature is stored in OpenAI range (0-2), converted to Claude range (0-1) on export - - Tool schemas are converted as needed for each API - - Some parameters may be lost or unsupported when converting between APIs - """ - - messages: list[Message] - max_output_tokens: int | None = None - temperature: float | None = None - top_p: float | None = None - stream: bool = False - tools: list[Tool] | None = None - tool_choice: ToolChoice | None = None - metadata: dict[str, Any] | None = None - service_tier: Literal["auto", "default", "flex", "scale", "priority", "standard_only"] | None = None - stop_sequences: list[str] | None = None - system_prompt: str | None = None - top_k: int | None = None - - -class RequestConverter[RequestType](Protocol): - """Converts between ARES LLMRequest and external API formats. - - This protocol defines the interface for bidirectional conversion between ARES's internal - LLMRequest format and external API request formats (OpenAI Chat Completions, OpenAI Responses, - Anthropic Messages, etc.). - - Type Parameters: - RequestType: The external API's request parameters type (e.g., dict[str, Any] for kwargs) - - Note: - Implementations are provided as modules with module-level to_external() and from_external() - functions. The module itself conforms to this Protocol through structural subtyping. - The model parameter is NOT included in conversions - it should be managed by the LLMClient. - - Available converters: - - openai_chat_converter: OpenAI Chat Completions API - - openai_responses_converter: OpenAI Responses API - - anthropic_converter: Anthropic Messages API - """ - - def to_external(self, request: LLMRequest, *, strict: bool = True) -> RequestType: - """Convert ARES LLMRequest to external API format. - - Args: - request: ARES internal request format - strict: If True, raise ValueError on information loss. If False, log warnings. - - Returns: - Request parameters in external API format (without model parameter) - - Raises: - ValueError: If strict=True and information would be lost in conversion - """ - ... - - def from_external(self, request: RequestType, *, strict: bool = True) -> LLMRequest: - """Convert external API format to ARES LLMRequest. - - Args: - request: External API request parameters - strict: If True, raise ValueError for unhandled parameters. If False, log warnings. - - Returns: - LLMRequest instance - - Raises: - ValueError: If strict=True and there are unhandled parameters - """ - ... diff --git a/src/ares/llms/request_test.py b/src/ares/llms/request_test.py deleted file mode 100644 index 1e7340979..000000000 --- a/src/ares/llms/request_test.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Integration tests for cross-API conversions between different LLM request formats.""" - -from typing import cast - -import anthropic.types -import openai.types.chat.completion_create_params -import openai.types.responses.response_create_params - -from ares.llms import anthropic_converter -from ares.llms import openai_chat_converter -from ares.llms import openai_responses_converter -from ares.llms import request as request_lib - - -class TestLLMRequestCrossAPIConversion: - """Tests for converting between different APIs.""" - - def test_chat_to_responses_to_chat(self): - """Test Chat -> Responses -> Chat roundtrip.""" - original_chat: openai.types.chat.completion_create_params.CompletionCreateParams = { - "model": "gpt-4o", - "messages": [{"role": "user", "content": "Hello"}], - "max_completion_tokens": 100, - "temperature": 0.7, - } - request = openai_chat_converter.from_external(original_chat) - responses_kwargs = openai_responses_converter.to_external(request) - request2 = openai_responses_converter.from_external( - cast(openai.types.responses.response_create_params.ResponseCreateParamsBase, responses_kwargs) - ) - final_chat = openai_chat_converter.to_external(request2) - - # Note: model will be dropped in the conversion. - assert "model" not in final_chat - assert final_chat["max_completion_tokens"] == original_chat["max_completion_tokens"] - assert final_chat["temperature"] == original_chat["temperature"] - - def test_chat_to_messages_temperature_conversion(self): - """Test Chat -> Messages converts temperature correctly.""" - chat_params: openai.types.chat.completion_create_params.CompletionCreateParams = { - "model": "gpt-4o", - "messages": [{"role": "user", "content": "Hello"}], - "temperature": 1.0, - } - request = openai_chat_converter.from_external(chat_params) - claude_kwargs = anthropic_converter.to_external(request) - - assert claude_kwargs["temperature"] == 0.5 # 1.0 / 2 - - def test_messages_to_chat_temperature_conversion(self): - """Test Messages -> Chat converts temperature correctly.""" - messages_params: anthropic.types.MessageCreateParams = { - "model": "claude-sonnet-4-5-20250929", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 100, - "temperature": 0.5, - } - request = anthropic_converter.from_external(messages_params) - chat_kwargs = openai_chat_converter.to_external(request) - - assert chat_kwargs["temperature"] == 1.0 # 0.5 * 2 - - def test_all_apis_preserve_core_params(self): - """Test that core parameters are preserved across all conversions.""" - # Create a request with all common parameters - original = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - max_output_tokens=100, - temperature=1.0, - top_p=0.9, - stream=True, - metadata={"user_id": "123"}, - ) - - # Convert to all formats - chat_kwargs = openai_chat_converter.to_external(original) - responses_kwargs = openai_responses_converter.to_external(original) - messages_kwargs = anthropic_converter.to_external(original) - - # Verify core params are present in all - assert chat_kwargs["max_completion_tokens"] == 100 - assert responses_kwargs["max_output_tokens"] == 100 - assert messages_kwargs["max_tokens"] == 100 - - assert chat_kwargs["temperature"] == 1.0 - assert responses_kwargs["temperature"] == 1.0 - assert messages_kwargs["temperature"] == 0.5 # Converted - - assert chat_kwargs["top_p"] == 0.9 - assert responses_kwargs["top_p"] == 0.9 - assert messages_kwargs["top_p"] == 0.9 - - assert chat_kwargs["stream"] is True - assert responses_kwargs["stream"] is True - assert messages_kwargs["stream"] is True - - assert chat_kwargs["metadata"] == {"user_id": "123"} - assert responses_kwargs["metadata"] == {"user_id": "123"} - assert messages_kwargs["metadata"] == {"user_id": "123"} - - def test_responses_tool_choice_flat_format(self): - """Test that Responses API uses flat tool_choice format.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - tools=[ - { - "name": "search", - "description": "Search the web", - "input_schema": {"type": "object", "properties": {}}, - } - ], - tool_choice={"type": "tool", "name": "search"}, - ) - kwargs = openai_responses_converter.to_external(request) - - # Responses API uses flat format: {"type": "function", "name": "search"} - assert kwargs["tool_choice"] == {"type": "function", "name": "search"} - - def test_chat_tool_choice_nested_format(self): - """Test that Chat Completions API uses nested tool_choice format.""" - request = request_lib.LLMRequest( - messages=[{"role": "user", "content": "Hello"}], - tools=[ - { - "name": "search", - "description": "Search the web", - "input_schema": {"type": "object", "properties": {}}, - } - ], - tool_choice={"type": "tool", "name": "search"}, - ) - kwargs = openai_chat_converter.to_external(request) - - # Chat Completions API uses nested format: {"type": "function", "function": {"name": "search"}} - assert kwargs["tool_choice"] == {"type": "function", "function": {"name": "search"}} diff --git a/src/ares/llms/response.py b/src/ares/llms/response.py index dead7aa31..ef8d86849 100644 --- a/src/ares/llms/response.py +++ b/src/ares/llms/response.py @@ -1,38 +1,99 @@ -"""LLM response model.""" +"""LLM response model wrapping linguafranca types.""" import dataclasses +import time +import uuid + +from linguafranca import types as lft @dataclasses.dataclass(frozen=True) -class Usage: - """Token usage information for an LLM call.""" +class InferenceResult: + """Result from an LLM inference call. - prompt_tokens: int - generated_tokens: int + Attributes: + response: The linguafranca Open Responses response. + cost: Cost of the inference call in USD. + """ - @property - def total_tokens(self) -> int: - """Total tokens used (prompt + generated).""" - return self.prompt_tokens + self.generated_tokens + response: lft.OpenResponsesResponse + cost: float + @property + def usage(self) -> lft.Usage | None: + """Token usage information.""" + return self.response.usage -@dataclasses.dataclass(frozen=True) -class TextData: - """Text content from an LLM response.""" - content: str +_TEXT_CONTENT_TYPES = ( + lft.ContentPartInputText, + lft.ContentPartOutputText, + lft.ContentPartText, + lft.ContentPartSummaryText, + lft.ContentPartReasoningText, +) -@dataclasses.dataclass(frozen=True) -class LLMResponse: - """Response from an LLM call. +def extract_text_content(response: lft.OpenResponsesResponse) -> str: + """Extract text content from an Open Responses response. - Attributes: - data: List of content blocks (currently only TextData, but extensible to ImageData, etc.) - cost: Cost of the LLM call in USD. - usage: Token usage information. + Returns the concatenated text from all output message content parts. + Returns empty string if no text content is found. """ + text_parts: list[str] = [] + for item in response.output: + if not isinstance(item, lft.OutputItemMessage): + continue + for part in item.content: + if isinstance(part, _TEXT_CONTENT_TYPES): + text_parts.append(part.text) + return "".join(text_parts) - data: list[TextData] - cost: float - usage: Usage + +def make_response( + content: str, + *, + model: str = "", + input_tokens: int = 0, + output_tokens: int = 0, + response_id: str | None = None, +) -> lft.OpenResponsesResponse: + """Create an Open Responses response with the given text content. + + This is a convenience function for constructing lft.OpenResponsesResponse + objects from simple text completions. + + Args: + content: The text content of the response. + model: The model that generated the response. + input_tokens: Number of input/prompt tokens. + output_tokens: Number of output/completion tokens. + response_id: Optional response ID. Generates a UUID if not provided. + + Returns: + A fully-formed lft.OpenResponsesResponse. + """ + output: list[lft.OutputItem] = [ + lft.OutputItemMessage( + content=[lft.ContentPartOutputText(text=content, type="output_text", annotations=None, logprobs=None)], + id=f"msg_{uuid.uuid4().hex}", + role=lft.MessageRole.assistant, + status=lft.ItemStatus.completed, + type="message", + ) + ] + return lft.OpenResponsesResponse( + created_at=int(time.time()), + id=response_id or f"resp_{uuid.uuid4().hex}", + model=model, + object="response", + output=output, + status=lft.ResponseStatus.completed, + usage=lft.Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + input_tokens_details=None, + output_tokens_details=None, + ), + ) diff --git a/src/ares/presets.py b/src/ares/presets.py index 7429cfa1a..beda66349 100644 --- a/src/ares/presets.py +++ b/src/ares/presets.py @@ -125,14 +125,20 @@ def _register_default_presets() -> None: This function is called automatically when the presets module is imported, ensuring built-in presets are always available. """ + existing_preset_names = set(registry._list_presets()) + seen_preset_names: set[str] = set() for ds_spec in code_env.list_harbor_datasets(): for code_agent_id, code_agent_factory in [ ("mswea", mini_swe_agent.MiniSWECodeAgent), ("terminus2", terminus2_agent.Terminus2Agent), ]: ds_id = _make_harbor_dataset_id(ds_spec.name, ds_spec.version) + preset_name = f"{ds_id}-{code_agent_id}" + if preset_name in seen_preset_names or preset_name in existing_preset_names: + continue + seen_preset_names.add(preset_name) registry.register_preset( - f"{ds_id}-{code_agent_id}", + preset_name, HarborSpec( ds_spec=ds_spec, dataset_id=ds_id, @@ -140,9 +146,12 @@ def _register_default_presets() -> None: code_agent_id=code_agent_id, ), ) + existing_preset_names.add(preset_name) # Twenty Questions — lightweight, no Docker needed. - registry.register_preset("20q", TwentyQuestionsSpec()) + if "20q" not in existing_preset_names: + registry.register_preset("20q", TwentyQuestionsSpec()) + existing_preset_names.add("20q") _LOGGER.debug("Registered %d default presets", len(registry._list_presets())) diff --git a/src/ares/presets_test.py b/src/ares/presets_test.py new file mode 100644 index 000000000..2597528ac --- /dev/null +++ b/src/ares/presets_test.py @@ -0,0 +1,35 @@ +"""Tests for built-in preset behavior.""" + +import pytest + +from linguafranca import types as lft + +import ares +from ares import presets +from ares import registry +from ares.llms import open_responses +from ares.llms import response + + +@pytest.mark.asyncio +async def test_make_twenty_questions_preset_uses_open_responses_observations(): + preset_name = "20q-open-responses-test" + if preset_name in registry._list_presets(): + registry.unregister_preset(preset_name) + + registry.register_preset(preset_name, presets.TwentyQuestionsSpec(objects=("Basketball",))) + + try: + async with ares.make(f"{preset_name}:0") as env: + ts = await env.reset() + assert isinstance(ts.observation, lft.OpenResponsesRequest) + assert open_responses.request_to_jsonable(ts.observation)["input"][0]["role"] == "user" + lf_response = response.make_response("Is it Basketball?", input_tokens=1, output_tokens=1) + ts = await env.step(response.InferenceResult(response=lf_response, cost=0.0)) + + assert ts.last() + assert ts.reward == 0.0 + assert isinstance(ts.observation, lft.OpenResponsesRequest) + assert open_responses.request_to_jsonable(ts.observation)["input"][0]["role"] == "user" + finally: + registry.unregister_preset(preset_name) diff --git a/src/ares/testing/mock_llm.py b/src/ares/testing/mock_llm.py index 79975ca4c..3f80fbbd2 100644 --- a/src/ares/testing/mock_llm.py +++ b/src/ares/testing/mock_llm.py @@ -3,7 +3,9 @@ from collections.abc import Callable import dataclasses -from ares.llms import request +from linguafranca import types as lft + +from ares.llms import open_responses from ares.llms import response @@ -15,27 +17,27 @@ class MockLLMClient: API requests. It records all requests and allows configuring responses. Attributes: - requests: List of all LLMRequest objects received. + requests: List of all Open Responses requests received. responses: List of response strings to return (cycles through them). response_handler: Optional function to dynamically generate responses. default_response: Default response if no responses configured. call_count: Number of times the client has been called. """ - requests: list[request.LLMRequest] = dataclasses.field(default_factory=list) + requests: list[lft.OpenResponsesRequest] = dataclasses.field(default_factory=list) responses: list[str] = dataclasses.field(default_factory=list) - response_handler: Callable[[request.LLMRequest], str] | None = None + response_handler: Callable[[lft.OpenResponsesRequest], str] | None = None default_response: str = "Mock LLM response" call_count: int = 0 - async def __call__(self, request: request.LLMRequest) -> response.LLMResponse: + async def __call__(self, request: lft.OpenResponsesRequest) -> response.InferenceResult: """Process LLM request and return mock response. Args: request: The LLM request to process. Returns: - LLMResponse with mock data. + InferenceResult with mock data. """ self.requests.append(request) self.call_count += 1 @@ -49,21 +51,23 @@ async def __call__(self, request: request.LLMRequest) -> response.LLMResponse: else: response_text = self.default_response - return response.LLMResponse( - data=[response.TextData(content=response_text)], - cost=0.0, - usage=response.Usage(prompt_tokens=100, generated_tokens=50), + lf_response = response.make_response( + response_text, + model="mock-model", + input_tokens=100, + output_tokens=50, ) + return response.InferenceResult(response=lf_response, cost=0.0) - def get_last_request(self) -> request.LLMRequest | None: + def get_last_request(self) -> lft.OpenResponsesRequest | None: """Get the most recent request, or None if no requests.""" return self.requests[-1] if self.requests else None - def get_request_messages(self, index: int = -1) -> list[request.Message]: - """Get messages from a specific request (default: last request).""" + def get_request_messages(self, index: int = -1) -> list[lft.InputItem]: + """Get input items from a specific request (default: last request).""" if not self.requests: return [] - return self.requests[index].messages + return open_responses.input_items(self.requests[index]) def reset(self) -> None: """Clear all recorded data.""" diff --git a/src/ares/testing/mock_llm_test.py b/src/ares/testing/mock_llm_test.py index e82a78ae3..35516e6fd 100644 --- a/src/ares/testing/mock_llm_test.py +++ b/src/ares/testing/mock_llm_test.py @@ -1,8 +1,12 @@ """Tests for mock LLM client implementation.""" +from typing import cast + +from linguafranca import types as lft import pytest -from ares.llms import request +from ares.llms import open_responses +from ares.llms import response from ares.testing import mock_llm @@ -11,8 +15,8 @@ async def test_mock_llm_client_records_requests(): """Test that mock LLM client records all requests.""" client = mock_llm.MockLLMClient() - request1 = request.LLMRequest(messages=[{"role": "user", "content": "Hello"}]) - request2 = request.LLMRequest(messages=[{"role": "user", "content": "World"}]) + request1 = open_responses.make_request([open_responses.user_message("Hello")]) + request2 = open_responses.make_request([open_responses.user_message("World")]) await client(request1) await client(request2) @@ -27,11 +31,11 @@ async def test_mock_llm_client_default_response(): """Test that mock LLM client returns default response.""" client = mock_llm.MockLLMClient() - req = request.LLMRequest(messages=[{"role": "user", "content": "test"}]) - response = await client(req) + req = open_responses.make_request([open_responses.user_message("test")]) + llm_response = await client(req) - assert response.data[0].content == "Mock LLM response" - assert response.cost == 0.0 + assert response.extract_text_content(llm_response.response) == "Mock LLM response" + assert llm_response.cost == 0.0 @pytest.mark.asyncio @@ -39,34 +43,34 @@ async def test_mock_llm_client_configured_responses(): """Test that mock LLM client cycles through configured responses.""" client = mock_llm.MockLLMClient(responses=["First", "Second", "Third"]) - req = request.LLMRequest(messages=[{"role": "user", "content": "test"}]) + req = open_responses.make_request([open_responses.user_message("test")]) response1 = await client(req) response2 = await client(req) response3 = await client(req) response4 = await client(req) # Should cycle back to first - assert response1.data[0].content == "First" - assert response2.data[0].content == "Second" - assert response3.data[0].content == "Third" - assert response4.data[0].content == "First" + assert response.extract_text_content(response1.response) == "First" + assert response.extract_text_content(response2.response) == "Second" + assert response.extract_text_content(response3.response) == "Third" + assert response.extract_text_content(response4.response) == "First" @pytest.mark.asyncio async def test_mock_llm_client_response_handler(): """Test that mock LLM client uses custom response handler.""" - def handler(req: request.LLMRequest) -> str: + def handler(req: lft.OpenResponsesRequest) -> str: # Echo back the user's message - user_msg = req.messages[-1].get("content", "") + user_msg = open_responses.message_text(open_responses.message_items(req)[-1]) return f"You said: {user_msg}" client = mock_llm.MockLLMClient(response_handler=handler) - req = request.LLMRequest(messages=[{"role": "user", "content": "Hello AI"}]) - response = await client(req) + req = open_responses.make_request([open_responses.user_message("Hello AI")]) + llm_response = await client(req) - assert response.data[0].content == "You said: Hello AI" + assert response.extract_text_content(llm_response.response) == "You said: Hello AI" @pytest.mark.asyncio @@ -74,7 +78,7 @@ async def test_mock_llm_client_call_count(): """Test that mock LLM client tracks call count.""" client = mock_llm.MockLLMClient() - req = request.LLMRequest(messages=[{"role": "user", "content": "test"}]) + req = open_responses.make_request([open_responses.user_message("test")]) assert client.call_count == 0 @@ -92,8 +96,8 @@ async def test_mock_llm_client_get_last_request(): assert client.get_last_request() is None - request1 = request.LLMRequest(messages=[{"role": "user", "content": "First"}]) - request2 = request.LLMRequest(messages=[{"role": "user", "content": "Second"}]) + request1 = open_responses.make_request([open_responses.user_message("First")]) + request2 = open_responses.make_request([open_responses.user_message("Second")]) await client(request1) assert client.get_last_request() == request1 @@ -107,17 +111,33 @@ async def test_mock_llm_client_get_request_messages(): """Test getting messages from specific requests.""" client = mock_llm.MockLLMClient() - req = request.LLMRequest( - messages=[ - {"role": "user", "content": "Hello"}, - ], - ) + req = open_responses.make_request([open_responses.user_message("Hello")]) await client(req) messages = client.get_request_messages() assert len(messages) == 1 - assert messages[0].get("content", "") == "Hello" + assert open_responses.message_text(cast(lft.InputItemMessage, messages[0])) == "Hello" + + +@pytest.mark.asyncio +async def test_mock_llm_client_get_request_messages_includes_tool_items(): + """Test that request item helper returns tool call items too.""" + client = mock_llm.MockLLMClient() + + req = open_responses.make_request( + [ + open_responses.function_call(call_id="call_1", name="search", arguments="{}"), + open_responses.function_call_output(call_id="call_1", output="done"), + ] + ) + + await client(req) + + items = client.get_request_messages() + assert len(items) == 2 + assert items[0].type == "function_call" + assert items[1].type == "function_call_output" @pytest.mark.asyncio @@ -125,7 +145,7 @@ async def test_mock_llm_client_reset(): """Test that reset() clears all data.""" client = mock_llm.MockLLMClient() - req = request.LLMRequest(messages=[{"role": "user", "content": "test"}]) + req = open_responses.make_request([open_responses.user_message("test")]) await client(req) await client(req) @@ -143,20 +163,19 @@ async def test_mock_llm_response_structure(): """Test that mock response has correct structure.""" client = mock_llm.MockLLMClient() - req = request.LLMRequest(messages=[{"role": "user", "content": "test"}]) - response = await client(req) + req = open_responses.make_request([open_responses.user_message("test")]) + llm_response = await client(req) # Check response structure - assert hasattr(response, "data") - assert hasattr(response, "cost") - assert hasattr(response, "usage") + assert hasattr(llm_response, "response") + assert hasattr(llm_response, "cost") + assert hasattr(llm_response, "usage") - # Check data structure - assert len(response.data) == 1 - assert hasattr(response.data[0], "content") - assert response.data[0].content == "Mock LLM response" + # Check text content extraction + assert response.extract_text_content(llm_response.response) == "Mock LLM response" # Check usage structure - assert response.usage.prompt_tokens == 100 - assert response.usage.generated_tokens == 50 - assert response.usage.total_tokens == 150 + assert llm_response.usage is not None + assert llm_response.usage.input_tokens == 100 + assert llm_response.usage.output_tokens == 50 + assert llm_response.usage.total_tokens == 150