Skip to content

Commit 9de9e2c

Browse files
committed
Review improvements.
1 parent 62f5ef8 commit 9de9e2c

32 files changed

Lines changed: 667 additions & 168 deletions

daiv/activity/migrations/0009_activity_thread_id.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,12 @@ class Migration(migrations.Migration):
1919
unique=True,
2020
verbose_name="thread ID",
2121
),
22-
)
22+
),
23+
migrations.AddConstraint(
24+
model_name="activity",
25+
constraint=models.CheckConstraint(
26+
condition=models.Q(("thread_id__isnull", True)) | models.Q(("thread_id", ""), _negated=True),
27+
name="activity_thread_id_nonempty",
28+
),
29+
),
2330
]

daiv/activity/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ class Meta:
185185
condition=models.Q(external_username__gt=""),
186186
),
187187
]
188+
constraints = [
189+
# ``thread_id`` is unique=True; "" would collide on the second insert
190+
# under Postgres (which treats NULL as not-equal but "" as a real
191+
# value). Forbid the empty-string sentinel so callers must use NULL.
192+
models.CheckConstraint(
193+
condition=models.Q(thread_id__isnull=True) | ~models.Q(thread_id=""), name="activity_thread_id_nonempty"
194+
)
195+
]
188196

189197
def __str__(self) -> str:
190198
return f"{self.get_trigger_type_display()} on {self.repo_id} ({self.status})"

