Skip to content

Commit d9ab514

Browse files
committed
refactor(chat): drive MR pill via STATE_SNAPSHOT, split api/views
Move from a post-run ``daiv:repo_state`` CustomEvent to AG-UI's native STATE_SNAPSHOT stream for the composer MR pill. ``GitMiddleware`` now publishes ``merge_request`` on the public output schema and seeds it with any pre-existing open MR on the current branch, so the pill reflects reality from the first turn — no end-of-run checkpoint probe. Split ``chat/api/views.py`` into ``streaming.py`` (SSE generator + ``RuntimeContextLangGraphAGUIAgent``), ``threads.py`` (run-slot claim/ release + ref persistence), and ``event_filter.py`` (the subagent event reorder/suppress that previously lived inline). Drop the running-task step counter; the filter now handles nested frames server-side, so the client-side workaround is dead weight. Tool-stream polish: - web_search returns a JSON array; the renderer parses it into per-hit cards and a hit-count badge - web_fetch / gitlab / gh prefix failures with ``error:`` and the gitlab CLI truncation appends a sentinel — both feed result-row badges - new body renderers for web_fetch, web_search, gitlab, gh - activity-stream uses class-map objects so Alpine swaps status variants cleanly instead of letting old classes linger
1 parent a36d8c4 commit d9ab514

22 files changed

Lines changed: 1451 additions & 411 deletions

File tree

