|
7 | 7 |
|
8 | 8 | from __future__ import annotations |
9 | 9 |
|
| 10 | +import json |
10 | 11 | import logging |
11 | 12 | import os |
12 | 13 | import re |
13 | 14 | import time |
14 | | -from typing import Any, Dict, List, Optional, Tuple |
| 15 | +from typing import Any, Dict, List, Optional, Set, Tuple |
15 | 16 |
|
16 | 17 | log = logging.getLogger(__name__) |
17 | 18 |
|
@@ -111,24 +112,32 @@ def __init__( |
111 | 112 | api_key: Optional[str] = None, |
112 | 113 | base_url: str = "https://openrouter.ai/api/v1", |
113 | 114 | ): |
| 115 | + self._api_key_override = api_key |
114 | 116 | self._api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "") |
115 | 117 | self._base_url = base_url |
116 | 118 | self._client = None |
| 119 | + self._client_api_key: Optional[str] = None |
117 | 120 | self._local_client = None |
118 | 121 | self._local_port: Optional[int] = None |
119 | 122 |
|
120 | 123 | def _get_client(self): |
121 | | - if self._client is None: |
| 124 | + current_api_key = self._api_key_override |
| 125 | + if current_api_key is None: |
| 126 | + current_api_key = os.environ.get("OPENROUTER_API_KEY", "") |
| 127 | + |
| 128 | + if self._client is None or self._client_api_key != current_api_key: |
122 | 129 | from openai import OpenAI |
123 | 130 | self._client = OpenAI( |
124 | 131 | base_url=self._base_url, |
125 | | - api_key=self._api_key, |
| 132 | + api_key=current_api_key, |
126 | 133 | max_retries=0, |
127 | 134 | default_headers={ |
128 | 135 | "HTTP-Referer": "https://ouroboros.local/", |
129 | 136 | "X-Title": "Ouroboros", |
130 | 137 | }, |
131 | 138 | ) |
| 139 | + self._client_api_key = current_api_key |
| 140 | + self._api_key = current_api_key |
132 | 141 | return self._client |
133 | 142 |
|
134 | 143 | def _get_local_client(self): |
@@ -273,9 +282,82 @@ def _chat_local( |
273 | 282 | choices = resp_dict.get("choices") or [{}] |
274 | 283 | msg = (choices[0] if choices else {}).get("message") or {} |
275 | 284 |
|
| 285 | + if not msg.get("tool_calls") and msg.get("content") and clean_tools: |
| 286 | + allowed_tool_names = { |
| 287 | + str(t.get("function", {}).get("name", "")).strip() |
| 288 | + for t in clean_tools |
| 289 | + if isinstance(t, dict) |
| 290 | + } |
| 291 | + msg = self._parse_tool_calls_from_content(msg, allowed_tool_names) |
| 292 | + |
276 | 293 | usage["cost"] = 0.0 |
277 | 294 | return msg, usage |
278 | 295 |
|
| 296 | + @staticmethod |
| 297 | + def _parse_tool_calls_from_content( |
| 298 | + msg: Dict[str, Any], |
| 299 | + allowed_tool_names: Optional[Set[str]] = None, |
| 300 | + ) -> Dict[str, Any]: |
| 301 | + """Parse <tool_call> XML tags from content into structured tool_calls. |
| 302 | +
|
| 303 | + Works around llama-cpp-python not parsing Qwen/Hermes-style tool calls |
| 304 | + (https://github.com/abetlen/llama-cpp-python/issues/1784). |
| 305 | + """ |
| 306 | + content = str(msg.get("content", "") or "") |
| 307 | + stripped = content.strip() |
| 308 | + if not stripped: |
| 309 | + return msg |
| 310 | + |
| 311 | + # Safety: only upgrade the response when it consists solely of |
| 312 | + # one or more <tool_call> blocks. If the model mixed prose with |
| 313 | + # examples or explanations, leave it as plain text. |
| 314 | + full_pattern = re.compile( |
| 315 | + r"^(?:\s*<tool_call>\s*\{.*?\}\s*</tool_call>\s*)+$", |
| 316 | + re.DOTALL, |
| 317 | + ) |
| 318 | + if not full_pattern.fullmatch(stripped): |
| 319 | + return msg |
| 320 | + |
| 321 | + matches = re.findall(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", stripped, re.DOTALL) |
| 322 | + if not matches: |
| 323 | + return msg |
| 324 | + |
| 325 | + allowed = {name for name in (allowed_tool_names or set()) if name} |
| 326 | + tool_calls = [] |
| 327 | + for i, raw in enumerate(matches): |
| 328 | + try: |
| 329 | + obj = json.loads(raw) |
| 330 | + if not isinstance(obj, dict): |
| 331 | + raise ValueError("tool_call payload must be an object") |
| 332 | + name = str(obj.get("name", "")).strip() |
| 333 | + args = obj.get("arguments", {}) |
| 334 | + if not name: |
| 335 | + raise ValueError("tool_call missing function name") |
| 336 | + if allowed and name not in allowed: |
| 337 | + raise ValueError(f"unknown tool '{name}'") |
| 338 | + if not isinstance(args, dict): |
| 339 | + raise ValueError("tool_call arguments must be an object") |
| 340 | + tool_calls.append({ |
| 341 | + "id": f"call_local_{i}", |
| 342 | + "type": "function", |
| 343 | + "function": { |
| 344 | + "name": name, |
| 345 | + "arguments": json.dumps(args), |
| 346 | + }, |
| 347 | + }) |
| 348 | + except (json.JSONDecodeError, ValueError) as exc: |
| 349 | + log.warning("Rejected local <tool_call> block: %s (%s)", raw[:200], exc) |
| 350 | + return msg |
| 351 | + |
| 352 | + if not tool_calls: |
| 353 | + return msg |
| 354 | + |
| 355 | + msg = dict(msg) |
| 356 | + msg["tool_calls"] = tool_calls |
| 357 | + msg["content"] = None |
| 358 | + log.info("Parsed %d local tool call(s) from text output", len(tool_calls)) |
| 359 | + return msg |
| 360 | + |
279 | 361 | @staticmethod |
280 | 362 | def _truncate_messages_for_context( |
281 | 363 | messages: List[Dict[str, Any]], ctx_len: int, max_tokens: int, |
|
0 commit comments