Skip to content
This repository was archived by the owner on May 30, 2026. It is now read-only.

Commit f844f2e

Browse files
author
a.kaznacheev
committed
fixes
1 parent 8e34123 commit f844f2e

6 files changed

Lines changed: 313 additions & 16 deletions

File tree

ouroboros/llm.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
from __future__ import annotations
99

10+
import json
1011
import logging
1112
import os
1213
import re
1314
import time
14-
from typing import Any, Dict, List, Optional, Tuple
15+
from typing import Any, Dict, List, Optional, Set, Tuple
1516

1617
log = logging.getLogger(__name__)
1718

@@ -111,24 +112,32 @@ def __init__(
111112
api_key: Optional[str] = None,
112113
base_url: str = "https://openrouter.ai/api/v1",
113114
):
115+
self._api_key_override = api_key
114116
self._api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
115117
self._base_url = base_url
116118
self._client = None
119+
self._client_api_key: Optional[str] = None
117120
self._local_client = None
118121
self._local_port: Optional[int] = None
119122

120123
def _get_client(self):
121-
if self._client is None:
124+
current_api_key = self._api_key_override
125+
if current_api_key is None:
126+
current_api_key = os.environ.get("OPENROUTER_API_KEY", "")
127+
128+
if self._client is None or self._client_api_key != current_api_key:
122129
from openai import OpenAI
123130
self._client = OpenAI(
124131
base_url=self._base_url,
125-
api_key=self._api_key,
132+
api_key=current_api_key,
126133
max_retries=0,
127134
default_headers={
128135
"HTTP-Referer": "https://ouroboros.local/",
129136
"X-Title": "Ouroboros",
130137
},
131138
)
139+
self._client_api_key = current_api_key
140+
self._api_key = current_api_key
132141
return self._client
133142