daiv/activity/static/activity/js/activity-stream.js

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,22 @@
22
* Alpine.js components for real-time activity status updates via SSE.
33
*
44
* activityStream (list page) — tracks multiple activities in place:
5-
* dotClass(id, fallback) → "status-dot-{variant}" CSS class
6-
* statusClass(id, fallback) → "status-badge-{variant}" CSS class
5+
* dotClass(id, fallback) → object toggling status-dot-{variant} classes
6+
* statusClass(id, fallback) → object toggling status-badge-{variant} classes
77
* statusLabel(id, fallback) → human-readable label
88
*
9+
* Object class maps (rather than a single string) are required so Alpine
10+
* removes the previously rendered variant class when the status transitions —
11+
* otherwise the static server-rendered class lingers alongside the new one
12+
* and the later CSS rule wins.
13+
*
914
* activityDetail (detail page) — subscribes to one activity and reloads the
1015
* page on any state change so server-rendered fields (started_at, finished_at,
1116
* elapsed counter, duration, timeline dots) reflect the new state.
1217
*/
1318
document.addEventListener("alpine:init", () => {
19+
const VARIANTS = ["success", "failed", "running", "pending"];
20+
1421
function statusVariantFor(status) {
1522
if (status === "SUCCESSFUL") return "success";
1623
if (status === "FAILED") return "failed";
@@ -25,6 +32,10 @@ document.addEventListener("alpine:init", () => {
2532
return "Pending";
2633
}
2734

35+
function variantClassMap(prefix, active) {
36+
return Object.fromEntries(VARIANTS.map((v) => [prefix + v, v === active]));
37+
}
38+
2839
Alpine.data("activityStream", (streamUrl, inFlightIds) => ({
2940
updates: {},
3041
init() {
@@ -42,10 +53,10 @@ document.addEventListener("alpine:init", () => {
4253
source.onerror = () => source.close();
4354
},
4455
statusClass(id, fallback) {
45-
return "status-badge-" + statusVariantFor(this.updates[id]?.status || fallback);
56+
return variantClassMap("status-badge-", statusVariantFor(this.updates[id]?.status || fallback));
4657
},
4758
dotClass(id, fallback) {
48-
return "status-dot-" + statusVariantFor(this.updates[id]?.status || fallback);
59+
return variantClassMap("status-dot-", statusVariantFor(this.updates[id]?.status || fallback));
4960
},
5061
statusLabel(id, fallback) {
5162
const update = this.updates[id];

daiv/automation/agent/middlewares/git.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from typing import TYPE_CHECKING, Annotated, Any, cast
55

6+
from asgiref.sync import sync_to_async
67
from langchain.agents import AgentState
78
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
89
from langchain.agents.middleware.types import PrivateStateAttr
@@ -11,6 +12,7 @@
1112

1213
from automation.agent.publishers import GitChangePublisher
1314
from codebase.base import MergeRequest, Scope
15+
from codebase.clients import RepoClient
1416
from codebase.context import RuntimeCtx # noqa: TC001
1517
from codebase.utils import GitManager, get_repo_ref
1618

@@ -59,9 +61,11 @@ class GitState(AgentState):
5961
State for the git middleware.
6062
"""
6163

62-
merge_request: Annotated[MergeRequest | None, PrivateStateAttr]
64+
merge_request: MergeRequest | None
6365
"""
64-
The merge request used to commit the changes.
66+
The merge request used to commit the changes. Public on the output schema so
67+
it streams in AG-UI ``STATE_SNAPSHOT`` events — the chat UI's MR pill is wired
68+
directly to this field instead of a custom post-run event.
6569
"""
6670

6771
code_changes: Annotated[bool, PrivateStateAttr]
@@ -116,6 +120,11 @@ async def abefore_agent(self, state: GitState, runtime: Runtime[RuntimeCtx]) ->
116120
# In this case, ignore the branch name and merge request ID from the state,
117121
# and use the source branch and merge request ID from the merge request.
118122
merge_request = runtime.context.merge_request
123+
elif merge_request is None:
124+
# Surface any pre-existing open MR on the current branch so the chat
125+
# composer pill reflects reality from the very first turn. Issue-scope
126+
# runs always start on the default branch, where this lookup short-circuits.
127+
merge_request = await self._alookup_open_mr(runtime.context)
119128

120129
if merge_request and merge_request.source_branch != get_repo_ref(runtime.context.gitrepo):
121130
git_manager = GitManager(runtime.context.gitrepo)
@@ -131,6 +140,23 @@ async def abefore_agent(self, state: GitState, runtime: Runtime[RuntimeCtx]) ->
131140

132141
return {"merge_request": merge_request, "code_changes": False}
133142

143+
@staticmethod
144+
async def _alookup_open_mr(context: RuntimeCtx) -> MergeRequest | None:
145+
"""Best-effort lookup of an open MR whose source branch matches the current ref."""
146+
current_branch = get_repo_ref(context.gitrepo)
147+
if not current_branch or current_branch == context.config.default_branch:
148+
return None
149+
try:
150+
client = RepoClient.create_instance()
151+
return await sync_to_async(client.get_merge_request_by_branches)(
152+
context.repository.slug, current_branch, context.config.default_branch
153+
)
154+
except Exception:
155+
logger.exception(
156+
"Failed to look up open merge request for %s on %s", context.repository.slug, current_branch
157+
)
158+
return None
159+
134160
async def awrap_model_call(
135161
self, request: ModelRequest[RuntimeCtx], handler: Callable[[ModelRequest[RuntimeCtx]], Awaitable[ModelResponse]]
136162
) -> ModelResponse:

daiv/automation/agent/middlewares/git_platform.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,29 @@
3636
DEFAULT_MAX_OUTPUT_LINES = 2_000
3737
DEFAULT_CLI_TIMEOUT = 30
3838

39+
40+
def _truncate_cli_output(output: str, *, keep: Literal["head", "tail"]) -> str:
41+
"""
42+
Cap CLI output to ``DEFAULT_MAX_OUTPUT_LINES``, appending a sentinel when
43+
truncation occurs so the agent and the chat UI both know the slice is
44+
partial. ``keep="tail"`` is for job traces / run logs where the failing
45+
tail is the interesting part.
46+
"""
47+
# Cheap line count avoids materializing splitlines on the (common) happy path.
48+
if output.count("\n") < DEFAULT_MAX_OUTPUT_LINES:
49+
return output
50+
51+
lines = output.splitlines(keepends=True)
52+
if len(lines) <= DEFAULT_MAX_OUTPUT_LINES:
53+
return output
54+
55+
omitted = len(lines) - DEFAULT_MAX_OUTPUT_LINES
56+
sentinel = f"... (truncated, {omitted} lines omitted)\n"
57+
if keep == "tail":
58+
return sentinel + "".join(lines[-DEFAULT_MAX_OUTPUT_LINES:])
59+
return "".join(lines[:DEFAULT_MAX_OUTPUT_LINES]) + sentinel
60+
61+
3962
GITLAB_REQUESTS_TIMEOUT = 15
4063
GITLAB_PER_PAGE = "5"
4164
GITLAB_TOOL_NAME = "gitlab"
@@ -607,9 +630,9 @@ async def gitlab_tool(
607630
if resource == "project-job" and splitted_subcommand[1] == "trace":
608631
# TODO: evict the output to the file system if it's too long
609632
output = clean_job_logs(output, runtime.context.git_platform)
610-
return "".join(output.splitlines(keepends=True)[-DEFAULT_MAX_OUTPUT_LINES:])
633+
return _truncate_cli_output(output, keep="tail")
611634

612-
return "".join(output.splitlines(keepends=True)[:DEFAULT_MAX_OUTPUT_LINES])
635+
return _truncate_cli_output(output, keep="head")
613636

614637

615638
def _get_cached_github_cli_token(runtime: ToolRuntime[RuntimeCtx]) -> tuple[str, dict[str, str | float] | None]:
@@ -753,9 +776,9 @@ async def github_tool(
753776
elif resource == "run" and action == "view" and "--log" in splitted_subcommand:
754777
# TODO: evict the output to the file system if it's too long
755778
output = clean_job_logs(output, runtime.context.git_platform)
756-
output = "".join(output.splitlines(keepends=True)[-DEFAULT_MAX_OUTPUT_LINES:])
779+
output = _truncate_cli_output(output, keep="tail")
757780
else:
758-
output = "".join(output.splitlines(keepends=True)[:DEFAULT_MAX_OUTPUT_LINES])
781+
output = _truncate_cli_output(output, keep="head")
759782

760783
# Return Command with state update if token was cached/refreshed
761784
if state_update:

daiv/automation/agent/middlewares/web_fetch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ async def web_fetch_tool(
204204
"""
205205
url = _upgrade_http_to_https(url.strip())
206206
if not _is_valid_http_url(url):
207-
return "Invalid URL. Provide a fully-formed http(s) URL (e.g., https://example.com)."
207+
return "error: Invalid URL. Provide a fully-formed http(s) URL (e.g., https://example.com)."
208208

209209
prompt = prompt or ""
210210

@@ -218,12 +218,12 @@ async def web_fetch_tool(
218218
# Used for special redirect signaling.
219219
return str(e)
220220
except Exception as e:
221-
return f"Failed to fetch URL: {e}"
221+
return f"error: Failed to fetch URL: {e}"
222222

223223
# Safety guard: avoid silently truncating; ask for a narrower URL/prompt instead.
224224
if len(content) > site_settings.web_fetch_max_content_chars:
225225
return (
226-
"Page content is too large to safely analyze in one pass.\n"
226+
"error: Page content is too large to safely analyze in one pass.\n"
227227
"Provide a more specific URL (e.g. a specific section/anchor) or narrow the prompt."
228228
)
229229

daiv/automation/agent/middlewares/web_search.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3+
import json
34
import logging
4-
import textwrap
55
from typing import TYPE_CHECKING, Annotated
66

77
from django.utils import timezone
@@ -45,13 +45,18 @@
4545
- Access up-to-date information for current events and recent data
4646
- Access information beyond your knowledge cutoff
4747
48+
Result format:
49+
- The tool returns a JSON array of objects with `title`, `link`, and `content` fields.
50+
- An empty array (`[]`) means no relevant results were found — broaden the query and retry, or tell the user no results exist.
51+
- Tavily may prepend a synthesized summary as the first entry with `title="Suggested answer"` and `link=""`. Treat it as a hint, not a citable source.
52+
4853
IMPORTANT - Use the correct year in search queries:
4954
- You MUST use this year when searching for recent information, documentation, or current events.
5055
- Example: If today is {{current_year}}-07-15 and the user asks for "latest React docs", search for "React documentation {{current_year}}", NOT "React documentation {{previous_year}}".
5156
5257
CRITICAL REQUIREMENT - You MUST follow this when using web search:
5358
- After answering the user's question using web search results, you MUST include a "Sources:" section at the end of your response when the answer primarily derives from search results.
54-
- In the Sources section, list all relevant URLs from the search results as markdown hyperlinks: [Title](URL)
59+
- In the Sources section, list each relevant entry's `link` as a markdown hyperlink using its `title`: `[<title>](<link>)`. Skip entries with an empty `link` (the "Suggested answer" hint).
5560
- This is MANDATORY - never skip including sources in your response
5661
- Example format:
5762
@@ -129,19 +134,11 @@ async def web_search_tool(query: Annotated[str, "The search query."]) -> str:
129134
Tool to search the web and use the results to inform responses.
130135
""" # noqa: E501
131136

132-
if not (results := await _get_web_search_results(query)):
133-
return "No relevant results found for the given search query."
134-
135-
return "\n".join([
136-
textwrap.dedent(
137-
"""\
138-
<web_search_result title="{title}" link="{link}">
139-
{body}
140-
</web_search_result>
141-
"""
142-
).format(title=result["title"], link=result["link"], body=result["content"])
143-
for result in results
144-
])
137+
results = await _get_web_search_results(query)
138+
# `ensure_ascii=False` keeps non-ASCII titles/snippets readable for the model
139+
# (and saves tokens vs. \uXXXX escapes). An empty array is a real, valid
140+
# outcome the model is told how to handle in the system prompt.
141+
return json.dumps(results, ensure_ascii=False)
145142

146143

147144
class WebSearchMiddleware(AgentMiddleware):

daiv/chat/api/event_filter.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import json
2+
from typing import TYPE_CHECKING, Any
3+
4+
from ag_ui.core.events import EventType, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent
5+
6+
if TYPE_CHECKING:
7+
from collections.abc import AsyncIterator, Iterable
8+
9+
from ag_ui.core.events import BaseEvent
10+
11+
12+
class SubagentEventFilter:
13+
"""Reorder/suppress AGUI events so subagent frames don't leak into the parent turn.
14+
15+
Two upstream behaviors collide on ``task``-tool turns:
16+
17+
1. ag_ui_langgraph drops the parent's ``task`` TOOL_CALL_START on the
18+
text→tool_call transition chunk (the chunk that ends the parent's text
19+
stream also carries the new tool_call name, but the handler returns
20+
after emitting TEXT_MESSAGE_END). Subsequent chunks only have args, so
21+
OnChatModelStream never reaches ``is_tool_call_start_event``. The
22+
``task`` TOOL_CALL_START finally arrives from the OnToolEnd re-emit —
23+
*after* the subagent has already streamed text/tool calls to the
24+
client.
25+
26+
2. With ``stream_subgraphs=True``, every chunk emitted from inside
27+
``subagent.ainvoke()`` flows through the parent's stream with a
28+
nested ``langgraph_checkpoint_ns`` (``"tools:UUID|model:UUID"``).
29+
Without (1)'s TOOL_CALL_START there is no ``task`` segment to
30+
suppress them against.
31+
32+
This filter:
33+
34+
* captures ``task`` tool_call ids from top-level STATE_SNAPSHOT events,
35+
* synthesizes TOOL_CALL_START + ARGS + END for each on the first nested
36+
event so the chat creates the segment *before* the subagent runs,
37+
* drops every nested event (``|`` in ns),
38+
* drops the LATE OnToolEnd re-emitted START/ARGS/END for tool_calls we
39+
already synthesized (deduping by tool_call_id).
40+
41+
The parent's TOOL_CALL_RESULT for the task tool still flows through
42+
untouched — it's a top-level event with the same ``tool_call_id``, so the
43+
chat UI flips the synthesized segment to ``done`` exactly like a normal
44+
tool call.
45+
"""
46+
47+
# Tool name used by deepagents' SubAgentMiddleware to invoke a subagent.
48+
TASK_TOOL_NAME = "task"
49+
50+
def __init__(self) -> None:
51+
# Two-state lifecycle: a tool_call_id starts in ``_pending`` (synthesize
52+
# on next nested event), then moves to ``_emitted`` (drop the late
53+
# re-emit). Membership in either is enough to dedup a STATE_SNAPSHOT
54+
# rebroadcast; ``_emitted`` alone gates the late TOOL_CALL_*
55+
# re-emit drop.
56+
self._pending: dict[str, tuple[str, Any]] = {}
57+
self._emitted: set[str] = set()
58+
59+
async def apply(self, stream: AsyncIterator[BaseEvent]) -> AsyncIterator[BaseEvent]:
60+
async for event in stream:
61+
ns = self._checkpoint_ns(event)
62+
is_nested = "|" in ns
63+
64+
if not is_nested and event.type == EventType.STATE_SNAPSHOT:
65+
for tcid, name, args in self._iter_latest_task_calls(event):
66+
if tcid not in self._pending and tcid not in self._emitted:
67+
self._pending[tcid] = (name, args)
68+
69+
if is_nested:
70+
for tcid, (name, args) in self._pending.items():
71+
yield ToolCallStartEvent(type=EventType.TOOL_CALL_START, tool_call_id=tcid, tool_call_name=name)
72+
if args:
73+
# ``default=str`` so a Pydantic model / datetime / other
74+
# non-JSON-native object in args doesn't kill the entire
75+
# chat stream — better a stringified field than RUN_ERROR.
76+
delta = args if isinstance(args, str) else json.dumps(args, default=str)
77+
yield ToolCallArgsEvent(type=EventType.TOOL_CALL_ARGS, tool_call_id=tcid, delta=delta)
78+
yield ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=tcid)
79+
self._emitted.add(tcid)
80+
self._pending.clear()
81+
continue
82+
83+
if event.type in (EventType.TOOL_CALL_START, EventType.TOOL_CALL_ARGS, EventType.TOOL_CALL_END):
84+
tcid = getattr(event, "tool_call_id", None)
85+
if isinstance(tcid, str) and tcid in self._emitted:
86+
continue
87+
88+
yield event
89+
90+
@staticmethod
91+
def _checkpoint_ns(event: BaseEvent) -> str:
92+
"""Extract the LangGraph checkpoint namespace from an AGUI event's raw_event.
93+
94+
A ``|`` in the namespace means the event was emitted from a *nested*
95+
LangGraph execution — i.e. from inside a subagent invoked by the parent's
96+
``task`` tool. Top-level events have an empty ns or a single
97+
``"<node>:UUID"`` segment (e.g. ``"model:..."``, ``"tools:..."``) with no
98+
pipe.
99+
"""
100+
raw = getattr(event, "raw_event", None)
101+
if not isinstance(raw, dict):
102+
return ""
103+
md = raw.get("metadata") or {}
104+
return str(md.get("langgraph_checkpoint_ns", "") or "")
105+
106+
@classmethod
107+
def _iter_latest_task_calls(cls, event: BaseEvent) -> Iterable[tuple[str, str, Any]]:
108+
"""Yield ``(tool_call_id, name, args)`` for every ``task`` tool_call on the
109+
snapshot's latest AIMessage. Caller is responsible for dedup against
110+
already-emitted ids — this is just the per-snapshot scan.
111+
"""
112+
snap = getattr(event, "snapshot", None)
113+
if not isinstance(snap, dict):
114+
return
115+
msgs = snap.get("messages")
116+
if not isinstance(msgs, list):
117+
return
118+
# Only the latest AIMessage matters — older AIMessages were emitted on
119+
# earlier snapshots and their task ids are already in ``task_calls``.
120+
# Walking past the latest is just wasted work; the dedup map at the call
121+
# site is what guarantees no double-synthesis.
122+
for m in reversed(msgs):
123+
if cls._msg_role(m) not in ("ai", "assistant"):
124+
continue
125+
for tc in cls._msg_field(m, "tool_calls") or []:
126+
tcid = cls._msg_field(tc, "id")
127+
name = cls._msg_field(tc, "name")
128+
if name == cls.TASK_TOOL_NAME and isinstance(tcid, str):
129+
yield tcid, name, cls._msg_field(tc, "args")
130+
return
131+
132+
@staticmethod
133+
def _msg_field(message: Any, name: str, default: Any = None) -> Any:
134+
"""Read a field from a LangChain message or its dict-encoded form.
135+
136+
STATE_SNAPSHOT can carry either shape depending on whether the snapshot
137+
has been serialized yet — running through the AGUI encoder turns objects
138+
into dicts, but the filter here sits *before* the encoder.
139+
"""
140+
if isinstance(message, dict):
141+
return message.get(name, default)
142+
return getattr(message, name, default)
143+
144+
@classmethod
145+
def _msg_role(cls, message: Any) -> str:
146+
return str(cls._msg_field(message, "type", "") or cls._msg_field(message, "role", "") or "").lower()

0 commit comments

Comments
 (0)