Skip to content

Commit 4e53356

Browse files
committed
Refactor intent classification and update tests
- Removed the lightweight regex-based `classify` function from `agent_config.py` and replaced it with LLM-based classification using `SapAgentFactory.classify_intent()`. - Updated `AgentConfig` and related documentation to reflect the new classification method. - Removed associated tests for the old classification method from `agent_config_test.py`. - Added new tests for LLM-based intent classification in `agent_test.py`, ensuring correct mapping of intents based on LLM responses. - Refactored history provider tests in `history_provider_test.py` to utilize a new message construction method for consistency and clarity. - Ensured all tests pass with the new classification approach and updated assertions where necessary.
1 parent a228e6f commit 4e53356

6 files changed

Lines changed: 718 additions & 388 deletions

File tree

src/agents/ag_ui.py

Lines changed: 92 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66
77
Uses a ``SapWorkflow`` subclass of ``AgentFrameworkWorkflow`` that
88
handles conversation persistence at the workflow boundary, so
9-
individual agent sessions inside ``SequentialBuilder`` do not need
10-
to know about the AG-UI ``thread_id``.
9+
agent sessions do not need to know about the AG-UI ``thread_id``.
10+
11+
Architecture:
12+
- **TRIAGE/TEST**: HandoffBuilder with Coordinator → Investigator / TestRunner.
13+
Specialist text is emitted as ThinkingText events; Coordinator's final
14+
response becomes the user-visible answer.
15+
- **GENERAL/KNOWLEDGE**: Single agent with all tools. All text is user-visible.
16+
Tool calls stream naturally between reasoning segments.
1117
"""
1218

1319
from __future__ import annotations
@@ -38,25 +44,32 @@
3844
AgentFrameworkWorkflow,
3945
add_agent_framework_fastapi_endpoint,
4046
)
47+
from agent_framework_ag_ui._workflow_run import run_workflow_stream
4148

4249
from agent_framework import Message as AFMessage
4350
from agent_framework._types import Content
4451

4552
from src.agents.agent import SapAgentFactory
46-
from src.agents.agent_config import TRIAGE_CONFIG
53+
from src.agents.agent_config import config_for_intent
4754
from src.core.models.conversation import Conversation
4855
from src.core.storage.conversation_store import ConversationStore
4956

5057
logger = logging.getLogger(__name__)
5158

5259

5360
class SapWorkflow(AgentFrameworkWorkflow):
54-
"""Workflow that persists messages at the workflow boundary.
61+
"""Workflow that classifies intent per request and persists messages.
5562
56-
Wraps ``AgentFrameworkWorkflow`` and intercepts ``run()`` to:
57-
1. Auto-create the conversation if it does not exist.
58-
2. Save the user message and final assistant response.
59-
3. Fire-and-forget title generation on first turn.
63+
Overrides ``run()`` to bypass ``workflow_factory`` entirely.
64+
Instead, each request:
65+
66+
1. Extracts the user message from the AG-UI input.
67+
2. Classifies intent via ``classify()`` (regex heuristics).
68+
3. Builds a fresh ``Workflow`` via ``SapAgentFactory.create_workflow()``
69+
with the correct ``AgentConfig``, ``user_query``, and ``thread_id``.
70+
4. Calls ``run_workflow_stream()`` directly to convert workflow
71+
events into AG-UI events.
72+
5. Persists user + assistant messages to SQLite.
6073
6174
:param factory: Agent factory with MCP connections.
6275
:param conversation_store: SQLite conversation store.
@@ -68,29 +81,23 @@ def __init__(
6881
conversation_store: ConversationStore | None,
6982
**kwargs: Any,
7083
) -> None:
71-
super().__init__(
72-
workflow_factory=lambda thread_id: factory.create_workflow(
73-
config=TRIAGE_CONFIG,
74-
thread_id=thread_id,
75-
),
76-
**kwargs,
77-
)
84+
# No workflow or workflow_factory — we create workflows in run().
85+
super().__init__(**kwargs)
7886
self._factory = factory
7987
self._store = conversation_store
8088

81-
_THINKING_STEPS = frozenset({"Planner"})
89+
_THINKING_AGENTS = frozenset({"Investigator", "TestRunner"})
8290

8391
async def run(
8492
self,
8593
input_data: dict[str, Any],
8694
) -> AsyncGenerator[BaseEvent]:
87-
"""Run the workflow, convert intermediate text to thinking, persist.
95+
"""Run the workflow, stream events, and persist messages.
8896
89-
Planner/Executor text is re-emitted as ``ThinkingTextMessage*``
90-
events so the UI can show it as small ephemeral text (like
91-
VS Code Copilot's reasoning display). Tool call events pass
92-
through unchanged for progress visibility. Only the Analyst's
93-
text becomes the permanent assistant response.
97+
Creates a fresh workflow per request with dynamic intent
98+
classification. Specialist text (Investigator/TestRunner)
99+
is emitted as ``ThinkingTextMessage*`` events; Coordinator
100+
and single-agent text is user-visible.
94101
95102
:param input_data: AG-UI input dict with ``thread_id`` and
96103
``messages``.
@@ -100,33 +107,61 @@ async def run(
100107
run_id = input_data.get("run_id", str(uuid4()))
101108
user_text = self._extract_user_text(input_data)
102109

110+
# Classify intent from the actual user message via LLM.
111+
intent = await self._factory.classify_intent(user_text)
112+
config = config_for_intent(intent)
113+
103114
logger.info(
104-
"AG-UI run: thread_id=%r, run_id=%s, user_text=%s, "
105-
"msg_count=%d, keys=%s",
115+
"AG-UI run: thread_id=%r, run_id=%s, intent=%s, "
116+
"user_text=%s, msg_count=%d",
106117
thread_id,
107118
run_id[:12] if run_id else "(none)",
119+
intent.value,
108120
bool(user_text),
109121
len(input_data.get("messages", [])),
110-
list(input_data.keys()),
111122
)
112123

113124
if self._store and thread_id:
114125
self._ensure_conversation(thread_id)
115126

127+
# Build a fresh workflow with the classified config.
128+
workflow = self._factory.create_workflow(
129+
config=config,
130+
user_query=user_text,
131+
thread_id=thread_id,
132+
)
133+
116134
ordered_parts: list[dict[str, Any]] = []
117135
pending_text: list[str] = []
118-
current_step: str = ""
136+
current_agent: str = ""
119137
thinking_msg_ids: set[str] = set()
120-
thinking_step_open: bool = False
138+
thinking_open: bool = False
121139
open_tool_call_ids: list[str] = []
122140
tool_call_names: dict[str, str] = {}
123141
tool_call_args: dict[str, list[str]] = {}
124142
completed_tools: list[dict[str, str]] = []
125143

126-
async for event in super().run(input_data):
127-
if open_tool_call_ids and not isinstance(event, (ToolCallArgsEvent, ToolCallEndEvent)):
144+
async for event in run_workflow_stream(input_data, workflow):
145+
# ── Skip handoff tool calls (internal routing) ──
146+
if isinstance(event, ToolCallStartEvent):
147+
name = event.tool_call_name or ""
148+
if name.startswith("handoff_to_"):
149+
tool_call_names[event.tool_call_id] = name
150+
continue
151+
if isinstance(
152+
event, (ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent),
153+
):
154+
tc_id = event.tool_call_id
155+
if tool_call_names.get(tc_id, "").startswith("handoff_to_"):
156+
continue
157+
# ── Flush orphan tool calls when a non-tool event arrives ──
158+
if open_tool_call_ids and not isinstance(
159+
event, (ToolCallArgsEvent, ToolCallEndEvent),
160+
):
128161
if pending_text:
129-
ordered_parts.append({"type": "text", "text": "".join(pending_text)})
162+
ordered_parts.append(
163+
{"type": "text", "text": "".join(pending_text)},
164+
)
130165
pending_text.clear()
131166
for tc_id in open_tool_call_ids:
132167
ordered_parts.append({"type": "tool_ref", "id": tc_id})
@@ -145,9 +180,11 @@ async def run(
145180
)
146181
open_tool_call_ids.clear()
147182

183+
# ── Tool call lifecycle ──
148184
if isinstance(event, ToolCallStartEvent):
185+
name = event.tool_call_name or "tool"
149186
open_tool_call_ids.append(event.tool_call_id)
150-
tool_call_names[event.tool_call_id] = event.tool_call_name or "tool"
187+
tool_call_names[event.tool_call_id] = name
151188
tool_call_args[event.tool_call_id] = []
152189
yield event
153190
continue
@@ -157,7 +194,9 @@ async def run(
157194
if tc_id in open_tool_call_ids:
158195
open_tool_call_ids.remove(tc_id)
159196
if pending_text:
160-
ordered_parts.append({"type": "text", "text": "".join(pending_text)})
197+
ordered_parts.append(
198+
{"type": "text", "text": "".join(pending_text)},
199+
)
161200
pending_text.clear()
162201
ordered_parts.append({"type": "tool_ref", "id": tc_id})
163202
result_text = f"{tool_call_names.get(tc_id, 'tool')} completed"
@@ -190,32 +229,35 @@ async def run(
190229
yield event
191230
continue
192231

232+
# ── Step tracking (maps to agent names) ──
193233
if isinstance(event, StepStartedEvent):
194-
current_step = event.step_name or ""
195-
logger.info("Step started: %r", current_step)
234+
current_agent = event.step_name or ""
235+
logger.info("Agent started: %r", current_agent)
196236
yield event
197237
continue
198238
if isinstance(event, StepFinishedEvent):
199-
logger.info("Step finished: %r", current_step)
200-
current_step = ""
239+
logger.info("Agent finished: %r", current_agent)
240+
current_agent = ""
201241
yield event
202242
continue
203243

244+
# ── Text message end (close thinking if needed) ──
204245
if isinstance(event, TextMessageEndEvent):
205246
if event.message_id in thinking_msg_ids:
206247
thinking_msg_ids.discard(event.message_id)
207248
yield ThinkingTextMessageEndEvent()
208249
if not thinking_msg_ids:
209250
yield ThinkingEndEvent()
210-
thinking_step_open = False
251+
thinking_open = False
211252
continue
212253

213-
if current_step in self._THINKING_STEPS:
254+
# ── Specialist text → thinking bubbles ──
255+
if current_agent in self._THINKING_AGENTS:
214256
if isinstance(event, TextMessageStartEvent):
215257
thinking_msg_ids.add(event.message_id)
216-
if not thinking_step_open:
258+
if not thinking_open:
217259
yield ThinkingStartEvent()
218-
thinking_step_open = True
260+
thinking_open = True
219261
yield ThinkingTextMessageStartEvent()
220262
continue
221263
if isinstance(event, TextMessageContentEvent):
@@ -226,13 +268,21 @@ async def run(
226268
)
227269
continue
228270

271+
# ── Default: pass through (user-visible text) ──
229272
if isinstance(event, TextMessageContentEvent):
230273
pending_text.append(event.delta)
231274

232275
yield event
233276

277+
# ── Flush remaining state ──
278+
if thinking_open:
279+
yield ThinkingEndEvent()
280+
thinking_open = False
281+
234282
if pending_text:
235-
ordered_parts.append({"type": "text", "text": "".join(pending_text)})
283+
ordered_parts.append(
284+
{"type": "text", "text": "".join(pending_text)},
285+
)
236286
pending_text.clear()
237287
for tc_id in open_tool_call_ids:
238288
ordered_parts.append({"type": "tool_ref", "id": tc_id})
@@ -251,6 +301,7 @@ async def run(
251301
role="tool",
252302
)
253303

304+
# ── Persist ──
254305
if self._store and thread_id:
255306
if user_text:
256307
self._save_user_message(thread_id, user_text)

0 commit comments

Comments
 (0)