134143
def _get_local_client(self):
@@ -273,9 +282,82 @@ def _chat_local(
273282
choices = resp_dict.get("choices") or [{}]
274283
msg = (choices[0] if choices else {}).get("message") or {}
275284

285+
if not msg.get("tool_calls") and msg.get("content") and clean_tools:
286+
allowed_tool_names = {
287+
str(t.get("function", {}).get("name", "")).strip()
288+
for t in clean_tools
289+
if isinstance(t, dict)
290+
}
291+
msg = self._parse_tool_calls_from_content(msg, allowed_tool_names)
292+
276293
usage["cost"] = 0.0
277294
return msg, usage
278295

296+
@staticmethod
297+
def _parse_tool_calls_from_content(
298+
msg: Dict[str, Any],
299+
allowed_tool_names: Optional[Set[str]] = None,
300+
) -> Dict[str, Any]:
301+
"""Parse <tool_call> XML tags from content into structured tool_calls.
302+
303+
Works around llama-cpp-python not parsing Qwen/Hermes-style tool calls
304+
(https://github.com/abetlen/llama-cpp-python/issues/1784).
305+
"""
306+
content = str(msg.get("content", "") or "")
307+
stripped = content.strip()
308+
if not stripped:
309+
return msg
310+
311+
# Safety: only upgrade the response when it consists solely of
312+
# one or more <tool_call> blocks. If the model mixed prose with
313+
# examples or explanations, leave it as plain text.
314+
full_pattern = re.compile(
315+
r"^(?:\s*<tool_call>\s*\{.*?\}\s*</tool_call>\s*)+$",
316+
re.DOTALL,
317+
)
318+
if not full_pattern.fullmatch(stripped):
319+
return msg
320+
321+
matches = re.findall(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", stripped, re.DOTALL)
322+
if not matches:
323+
return msg
324+
325+
allowed = {name for name in (allowed_tool_names or set()) if name}
326+
tool_calls = []
327+
for i, raw in enumerate(matches):
328+
try:
329+
obj = json.loads(raw)
330+
if not isinstance(obj, dict):
331+
raise ValueError("tool_call payload must be an object")
332+
name = str(obj.get("name", "")).strip()
333+
args = obj.get("arguments", {})
334+
if not name:
335+
raise ValueError("tool_call missing function name")
336+
if allowed and name not in allowed:
337+
raise ValueError(f"unknown tool '{name}'")
338+
if not isinstance(args, dict):
339+
raise ValueError("tool_call arguments must be an object")
340+
tool_calls.append({
341+
"id": f"call_local_{i}",
342+
"type": "function",
343+
"function": {
344+
"name": name,
345+
"arguments": json.dumps(args),
346+
},
347+
})
348+
except (json.JSONDecodeError, ValueError) as exc:
349+
log.warning("Rejected local <tool_call> block: %s (%s)", raw[:200], exc)
350+
return msg
351+
352+
if not tool_calls:
353+
return msg
354+
355+
msg = dict(msg)
356+
msg["tool_calls"] = tool_calls
357+
msg["content"] = None
358+
log.info("Parsed %d local tool call(s) from text output", len(tool_calls))
359+
return msg
360+
279361
@staticmethod
280362
def _truncate_messages_for_context(
281363
messages: List[Dict[str, Any]], ctx_len: int, max_tokens: int,

ouroboros/local_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,19 @@ def test_tool_calling(self) -> Dict[str, Any]:
427427
max_tokens=256,
428428
)
429429
msg = resp.choices[0].message if resp.choices else None
430-
if msg and msg.tool_calls:
430+
tool_calls = list(getattr(msg, "tool_calls", None) or []) if msg else []
431+
if msg and not tool_calls and getattr(msg, "content", None):
432+
from ouroboros.llm import LLMClient
433+
434+
parsed = LLMClient._parse_tool_calls_from_content(
435+
{
436+
"content": msg.content,
437+
"tool_calls": [],
438+
},
439+
{"get_time"},
440+
)
441+
tool_calls = parsed.get("tool_calls") or []
442+
if tool_calls:
431443
result["tool_call_ok"] = True
432444
else:
433445
result["details"] = "Model returned text instead of tool_call"

tests/test_llm_client_refresh.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os
2+
import sys
3+
import types
4+
import unittest
5+
from unittest.mock import patch
6+
7+
8+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
9+
10+
11+
class _FakeOpenAI:
12+
created = []
13+
14+
def __init__(self, **kwargs):
15+
self.kwargs = kwargs
16+
type(self).created.append(self)
17+
18+
19+
class TestLlmClientRefresh(unittest.TestCase):
20+
def setUp(self):
21+
_FakeOpenAI.created.clear()
22+
23+
def test_runtime_client_refreshes_when_env_key_changes(self):
24+
from ouroboros.llm import LLMClient
25+
26+
fake_openai = types.SimpleNamespace(OpenAI=_FakeOpenAI)
27+
with patch.dict(sys.modules, {"openai": fake_openai}):
28+
with patch.dict(os.environ, {"OPENROUTER_API_KEY": ""}, clear=False):
29+
client = LLMClient()
30+
first = client._get_client()
31+
32+
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "sk-or-new-key"}, clear=False):
33+
second = client._get_client()
34+
35+
self.assertIsNot(first, second)
36+
self.assertEqual(len(_FakeOpenAI.created), 2)
37+
self.assertEqual(_FakeOpenAI.created[0].kwargs["api_key"], "")
38+
self.assertEqual(_FakeOpenAI.created[1].kwargs["api_key"], "sk-or-new-key")
39+
40+
def test_explicit_api_key_does_not_track_env_changes(self):
41+
from ouroboros.llm import LLMClient
42+
43+
fake_openai = types.SimpleNamespace(OpenAI=_FakeOpenAI)
44+
with patch.dict(sys.modules, {"openai": fake_openai}):
45+
with patch.dict(os.environ, {"OPENROUTER_API_KEY": ""}, clear=False):
46+
client = LLMClient(api_key="explicit-key")
47+
first = client._get_client()
48+
49+
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "sk-or-new-key"}, clear=False):
50+
second = client._get_client()
51+
52+
self.assertIs(first, second)
53+
self.assertEqual(len(_FakeOpenAI.created), 1)
54+
self.assertEqual(_FakeOpenAI.created[0].kwargs["api_key"], "explicit-key")

