Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions ddapm_test_agent/claude_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def __init__(self, session_id: str, trace_id: str, root_span_id: str, start_ns:
# Task tool_use_ids that have already been claimed by a SubagentStart,
# so they are not matched again when a second SubagentStart fires.
self.claimed_task_tools: Set[str] = set()
# Currently active agents keyed by span_id, for concurrent subagent resolution.
self.active_agents: Dict[str, Dict[str, Any]] = {}
self.conversation_title: str = ""
# Persists across turns so each turn's context_delta reflects growth from
# the previous turn's final context size.
Expand Down Expand Up @@ -367,6 +369,7 @@ def _finalize_interrupted_turn(self, session: SessionState) -> None:
session.pending_tools.clear()
session.deferred_agent_spans.clear()
session.claimed_task_tools.clear()
session.active_agents.clear()

# Finalize the root span
root_span: Optional[Dict[str, Any]] = getattr(session, "_root_span_ref", None)
Expand Down Expand Up @@ -416,6 +419,7 @@ def _handle_user_prompt_submit(self, session_id: str, body: Dict[str, Any]) -> N
session.pending_tools = {}
session.deferred_agent_spans = {}
session.claimed_task_tools = set()
session.active_agents = {}
# Don't reset conversation_title — it persists across turns so
# subsequent interactions on the same topic reuse the title.
# The haiku summarization call will update it when the topic changes.
Expand Down Expand Up @@ -468,7 +472,13 @@ def _handle_pre_tool_use(self, session_id: str, body: Dict[str, Any]) -> None:
session.tools_used.add(tool_name)

span_id = _format_span_id()
parent_id = self._current_parent_id(session)
# Try link tracker first: tool_use_id → LLM span → agent parent.
# This correctly resolves the parent when concurrent subagents are active.
parent_id = None
if self._link_tracker:
parent_id = self._link_tracker.get_parent_for_tool(tool_use_id)
if not parent_id:
parent_id = self._current_parent_id(session)
now_ns = int(time.time() * 1_000_000_000)

session.pending_tools[tool_use_id] = PendingToolSpan(
Expand Down Expand Up @@ -638,7 +648,6 @@ def _handle_subagent_start(self, session_id: str, body: Dict[str, Any]) -> None:
"""
session = self._get_or_create_session(session_id)
span_id = _format_span_id()
parent_id = self._current_parent_id(session)
now_ns = int(time.time() * 1_000_000_000)

agent_name = body.get("agent_type", body.get("agent_name", "subagent"))
Expand All @@ -648,13 +657,20 @@ def _handle_subagent_start(self, session_id: str, body: Dict[str, Any]) -> None:
# when multiple Task tools are pending, each subagent gets its own.
task_tool_use_id = ""
task_tool_input: Any = None
task_pending: Optional[PendingToolSpan] = None
for tid, pending in session.pending_tools.items():
if pending.tool_name == "Task" and tid not in session.claimed_task_tools:
task_tool_use_id = tid
task_tool_input = pending.tool_input
task_pending = pending
session.claimed_task_tools.add(tid)
break

# Use the parent captured at PreToolUse time (before any SubagentStart fired).
# This is correct for concurrent siblings — they all share the same parent
# from when they were dispatched, not the stack top which may have changed.
parent_id = task_pending.parent_id if task_pending else self._current_parent_id(session)

# Enrich agent name with the Task tool's description if available
task_desc = ""
if isinstance(task_tool_input, dict):
Expand Down Expand Up @@ -693,17 +709,22 @@ def _handle_subagent_start(self, session_id: str, body: Dict[str, Any]) -> None:
}
self._assembled_spans.append(preliminary_span)

session.agent_span_stack.append(
{
"span_id": span_id,
"parent_id": parent_id,
"name": agent_name,
"start_ns": now_ns,
"task_tool_use_id": task_tool_use_id,
"task_tool_input": task_tool_input,
"_span_ref": preliminary_span,
}
)
task_prompt = ""
if isinstance(task_tool_input, dict):
task_prompt = task_tool_input.get("prompt", "")

agent_entry = {
"span_id": span_id,
"parent_id": parent_id,
"name": agent_name,
"start_ns": now_ns,
"task_tool_use_id": task_tool_use_id,
"task_tool_input": task_tool_input,
"task_prompt": task_prompt,
"_span_ref": preliminary_span,
}
session.agent_span_stack.append(agent_entry)
session.active_agents[span_id] = agent_entry

def _handle_subagent_stop(self, session_id: str, body: Dict[str, Any]) -> None:
"""Handle SubagentStop hook event — pops the agent stack.
Expand All @@ -720,6 +741,8 @@ def _handle_subagent_stop(self, session_id: str, body: Dict[str, Any]) -> None:
return

agent_info = session.agent_span_stack.pop()
# Remove from active_agents map
session.active_agents.pop(str(agent_info["span_id"]), None)
duration = now_ns - agent_info["start_ns"]

task_tool_use_id = agent_info.get("task_tool_use_id", "")
Expand Down
17 changes: 17 additions & 0 deletions ddapm_test_agent/claude_link_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ClaudeLinkTracker:

def __init__(self) -> None:
self._tool_calls: Dict[str, TrackedToolCall] = {}
self._llm_span_parents: Dict[str, str] = {} # llm_span_id → parent agent span_id

def on_llm_tool_choice(
self,
Expand Down Expand Up @@ -145,3 +146,19 @@ def on_tool_call_output_used(self, tool_use_id: str) -> Tuple[List[SpanLink], Op
)
]
return links, tc.tool_parent_id

def set_llm_parent(self, llm_span_id: str, parent_span_id: str) -> None:
"""Record which agent span an LLM span belongs to."""
self._llm_span_parents[llm_span_id] = parent_span_id

def get_parent_for_tool(self, tool_use_id: str) -> Optional[str]:
"""Resolve the parent agent for a tool via: tool_use_id → LLM span → agent parent.

When an LLM response emits tool_use blocks, on_llm_tool_choice records which
LLM span produced each tool_use_id. This method walks that chain to find the
agent span that the LLM (and therefore the tool) belongs to.
"""
tc = self._tool_calls.get(tool_use_id)
if not tc:
return None
return self._llm_span_parents.get(tc.llm_span_id)
48 changes: 48 additions & 0 deletions ddapm_test_agent/claude_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,43 @@ def _extract_conversation_title(session: SessionState, content_blocks: List[Dict
log.info("Conversation title: %s", title)
return

@staticmethod
def _match_agent_by_prompt(session: SessionState, request_body: Dict[str, Any]) -> Optional[str]:
"""Match an LLM request to an active agent by checking if the agent's task_prompt appears in the request.

When multiple concurrent subagents are active and there are no tool_result
hints (first LLM call for each subagent), we fall back to matching the
task prompt that launched each agent against the system/user messages in
the API request.
"""
# Extract text from the request messages
request_text = ""
for msg in request_body.get("messages", []):
content = msg.get("content", "")
if isinstance(content, str):
request_text += content + "\n"
elif isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
request_text += block.get("text", "") + "\n"
# Also check system prompt
system = request_body.get("system", "")
if isinstance(system, str) and system:
request_text += system + "\n"
elif isinstance(system, list):
for block in system:
if isinstance(block, dict) and block.get("type") == "text":
request_text += block.get("text", "") + "\n"
if not request_text:
return None

for span_id, agent_entry in session.active_agents.items():
task_prompt = agent_entry.get("task_prompt", "")
if task_prompt and task_prompt in request_text:
return span_id

return None

def _create_llm_span(
self,
session: Optional[SessionState],
Expand Down Expand Up @@ -383,6 +420,12 @@ def _create_llm_span(
# All tool_results in a single LLM request should come from the same parent.
# Use the first hint; if they disagree, still better than the stack guess.
parent_id = parent_hints[0]
elif session and len(session.active_agents) > 1:
# First LLM call for a subagent — no tool_results to hint from.
# Match by checking which agent's task_prompt appears in the request.
resolved = self._match_agent_by_prompt(session, request_body)
if resolved:
parent_id = resolved

# LLM.output -> Tool.input linking: register tool_use blocks from the response
tool_uses = _extract_tool_uses_from_response(content_blocks)
Expand All @@ -395,6 +438,11 @@ def _create_llm_span(
llm_trace_id=trace_id,
)

# Record which agent this LLM span belongs to, so tools from its
# tool_use blocks can resolve their parent via the link tracker.
if parent_id and parent_id != "undefined":
self._link_tracker.set_llm_parent(span_id, parent_id)

input_messages = _format_input_messages(request_body)
output_messages = _format_output_messages(content_blocks)

Expand Down
153 changes: 153 additions & 0 deletions tests/test_claude_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,156 @@ async def test_hook_sessions_endpoint(agent):
body = await resp.json()
session_ids = [s["session_id"] for s in body["sessions"]]
assert session_id in session_ids


async def test_concurrent_subagents_parent_correctly(agent):
"""Two Task tools spawn sibling subagents — both should be parented to root, not each other."""
session_id = "sess-concurrent"

await _post_hook(agent, {"session_id": session_id, "hook_event_name": "SessionStart"})
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "UserPromptSubmit",
"user_prompt": "Run two tasks concurrently",
},
)

# Two PreToolUse(Task) fire before any SubagentStart — simulates concurrent dispatch
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "PreToolUse",
"tool_name": "Task",
"tool_use_id": "task-A",
"tool_input": {"description": "search code", "prompt": "Search the codebase for foo"},
},
)
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "PreToolUse",
"tool_name": "Task",
"tool_use_id": "task-B",
"tool_input": {"description": "read docs", "prompt": "Read the documentation for bar"},
},
)