daiv/automation/agent/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_model_kwargs(
220220
else:
221221
# `enabled: true` is the universal switch on OpenRouter; some providers
222222
# (notably z.ai's GLM family) ignore `effort` and require the explicit flag.
223-
_kwargs["extra_body"] = {"reasoning": {"enabled": True, "effort": thinking_level.value}}
223+
_kwargs["extra_body"] = {"reasoning": {"enabled": True, "effort": thinking_level}}
224224

225225
elif _kwargs["model"].startswith("anthropic") and "max_tokens" not in _kwargs:
226226
# Avoid rate limiting by setting a fair max_tokens value

daiv/automation/agent/middlewares/git.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import logging
44
from typing import TYPE_CHECKING, Annotated, Any, cast
55

6+
import httpx
67
from asgiref.sync import sync_to_async
8+
from github import GithubException
9+
from gitlab.exceptions import GitlabError
710
from langchain.agents import AgentState
811
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
912
from langchain.agents.middleware.types import PrivateStateAttr
@@ -16,6 +19,11 @@
1619
from codebase.context import RuntimeCtx # noqa: TC001
1720
from codebase.utils import GitManager, get_repo_ref
1821

22+
# Platform / transport errors that warrant a soft "no MR" fallback. Bugs
23+
# (KeyError, AttributeError, etc.) propagate so the run fails loudly rather
24+
# than producing a duplicate MR downstream.
25+
_MR_LOOKUP_PLATFORM_ERRORS: tuple[type[BaseException], ...] = (GitlabError, GithubException, httpx.HTTPError)
26+
1927
if TYPE_CHECKING:
2028
from collections.abc import Awaitable, Callable
2129

@@ -134,15 +142,20 @@ async def abefore_agent(self, state: GitState, runtime: Runtime[RuntimeCtx]) ->
134142
try:
135143
git_manager.checkout(merge_request.source_branch)
136144
except ValueError as e:
137-
# The branch does not exist in the repository, so we need to create it.
145+
# Branch from the MR no longer exists locally; treat as no MR
146+
# and let the publisher decide whether to recreate it.
138147
logger.warning("[%s] Failed to checkout to branch '%s': %s", self.name, merge_request.source_branch, e)
139148
merge_request = None
140149

141150
return {"merge_request": merge_request, "code_changes": False}
142151

143152
@staticmethod
144153
async def _alookup_open_mr(context: RuntimeCtx) -> MergeRequest | None:
145-
"""Best-effort lookup of an open MR whose source branch matches the current ref."""
154+
"""Best-effort lookup of an open MR whose source branch matches the current ref.
155+
156+
Soft-fails on platform/transport errors so the agent can still run — the
157+
publisher will create a fresh MR if needed. Programming bugs propagate.
158+
"""
146159
current_branch = get_repo_ref(context.gitrepo)
147160
if not current_branch or current_branch == context.config.default_branch:
148161
return None
@@ -151,7 +164,7 @@ async def _alookup_open_mr(context: RuntimeCtx) -> MergeRequest | None:
151164
return await sync_to_async(client.get_merge_request_by_branches)(
152165
context.repository.slug, current_branch, context.config.default_branch
153166
)
154-
except Exception:
167+
except _MR_LOOKUP_PLATFORM_ERRORS:
155168
logger.exception(
156169
"Failed to look up open merge request for %s on %s", context.repository.slug, current_branch
157170
)

daiv/automation/agent/middlewares/sandbox.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,8 @@ async def abefore_agent(self, state: StateT, runtime: Runtime[RuntimeCtx]) -> di
360360
dict[str, str] | None: The state updates with the sandbox session ID.
361361
"""
362362
if not self.close_session and "session_id" in state:
363-
# If the session is not being closed, don't start a new one, reuse the existing one.
364-
# Also, avoid reusing the session_id if it is already set from a previous run that failed to close
365-
# the session.
363+
# Subagent path: the parent already started a session and owns its
364+
# lifecycle. Skip starting a duplicate.
366365
return None
367366

368367
session_id = await DAIVSandboxClient().start_session(

daiv/automation/agent/middlewares/web_fetch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ async def web_fetch_tool(
208208

209209
prompt = prompt or ""
210210

211-
# Cache the final response for a given (url, prompt, model).
211+
# Cache key is (url, prompt). The summarisation model is intentionally NOT
212+
# part of the key today — if the active model is rotated and you want fresh
213+
# answers, also bump ``_cache_key_for_response``.
212214
if prompt.strip() and (cached := _get_cached_response(url=url, prompt=prompt)) is not None:
213215
return str(cached)
214216

daiv/chat/api/event_filter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import json
24
from typing import TYPE_CHECKING, Any
35

daiv/chat/api/streaming.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from __future__ import annotations
2+
13
import logging
4+
import time
25
from dataclasses import dataclass, fields, is_dataclass
36
from typing import TYPE_CHECKING, Any
47

58
from ag_ui.core.events import EventType, RunErrorEvent
6-
from ag_ui.encoder import EventEncoder # noqa: TC002
79
from copilotkit import LangGraphAGUIAgent
810
from langgraph.store.memory import InMemoryStore
911

@@ -18,28 +20,32 @@
1820
from .threads import ChatThreadService
1921

2022
if TYPE_CHECKING:
23+
from collections.abc import AsyncIterator
24+
2125
from ag_ui.core import RunAgentInput
26+
from ag_ui.encoder import EventEncoder
2227

2328
from codebase.base import MergeRequest
29+
from codebase.context import RuntimeCtx
2430

2531
logger = logging.getLogger("daiv.chat")
2632

2733
# GitState fields that survive the ag-ui output-schema filter and reach the
28-
# chat client through STATE_SNAPSHOT events. ``merge_request`` drives the
29-
# composer MR pill; extend this list when adding new streamable state.
34+
# chat client through STATE_SNAPSHOT events.
3035
STREAMED_STATE_KEYS = ("merge_request",)
3136

37+
# Bump ``last_active_at`` at most this often while the stream is alive.
38+
HEARTBEAT_INTERVAL_S = 5.0
39+
3240

3341
class RuntimeContextLangGraphAGUIAgent(LangGraphAGUIAgent):
34-
"""Default LangGraph's typed ``context=`` kwarg to the daiv RuntimeCtx dataclass.
42+
"""Inject the daiv RuntimeCtx dataclass into upstream's stream kwargs.
3543
36-
Upstream's ``get_stream_kwargs`` only accepts dict-shaped contexts (it merges via
37-
``dict.update``), but our graph declares ``context_schema=RuntimeCtx`` and expects
38-
the frozen dataclass itself. We use ``setdefault`` so an upstream-provided context
39-
still wins if one ever appears; today nothing populates it.
44+
Upstream's ``get_stream_kwargs`` only accepts dict-shaped contexts, but our graph
45+
declares ``context_schema=RuntimeCtx`` and expects the frozen dataclass itself.
4046
"""
4147

42-
def __init__(self, *, runtime_context: Any, **kwargs: Any):
48+
def __init__(self, *, runtime_context: RuntimeCtx, **kwargs: Any):
4349
super().__init__(**kwargs)
4450
self._runtime_context = runtime_context
4551

@@ -50,14 +56,17 @@ def get_stream_kwargs(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
5056

5157
def get_schema_keys(self, config: Any) -> dict[str, list[str]]:
5258
# Upstream calls ``graph.config_schema().schema()`` which recurses into
53-
# ``context_schema=RuntimeCtx``. RuntimeCtx holds a ``git.Repo`` field that pydantic
54-
# cannot turn into JSON schema, so the call raises PydanticInvalidForJsonSchema.
55-
# Derive context keys from the dataclass directly and keep the rest of the shape
56-
# matching upstream's contract. ``output`` is the filter applied to every
57-
# ``STATE_SNAPSHOT`` payload (``filter_object_by_schema_keys``), so fields we
58-
# want streamed to the chat UI must be listed here explicitly.
59+
# ``context_schema=RuntimeCtx``. RuntimeCtx holds a ``git.Repo`` field that
60+
# pydantic cannot turn into JSON schema, so the call raises
61+
# PydanticInvalidForJsonSchema. Derive context keys from the dataclass directly.
5962
ctx_schema = getattr(self.graph, "context_schema", None)
60-
context_keys = [f.name for f in fields(ctx_schema)] if is_dataclass(ctx_schema) else []
63+
if not is_dataclass(ctx_schema):
64+
logger.warning(
65+
"chat: context_schema %r is not a dataclass; STATE_SNAPSHOT context keys will be empty", ctx_schema
66+
)
67+
context_keys: list[str] = []
68+
else:
69+
context_keys = [f.name for f in fields(ctx_schema)]
6170
constant = list(self.constant_schema_keys)
6271
return {"input": constant, "output": [*constant, *STREAMED_STATE_KEYS], "config": [], "context": context_keys}
6372

@@ -76,8 +85,18 @@ class ChatRunStreamer:
7685
input_data: RunAgentInput
7786
encoder: EventEncoder
7887

79-
async def events(self):
88+
def __post_init__(self) -> None:
89+
# The view passes thread_id/run_id alongside input_data; a future refactor
90+
# could desync them silently. Pin the invariant here.
91+
if self.thread_id != self.input_data.thread_id:
92+
raise ValueError(f"thread_id mismatch: {self.thread_id!r} vs input_data {self.input_data.thread_id!r}")
93+
if self.run_id != self.input_data.run_id:
94+
raise ValueError(f"run_id mismatch: {self.run_id!r} vs input_data {self.input_data.run_id!r}")
95+
96+
async def events(self) -> AsyncIterator[str]:
8097
last_mr: MergeRequest | None = None
98+
clean_run = False
99+
last_heartbeat = time.monotonic()
81100
try:
82101
async with (
83102
open_checkpointer() as checkpointer,
@@ -103,22 +122,35 @@ async def events(self):
103122
if isinstance(snap, dict) and "merge_request" in snap:
104123
last_mr = snap["merge_request"]
105124
yield self.encoder.encode(event)
106-
except Exception as exc:
125+
126+
now = time.monotonic()
127+
if now - last_heartbeat >= HEARTBEAT_INTERVAL_S:
128+
last_heartbeat = now
129+
try:
130+
await ChatThreadService.heartbeat(self.thread_id, self.run_id)
131+
except Exception:
132+
logger.exception("chat: heartbeat failed for thread_id=%s", self.thread_id)
133+
clean_run = True
134+
except Exception:
107135
logger.exception("Chat run failed for thread_id=%s run_id=%s", self.thread_id, self.run_id)
108136
yield self.encoder.encode(
109-
RunErrorEvent(type=EventType.RUN_ERROR, message=f"{type(exc).__name__}: {exc}", code="run_failed")
137+
RunErrorEvent(
138+
type=EventType.RUN_ERROR, message="Run failed. Check server logs for details.", code="run_failed"
139+
)
110140
)
111141
finally:
112142
# Both cleanup steps are wrapped: a post-stream DB hiccup must not
113143
# retroactively paint a clean run as RUN_ERROR, and a release_run
114144
# failure must not leave the per-thread slot permanently claimed.
115-
# Durable copy of the source_branch so reloads land on it; the pill
116-
# itself updates client-side from the same STATE_SNAPSHOT stream.
117-
try:
118-
await ChatThreadService.persist_ref(self.thread_id, self.ref, last_mr)
119-
except Exception:
120-
logger.exception("chat: failed to persist thread ref for thread_id=%s", self.thread_id)
145+
# ref is only persisted on a clean finish — a partial run could have
146+
# checked out a branch without committing, and pinning it would
147+
# silently retarget reloads at half-built state.
148+
if clean_run:
149+
try:
150+
await ChatThreadService.persist_ref(self.thread_id, self.ref, last_mr)
151+
except Exception:
152+
logger.exception("chat: failed to persist thread ref for thread_id=%s", self.thread_id)
121153
try:
122-
await ChatThreadService.release_run(self.thread_id)
154+
await ChatThreadService.release_run(self.thread_id, self.run_id)
123155
except Exception:
124156
logger.exception("chat: failed to release run slot for thread_id=%s", self.thread_id)

daiv/chat/api/threads.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from __future__ import annotations
2+
3+
from datetime import timedelta
14
from typing import TYPE_CHECKING
25

6+
from django.db.models import Q
37
from django.utils import timezone
48

59
from chat.models import ChatThread
@@ -11,18 +15,26 @@
1115
from codebase.base import MergeRequest
1216

1317

14-
def _extract_first_user_message(input_data: RunAgentInput) -> str:
15-
return next((c for m in input_data.messages if isinstance(c := getattr(m, "content", ""), str) and c.strip()), "")
18+
# A claim that hasn't bumped last_active_at within this window is considered
19+
# orphaned (worker crashed / OOM-killed before the streamer's finally ran) and
20+
# can be taken over by a fresh claim. Live runs heartbeat well within this
21+
# window via ``ChatThreadService.heartbeat``.
22+
STALE_RUN_MINUTES = 30
1623

1724

18-
class ChatThreadService:
19-
"""Encapsulates ``ChatThread`` row operations needed by the chat API.
25+
def _extract_first_user_message(input_data: RunAgentInput) -> str:
26+
"""Return the first non-empty content from a human/user role message."""
27+
for m in input_data.messages:
28+
role = (getattr(m, "role", None) or getattr(m, "type", "") or "").lower()
29+
if role not in ("user", "human"):
30+
continue
31+
content = getattr(m, "content", "")
32+
if isinstance(content, str) and content.strip():
33+
return content
34+
return ""
2035

21-
The view stays out of the model directly — every read/write goes through
22-
this service so the per-thread run-slot protocol (``aget_or_create`` →
23-
conditional ``UPDATE`` claim → ``UPDATE`` release) lives in one place.
24-
"""
2536

37+
class ChatThreadService:
2638
@staticmethod
2739
async def get_or_create_for_user(
2840
*, user: User, thread_id: str, repo_id: str, ref: str, input_data: RunAgentInput
@@ -43,24 +55,51 @@ async def get_or_create_for_user(
4355

4456
@staticmethod
4557
async def try_claim_run(thread_id: str, run_id: str) -> bool:
46-
"""Atomic claim: only succeeds if the slot is currently free. Avoids TOCTOU
47-
between a "is it free?" read and a "claim it" write when two tabs fire
48-
simultaneously.
58+
"""Atomic claim: succeeds if the slot is free OR its heartbeat is stale.
59+
60+
Why: a worker crash (OOM, SIGKILL, ASGI transport error before the streaming
61+
body iterates) skips the streamer's ``finally`` so ``release_run`` never fires.
62+
Without the stale-takeover branch the thread would be unrecoverable forever.
4963
"""
50-
claimed = await ChatThread.objects.filter(thread_id=thread_id, active_run_id="").aupdate(
64+
stale_cutoff = timezone.now() - timedelta(minutes=STALE_RUN_MINUTES)
65+
free_or_stale = Q(active_run_id__isnull=True) | Q(last_active_at__lt=stale_cutoff)
66+
claimed = await ChatThread.objects.filter(Q(thread_id=thread_id) & free_or_stale).aupdate(
5167
active_run_id=run_id, last_active_at=timezone.now()
5268
)
5369
return bool(claimed)
5470

5571
@staticmethod
56-
async def release_run(thread_id: str) -> None:
57-
await ChatThread.objects.filter(thread_id=thread_id).aupdate(active_run_id="", last_active_at=timezone.now())
72+
async def heartbeat(thread_id: str, run_id: str) -> None:
73+
"""Bump ``last_active_at`` while the slot is still ours.
74+
75+
Filtered on ``active_run_id=run_id`` so a delayed heartbeat from a previous
76+
run cannot keep a stale slot alive after another run took it over.
77+
"""
78+
await ChatThread.objects.filter(thread_id=thread_id, active_run_id=run_id).aupdate(
79+
last_active_at=timezone.now()
80+
)
5881

5982
@staticmethod
60-
async def persist_ref(thread_id: str, original_ref: str, mr: MergeRequest | None) -> None:
61-
"""Sync ``ChatThread.ref`` with the agent's final ``merge_request`` (captured
62-
from the live STATE_SNAPSHOT stream — no second checkpoint read needed).
83+
async def release_run(thread_id: str, run_id: str) -> None:
84+
"""Clear the slot only if we still hold it.
85+
86+
The ``active_run_id=run_id`` guard prevents a delayed cleanup from stomping
87+
a freshly-claimed slot taken over via the stale path.
88+
"""
89+
await ChatThread.objects.filter(thread_id=thread_id, active_run_id=run_id).aupdate(
90+
active_run_id=None, last_active_at=timezone.now()
91+
)
92+
93+
@staticmethod
94+
async def persist_ref(thread_id: str, original_ref: str, mr: MergeRequest | dict | None) -> None:
95+
"""Sync ``ChatThread.ref`` with the agent's final ``merge_request``.
96+
97+
Accepts both a live ``MergeRequest`` instance and a dict (the snapshot
98+
gets rehydrated through the checkpointer as a plain dict, so resumed
99+
runs land here in dict shape).
63100
"""
64-
new_ref = mr.source_branch if mr else None
101+
if mr is None:
102+
return
103+
new_ref = mr.get("source_branch") if isinstance(mr, dict) else getattr(mr, "source_branch", None)
65104
if new_ref and new_ref != original_ref:
66105
await ChatThread.objects.filter(thread_id=thread_id).aupdate(ref=new_ref)

0 commit comments

Comments
 (0)