1+ from __future__ import annotations
2+
13import logging
4+ import time
25from dataclasses import dataclass , fields , is_dataclass
36from typing import TYPE_CHECKING , Any
47
58from ag_ui .core .events import EventType , RunErrorEvent
6- from ag_ui .encoder import EventEncoder # noqa: TC002
79from copilotkit import LangGraphAGUIAgent
810from langgraph .store .memory import InMemoryStore
911
1820from .threads import ChatThreadService
1921
2022if TYPE_CHECKING :
23+ from collections .abc import AsyncIterator
24+
2125 from ag_ui .core import RunAgentInput
26+ from ag_ui .encoder import EventEncoder
2227
2328 from codebase .base import MergeRequest
29+ from codebase .context import RuntimeCtx
2430
2531logger = logging .getLogger ("daiv.chat" )
2632
2733# GitState fields that survive the ag-ui output-schema filter and reach the
28- # chat client through STATE_SNAPSHOT events. ``merge_request`` drives the
29- # composer MR pill; extend this list when adding new streamable state.
34+ # chat client through STATE_SNAPSHOT events.
3035STREAMED_STATE_KEYS = ("merge_request" ,)
3136
37+ # Bump ``last_active_at`` at most this often while the stream is alive.
38+ HEARTBEAT_INTERVAL_S = 5.0
39+
3240
3341class RuntimeContextLangGraphAGUIAgent (LangGraphAGUIAgent ):
34- """Default LangGraph's typed ``context=`` kwarg to the daiv RuntimeCtx dataclass.
42+ """Inject the daiv RuntimeCtx dataclass into upstream's stream kwargs .
3543
36- Upstream's ``get_stream_kwargs`` only accepts dict-shaped contexts (it merges via
37- ``dict.update``), but our graph declares ``context_schema=RuntimeCtx`` and expects
38- the frozen dataclass itself. We use ``setdefault`` so an upstream-provided context
39- still wins if one ever appears; today nothing populates it.
44+ Upstream's ``get_stream_kwargs`` only accepts dict-shaped contexts, but our graph
45+ declares ``context_schema=RuntimeCtx`` and expects the frozen dataclass itself.
4046 """
4147
42- def __init__ (self , * , runtime_context : Any , ** kwargs : Any ):
48+ def __init__ (self , * , runtime_context : RuntimeCtx , ** kwargs : Any ):
4349 super ().__init__ (** kwargs )
4450 self ._runtime_context = runtime_context
4551
@@ -50,14 +56,17 @@ def get_stream_kwargs(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
5056
5157 def get_schema_keys (self , config : Any ) -> dict [str , list [str ]]:
5258 # Upstream calls ``graph.config_schema().schema()`` which recurses into
53- # ``context_schema=RuntimeCtx``. RuntimeCtx holds a ``git.Repo`` field that pydantic
54- # cannot turn into JSON schema, so the call raises PydanticInvalidForJsonSchema.
55- # Derive context keys from the dataclass directly and keep the rest of the shape
56- # matching upstream's contract. ``output`` is the filter applied to every
57- # ``STATE_SNAPSHOT`` payload (``filter_object_by_schema_keys``), so fields we
58- # want streamed to the chat UI must be listed here explicitly.
59+ # ``context_schema=RuntimeCtx``. RuntimeCtx holds a ``git.Repo`` field that
60+ # pydantic cannot turn into JSON schema, so the call raises
61+ # PydanticInvalidForJsonSchema. Derive context keys from the dataclass directly.
5962 ctx_schema = getattr (self .graph , "context_schema" , None )
60- context_keys = [f .name for f in fields (ctx_schema )] if is_dataclass (ctx_schema ) else []
63+ if not is_dataclass (ctx_schema ):
64+ logger .warning (
65+ "chat: context_schema %r is not a dataclass; STATE_SNAPSHOT context keys will be empty" , ctx_schema
66+ )
67+ context_keys : list [str ] = []
68+ else :
69+ context_keys = [f .name for f in fields (ctx_schema )]
6170 constant = list (self .constant_schema_keys )
6271 return {"input" : constant , "output" : [* constant , * STREAMED_STATE_KEYS ], "config" : [], "context" : context_keys }
6372
@@ -76,8 +85,18 @@ class ChatRunStreamer:
7685 input_data : RunAgentInput
7786 encoder : EventEncoder
7887
79- async def events (self ):
88+ def __post_init__ (self ) -> None :
89+ # The view passes thread_id/run_id alongside input_data; a future refactor
90+ # could desync them silently. Pin the invariant here.
91+ if self .thread_id != self .input_data .thread_id :
92+ raise ValueError (f"thread_id mismatch: { self .thread_id !r} vs input_data { self .input_data .thread_id !r} " )
93+ if self .run_id != self .input_data .run_id :
94+ raise ValueError (f"run_id mismatch: { self .run_id !r} vs input_data { self .input_data .run_id !r} " )
95+
96+ async def events (self ) -> AsyncIterator [str ]:
8097 last_mr : MergeRequest | None = None
98+ clean_run = False
99+ last_heartbeat = time .monotonic ()
81100 try :
82101 async with (
83102 open_checkpointer () as checkpointer ,
@@ -103,22 +122,35 @@ async def events(self):
103122 if isinstance (snap , dict ) and "merge_request" in snap :
104123 last_mr = snap ["merge_request" ]
105124 yield self .encoder .encode (event )
106- except Exception as exc :
125+
126+ now = time .monotonic ()
127+ if now - last_heartbeat >= HEARTBEAT_INTERVAL_S :
128+ last_heartbeat = now
129+ try :
130+ await ChatThreadService .heartbeat (self .thread_id , self .run_id )
131+ except Exception :
132+ logger .exception ("chat: heartbeat failed for thread_id=%s" , self .thread_id )
133+ clean_run = True
134+ except Exception :
107135 logger .exception ("Chat run failed for thread_id=%s run_id=%s" , self .thread_id , self .run_id )
108136 yield self .encoder .encode (
109- RunErrorEvent (type = EventType .RUN_ERROR , message = f"{ type (exc ).__name__ } : { exc } " , code = "run_failed" )
137+ RunErrorEvent (
138+ type = EventType .RUN_ERROR , message = "Run failed. Check server logs for details." , code = "run_failed"
139+ )
110140 )
111141 finally :
112142 # Both cleanup steps are wrapped: a post-stream DB hiccup must not
113143 # retroactively paint a clean run as RUN_ERROR, and a release_run
114144 # failure must not leave the per-thread slot permanently claimed.
115- # Durable copy of the source_branch so reloads land on it; the pill
116- # itself updates client-side from the same STATE_SNAPSHOT stream.
117- try :
118- await ChatThreadService .persist_ref (self .thread_id , self .ref , last_mr )
119- except Exception :
120- logger .exception ("chat: failed to persist thread ref for thread_id=%s" , self .thread_id )
145+ # ref is only persisted on a clean finish — a partial run could have
146+ # checked out a branch without committing, and pinning it would
147+ # silently retarget reloads at half-built state.
148+ if clean_run :
149+ try :
150+ await ChatThreadService .persist_ref (self .thread_id , self .ref , last_mr )
151+ except Exception :
152+ logger .exception ("chat: failed to persist thread ref for thread_id=%s" , self .thread_id )
121153 try :
122- await ChatThreadService .release_run (self .thread_id )
154+ await ChatThreadService .release_run (self .thread_id , self . run_id )
123155 except Exception :
124156 logger .exception ("chat: failed to release run slot for thread_id=%s" , self .thread_id )
0 commit comments