# SubagentStart for first agent (claims task-A)
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "SubagentStart",
"agent_type": "explore-agent",
},
)

# SubagentStart for second agent (claims task-B)
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "SubagentStart",
"agent_type": "explore-agent",
},
)

# Tool inside agent1
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "PreToolUse",
"tool_name": "Grep",
"tool_use_id": "tool-in-A",
"tool_input": {"pattern": "foo"},
},
)
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "PostToolUse",
"tool_name": "Grep",
"tool_use_id": "tool-in-A",
"tool_response": "found foo",
},
)

# Tool inside agent2
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "PreToolUse",
"tool_name": "Read",
"tool_use_id": "tool-in-B",
"tool_input": {"file_path": "/docs/bar.md"},
},
)
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "PostToolUse",
"tool_name": "Read",
"tool_use_id": "tool-in-B",
"tool_response": "bar docs",
},
)

# SubagentStop for agent2 (top of stack)
await _post_hook(agent, {"session_id": session_id, "hook_event_name": "SubagentStop"})
# SubagentStop for agent1
await _post_hook(agent, {"session_id": session_id, "hook_event_name": "SubagentStop"})

# PostToolUse for both Task tools
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "PostToolUse",
"tool_name": "Task",
"tool_use_id": "task-A",
"tool_response": "search results",
},
)
await _post_hook(
agent,
{
"session_id": session_id,
"hook_event_name": "PostToolUse",
"tool_name": "Task",
"tool_use_id": "task-B",
"tool_response": "docs content",
},
)

await _post_hook(agent, {"session_id": session_id, "hook_event_name": "Stop"})

resp = await agent.get("/claude/hooks/spans")
body = await resp.json()
spans = body["spans"]

# Filter to just this session's spans
session_spans = [s for s in spans if s.get("session_id") == session_id]

root_spans = [s for s in session_spans if s["parent_id"] == "undefined"]
assert len(root_spans) == 1
root = root_spans[0]

agent_spans = [
s for s in session_spans if s["meta"]["span"]["kind"] == "agent" and s["parent_id"] != "undefined"
]
assert len(agent_spans) == 2, f"Expected 2 subagent spans, got {len(agent_spans)}"

# Both subagents should be parented to root — not to each other
for agent_span in agent_spans:
assert agent_span["parent_id"] == root["span_id"], (
f"Subagent {agent_span['name']} has parent_id={agent_span['parent_id']} "
f"but expected root span_id={root['span_id']}"
)
Loading