Skip to content

Commit 81c34dd

Browse files
committed
refactor agent
Signed-off-by: yaacov <yzamir@redhat.com>
1 parent 3b3c636 commit 81c34dd

19 files changed

Lines changed: 620 additions & 240 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ mcp.json
2727

2828
# Cache
2929
.cache/
30+
31+
# Debug dumps
32+
dumps/

mtv_agent/server/agent.py

Lines changed: 83 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44

55
import json
66
import logging
7-
from collections.abc import AsyncGenerator, Callable, Awaitable
7+
from collections.abc import AsyncGenerator
88
from typing import Any
99

1010
from mtv_agent.server.llm.client import LLMClient
1111
from mtv_agent.server.mcp.manager import MCPManager
12+
from mtv_agent.server.tools import ApproveFunc, execute_tool_call, trim_history
1213

1314
logger = logging.getLogger(__name__)
1415

15-
ApproveFunc = Callable[[str, dict], Awaitable[tuple[bool, str | None]]]
16-
1716

1817
async def run_stream(
1918
message: str,
@@ -24,34 +23,62 @@ async def run_stream(
2423
history: list[dict] | None = None,
2524
namespace: str | None = None,
2625
command: str | None = None,
26+
session_id: str | None = None,
2727
max_iterations: int = 20,
28+
max_history_chars: int = 80_000,
2829
) -> AsyncGenerator[dict[str, Any], None]:
2930
"""Run the agent loop, yielding SSE-ready event dicts.
3031
3132
Iterates until the LLM produces a text response or hits *max_iterations*.
33+
34+
**Initial message setup** (built by ``_build_messages``):
35+
36+
1. System prompt -- always first, sets the agent persona and instructions.
37+
2. History -- previous user/assistant turns from the chat session, trimmed
38+
from the oldest end to stay within *max_history_chars* so we don't
39+
exceed the LLM context window.
40+
3. User message -- the current request. When the user invokes a
41+
slash-command (e.g. ``/check-cluster-health``), its body replaces the
42+
plain user message, with the original input appended for context.
43+
44+
**Iteration loop** (each pass through the ``for`` loop):
45+
46+
- Send the messages + tool definitions to the LLM.
47+
- If the LLM responds with plain text (no tool calls), yield it and stop.
48+
- If the LLM requests tool calls, append its response (via ``model_dump()``)
49+
to the messages list, then execute each tool and append the result.
50+
The OpenAI API requires this pairing: an assistant message with
51+
``tool_calls`` followed by a ``tool`` message for each call.
52+
- The loop then repeats, giving the LLM the tool results so it can
53+
decide to call more tools or produce a final text answer.
54+
55+
Example messages list after one tool-call iteration::
56+
57+
[
58+
{"role": "system", "content": "<system prompt>"},
59+
{"role": "user", "content": "<history msg 1>"},
60+
{"role": "assistant", "content": "<history msg 2>"},
61+
{"role": "user", "content": "<user message or command + user message>"},
62+
{"role": "assistant", "tool_calls": [{"id": "...", ...}]},
63+
{"role": "tool", "tool_call_id": "...", "content": "<result>"},
64+
]
3265
"""
3366
tools = mcp.get_tool_definitions()
3467

35-
tools_with_flags = {
36-
td["function"]["name"]
37-
for td in tools
38-
if "flags" in td.get("function", {}).get("parameters", {}).get("properties", {})
39-
}
68+
messages = _build_messages(
69+
system_prompt, command, history, message, max_history_chars
70+
)
4071

41-
trimmed_history = _trim_history(history or [])
42-
messages: list[dict] = [
43-
{"role": "system", "content": system_prompt},
44-
]
45-
if command:
46-
messages.append(
47-
{"role": "system", "content": f"Follow this command:\n\n{command}"}
48-
)
49-
messages.extend([*trimmed_history, {"role": "user", "content": message}])
72+
if llm.dumper and session_id:
73+
llm.dumper.set_session(session_id)
5074

5175
for iteration in range(max_iterations):
5276
logger.debug("Agent iteration %d", iteration + 1)
5377
yield {"event": "thinking"}
5478

79+
if llm.dumper:
80+
llm.dumper.next_iteration()
81+
5582
try:
5683
response = await llm.chat(messages, tools or None)
5784
except Exception as exc:
@@ -62,108 +89,59 @@ async def run_stream(
6289
choice = response.choices[0]
6390

6491
if not choice.message.tool_calls:
65-
yield {
66-
"event": "content",
67-
"content": choice.message.content or "",
68-
}
92+
yield {"event": "content", "content": choice.message.content or ""}
6993
return
7094

7195
messages.append(choice.message.model_dump())
7296

7397
for tc in choice.message.tool_calls:
7498
name = tc.function.name
75-
try:
76-
args = json.loads(tc.function.arguments)
77-
except json.JSONDecodeError:
78-
args = {}
79-
80-
if namespace and name in tools_with_flags:
81-
flags = args.setdefault("flags", {})
82-
if "namespace" not in flags:
83-
flags["namespace"] = namespace
84-
85-
policy = mcp.check_policy(name, args)
86-
87-
if policy == "reject":
88-
result = "Tool call rejected by policy."
89-
yield {
90-
"event": "tool_call",
91-
"name": name,
92-
"arguments": args,
93-
"pending": False,
94-
}
95-
yield {
96-
"event": "tool_rejected",
97-
"name": name,
98-
"reason": "blocked by policy",
99-
}
100-
elif policy == "ask" and approve_fn:
101-
yield {
102-
"event": "tool_call",
103-
"name": name,
104-
"arguments": args,
105-
"pending": True,
106-
}
107-
approved, reason = await approve_fn(name, args)
108-
if not approved:
109-
result = f"Tool call denied by user. {reason or ''}"
110-
yield {
111-
"event": "tool_rejected",
112-
"name": name,
113-
"reason": reason or "denied",
114-
}
99+
args = _parse_args(tc)
100+
if namespace:
101+
args.setdefault("flags", {}).setdefault("namespace", namespace)
102+
103+
result = ""
104+
async for event in execute_tool_call(name, args, mcp, approve_fn):
105+
if "_result" in event:
106+
result = event["_result"]
115107
else:
116-
try:
117-
result = await mcp.call_tool(name, args)
118-
except Exception as exc:
119-
logger.exception("Tool call %s failed", name)
120-
result = f"Error executing tool: {exc}"
121-
yield {
122-
"event": "tool_result",
123-
"name": name,
124-
"result": _truncate(result),
125-
}
126-
else:
127-
yield {
128-
"event": "tool_call",
129-
"name": name,
130-
"arguments": args,
131-
"pending": False,
132-
}
133-
try:
134-
result = await mcp.call_tool(name, args)
135-
except Exception as exc:
136-
logger.exception("Tool call %s failed", name)
137-
result = f"Error executing tool: {exc}"
138-
yield {
139-
"event": "tool_result",
140-
"name": name,
141-
"result": _truncate(result),
142-
}
108+
yield event
143109

144110
messages.append({"role": "tool", "tool_call_id": tc.id, "content": result})
145111

146112
yield {"event": "error", "message": "Max iterations reached"}
147113

148114

149-
MAX_HISTORY_CHARS = 80_000
115+
# ---------------------------------------------------------------------------
116+
# Helpers
117+
# ---------------------------------------------------------------------------
150118

151119

152-
def _trim_history(history: list[dict]) -> list[dict]:
153-
"""Keep only recent history that fits within a character budget."""
154-
total = 0
155-
result: list[dict] = []
156-
for msg in reversed(history):
157-
size = len(msg.get("content", ""))
158-
if total + size > MAX_HISTORY_CHARS:
159-
break
160-
result.append(msg)
161-
total += size
162-
result.reverse()
163-
return result
120+
def _build_messages(
121+
system_prompt: str,
122+
command: str | None,
123+
history: list[dict] | None,
124+
user_message: str,
125+
max_history_chars: int = 80_000,
126+
) -> list[dict]:
127+
"""Assemble the initial message list for the LLM."""
128+
msgs: list[dict] = [{"role": "system", "content": system_prompt}]
129+
msgs.extend(trim_history(history or [], max_history_chars))
130+
if command:
131+
msgs.append(
132+
{
133+
"role": "user",
134+
"content": f"Follow this command:\n\n{command}\n\nUser message: {user_message}",
135+
}
136+
)
137+
else:
138+
msgs.append({"role": "user", "content": user_message})
139+
return msgs
164140

165141

166-
def _truncate(text: str, limit: int = 80_000) -> str:
167-
if len(text) <= limit:
168-
return text
169-
return text[:limit] + "\n... (truncated)"
142+
def _parse_args(tc: object) -> dict:
143+
"""JSON-parse tool-call arguments with a safe fallback."""
144+
try:
145+
return json.loads(tc.function.arguments)
146+
except json.JSONDecodeError:
147+
return {}

mtv_agent/server/app.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def _write_startup_error(msg: str) -> None:
5959

6060
@asynccontextmanager
6161
async def lifespan(_app: FastAPI):
62-
global llm
62+
global llm, store
63+
64+
store = ChatStore(settings.cache_dir)
6365

6466
model = settings.llm_model
6567
if not model:
@@ -75,10 +77,12 @@ async def lifespan(_app: FastAPI):
7577
_write_startup_error(msg)
7678
raise SystemExit(1) from None
7779

80+
dump_dir = str(Path(settings.dump_dir).expanduser()) if settings.dump_llm else None
7881
llm = LLMClient(
7982
base_url=settings.llm_base_url,
8083
api_key=settings.llm_api_key,
8184
model=model,
85+
dump_dir=dump_dir,
8286
)
8387

8488
try:
@@ -201,7 +205,9 @@ async def event_generator():
201205
history=history,
202206
namespace=namespace,
203207
command=command_body,
208+
session_id=session_id,
204209
max_iterations=settings.max_iterations,
210+
max_history_chars=settings.max_history_chars,
205211
):
206212
if cancel_evt.is_set():
207213
yield {
@@ -324,10 +330,8 @@ def main():
324330
_write_startup_error(str(exc))
325331
raise SystemExit(1) from None
326332

327-
store = ChatStore(settings.cache_dir)
328-
329333
uvicorn.run(
330-
"mtv_agent.server.app:app",
334+
app,
331335
host=settings.host,
332336
port=settings.port,
333337
log_level="info",

mtv_agent/server/commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def load_commands(commands_dir: str) -> dict[str, dict]:
3636
3737
Returns a dict mapping command name to command data.
3838
"""
39-
base = Path(commands_dir)
39+
base = Path(commands_dir).expanduser()
4040
commands: dict[str, dict] = {}
4141
if not base.is_dir():
4242
logger.warning("Commands directory not found: %s", base)

mtv_agent/server/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def bundled_mcp_example() -> Path:
3333
return bundled_data_path("mcp.json.example")
3434

3535

36+
def bundled_policies_example() -> Path:
37+
return bundled_data_path("policies.json.example")
38+
39+
3640
# ---------------------------------------------------------------------------
3741
# Config file discovery
3842
# ---------------------------------------------------------------------------
@@ -147,7 +151,10 @@ class Settings:
147151
commands_dir: str = _BUNDLED_COMMANDS
148152
cache_dir: str = "~/.mtv-agent/cache"
149153
max_iterations: int = 20
154+
max_history_chars: int = 80_000
150155
mcp_config: str | None = None
156+
dump_llm: bool = False
157+
dump_dir: str = "~/.mtv-agent/dumps"
151158

152159

153160
def load_settings(override: str | None = None) -> Settings:
@@ -159,6 +166,7 @@ def load_settings(override: str | None = None) -> Settings:
159166
commands = data.get("commands", {})
160167
cache = data.get("cache", {})
161168
agent = data.get("agent", {})
169+
debug = data.get("debug", {})
162170

163171
return Settings(
164172
llm_base_url=llm.get("baseUrl", Settings.llm_base_url),
@@ -170,6 +178,9 @@ def load_settings(override: str | None = None) -> Settings:
170178
commands_dir=commands.get("dir", Settings.commands_dir),
171179
cache_dir=cache.get("dir", Settings.cache_dir),
172180
max_iterations=agent.get("maxIterations", Settings.max_iterations),
181+
max_history_chars=agent.get("maxHistoryChars", Settings.max_history_chars),
182+
dump_llm=debug.get("dumpLlm", Settings.dump_llm),
183+
dump_dir=debug.get("dumpDir", Settings.dump_dir),
173184
)
174185

175186

mtv_agent/server/data/config.json.example

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@
1818
"dir": "~/.mtv-agent/cache"
1919
},
2020
"agent": {
21-
"maxIterations": 20
21+
"maxIterations": 20,
22+
"maxHistoryChars": 80000
2223
},
2324
"tui": {
2425
"theme": "textual-dark"
26+
},
27+
"debug": {
28+
"dumpLlm": false,
29+
"dumpDir": "~/.mtv-agent/dumps"
2530
}
2631
}

0 commit comments

Comments
 (0)