tests/test_local_tool_parsing.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import json
2+
import os
3+
import sys
4+
import unittest
5+
6+
7+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
8+
9+
10+
class TestLocalToolCallParsing(unittest.TestCase):
11+
def test_parses_pure_tool_call_blocks(self):
12+
from ouroboros.llm import LLMClient
13+
14+
msg = {
15+
"content": """
16+
<tool_call>
17+
{"name": "repo_read", "arguments": {"path": "README.md"}}
18+
</tool_call>
19+
<tool_call>
20+
{"name": "repo_write", "arguments": {"path": "notes.txt", "content": "hello"}}
21+
</tool_call>
22+
""",
23+
"tool_calls": [],
24+
}
25+
26+
parsed = LLMClient._parse_tool_calls_from_content(
27+
msg,
28+
{"repo_read", "repo_write"},
29+
)
30+
31+
self.assertEqual(len(parsed["tool_calls"]), 2)
32+
self.assertIsNone(parsed["content"])
33+
self.assertEqual(parsed["tool_calls"][0]["function"]["name"], "repo_read")
34+
self.assertEqual(
35+
json.loads(parsed["tool_calls"][0]["function"]["arguments"]),
36+
{"path": "README.md"},
37+
)
38+
39+
def test_rejects_mixed_prose_and_tool_calls(self):
40+
from ouroboros.llm import LLMClient
41+
42+
msg = {
43+
"content": """
44+
Sure, I will use the tool now.
45+
46+
<tool_call>
47+
{"name": "repo_read", "arguments": {"path": "README.md"}}
48+
</tool_call>
49+
""",
50+
"tool_calls": [],
51+
}
52+
53+
parsed = LLMClient._parse_tool_calls_from_content(msg, {"repo_read"})
54+
55+
self.assertEqual(parsed, msg)
56+
57+
def test_rejects_unknown_tool_names(self):
58+
from ouroboros.llm import LLMClient
59+
60+
msg = {
61+
"content": """
62+
<tool_call>
63+
{"name": "repo_delete_everything", "arguments": {}}
64+
</tool_call>
65+
""",
66+
"tool_calls": [],
67+
}
68+
69+
parsed = LLMClient._parse_tool_calls_from_content(msg, {"repo_read"})
70+
71+
self.assertEqual(parsed, msg)
72+
73+
def test_rejects_non_object_arguments(self):
74+
from ouroboros.llm import LLMClient
75+
76+
msg = {
77+
"content": """
78+
<tool_call>
79+
{"name": "repo_read", "arguments": "README.md"}
80+
</tool_call>
81+
""",
82+
"tool_calls": [],
83+
}
84+
85+
parsed = LLMClient._parse_tool_calls_from_content(msg, {"repo_read"})
86+
87+
self.assertEqual(parsed, msg)

tests/test_settings_ui_guards.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
import pathlib
3+
import sys
4+
import unittest
5+
6+
7+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
8+
9+
REPO = pathlib.Path(__file__).resolve().parents[1]
10+
11+
12+
class TestSettingsUiGuards(unittest.TestCase):
13+
def test_save_checks_http_status(self):
14+
source = (REPO / "web/modules/settings.js").read_text(encoding="utf-8")
15+
self.assertIn("if (!resp.ok) throw new Error(data.error || `HTTP ${resp.status}`);", source)
16+
17+
def test_save_does_not_overwrite_masked_secrets(self):
18+
source = (REPO / "web/modules/settings.js").read_text(encoding="utf-8")
19+
self.assertIn("if (orKey && !orKey.includes('...')) body.OPENROUTER_API_KEY = orKey;", source)
20+
self.assertIn("if (oaiKey && !oaiKey.includes('...')) body.OPENAI_API_KEY = oaiKey;", source)
21+
self.assertIn("if (antKey && !antKey.includes('...')) body.ANTHROPIC_API_KEY = antKey;", source)
22+
self.assertIn("if (ghToken && !ghToken.includes('...')) body.GITHUB_TOKEN = ghToken;", source)
23+
24+
def test_masked_secret_inputs_clear_on_focus(self):
25+
source = (REPO / "web/modules/settings.js").read_text(encoding="utf-8")
26+
self.assertIn("if (input.value.includes('...')) input.value = '';", source)
27+
28+
def test_models_section_explains_local_switching(self):
29+
source = (REPO / "web/modules/settings.js").read_text(encoding="utf-8")
30+
self.assertIn("These fields are cloud model IDs.", source)
31+
self.assertIn("through the GGUF server configured above.", source)
32+
33+
def test_save_reloads_settings_after_success(self):
34+
source = (REPO / "web/modules/settings.js").read_text(encoding="utf-8")
35+
self.assertIn("await loadSettings();", source)

0 commit comments

Comments
 (0)