diff --git a/langgraphics-web/public/icons/chat_model.svg b/langgraphics-web/public/icons/chat_model.svg new file mode 100644 index 0000000..1c7a632 --- /dev/null +++ b/langgraphics-web/public/icons/chat_model.svg @@ -0,0 +1,11 @@ + + + + + + + + + + \ No newline at end of file diff --git a/langgraphics-web/public/icons/function.svg b/langgraphics-web/public/icons/parser.svg similarity index 100% rename from langgraphics-web/public/icons/function.svg rename to langgraphics-web/public/icons/parser.svg diff --git a/langgraphics-web/public/icons/agent.svg b/langgraphics-web/public/icons/prompt.svg similarity index 100% rename from langgraphics-web/public/icons/agent.svg rename to langgraphics-web/public/icons/prompt.svg diff --git a/langgraphics-web/public/icons/runnable.svg b/langgraphics-web/public/icons/runnable.svg deleted file mode 100644 index bf584ab..0000000 --- a/langgraphics-web/public/icons/runnable.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - \ No newline at end of file diff --git a/langgraphics-web/src/components/InspectPanel.tsx b/langgraphics-web/src/components/InspectPanel.tsx index bc2ac4a..0b97d17 100644 --- a/langgraphics-web/src/components/InspectPanel.tsx +++ b/langgraphics-web/src/components/InspectPanel.tsx @@ -1,131 +1,69 @@ import Tree from "antd/es/tree"; -import {type ReactNode} from "react"; -import type {Node} from "@xyflow/react"; -import type {GraphMessage, NodeData, NodeOutputEntry, NodeStepEntry} from "../types"; -import {useInspectTree} from "../hooks/useInspectTree.tsx"; +import type {TreeDataNode} from "antd"; +import {useCallback, useEffect, useMemo, useState} from "react"; +import type {NodeEntry} from "../types"; -interface InspectPanelProps { - topology: GraphMessage | null; - nodes: Node[]; - nodeOutputLog: NodeOutputEntry[]; - nodeStepLog: NodeStepEntry[]; -} - -function DetailSection({title, children}: { title: string; children: ReactNode }) { - return ( -
- {title} - {children} -
- ); -} +export function InspectPanel({nodeEntries}: { nodeEntries: NodeEntry[] }) { + const [selectedKey, setSelectedKey] = useState(""); -function NodeDetail({entry, isStart, isEnd, stepStart, stepEnd}: { - entry: NodeOutputEntry | null; - isStart: boolean; - isEnd: boolean; - stepStart: NodeStepEntry | null; - stepEnd: NodeStepEntry | null; -}) { - if (!entry) return null; + const expandedKeys = useMemo(() => { + return nodeEntries.map(({run_id}) => run_id); + }, [nodeEntries]); - if (stepStart !== null) { - let input = stepStart.data; - let output = stepEnd !== null ? stepEnd.data : stepEnd; - const toString = (d: any) => typeof d === "string" ? d : JSON.stringify(d, null, 2); - if (typeof stepStart.data === "object") { - const messages = stepStart.data.messages; - input = Array.isArray(messages) ? messages[messages.length - 1].content : stepStart.data; - } - if (stepEnd !== null && typeof stepEnd.data === "object") { - const messages = stepEnd.data.messages; - output = Array.isArray(messages) ? messages[messages.length - 1].content : stepEnd.data; - } - return ( - <> - -
{toString(input)}
-
- {stepEnd !== null && ( - -
{toString(output)}
-
- )} - - ); - } + const selectedEntry = useMemo(() => { + return nodeEntries.find(({run_id}) => run_id === selectedKey); + }, [nodeEntries, selectedKey]); - if (isStart) { - const allMessages = entry.data.messages ?? []; - const promptMsg = allMessages.find((m) => m.type === "system") ?? allMessages[0]; - return ( - - {promptMsg - ?
{promptMsg.content as string}
- :
{JSON.stringify(entry.data, null, 2)}
- } -
- ); - } + const getChildren = useCallback((parent: NodeEntry) => { + return nodeEntries.filter(({parent_run_id}) => parent_run_id === parent.run_id).map(child => { + const children: TreeDataNode[] = getChildren(child); + return { + children, + selectable: true, + key: child.run_id, + isLeaf: children.length === 0, + title: ( + + {child.node_kind + ? {child.node_kind} + : + } + {child.node_id ?? "step"} + + ), + } + }) + }, [nodeEntries]) - if (isEnd) { - const allMessages = entry.data.messages ?? []; - const lastMsg = allMessages.length > 0 ? allMessages[allMessages.length - 1] : null; - return ( - - {lastMsg - ?
{lastMsg.content as string}
- :
{JSON.stringify(entry.data, null, 2)}
- } -
- ); - } + const treeData = useMemo((): TreeDataNode[] => { + return nodeEntries.filter(({parent_run_id}) => !parent_run_id).map(entry => ({ + key: entry.run_id, + children: getChildren(entry), + title: ( + + {entry.node_kind && {entry.node_kind}/} + {entry.node_id} + + ), + })) + }, [nodeEntries, getChildren]); - const outputMessages = entry.data.messages ?? []; - const inputMessages = entry.input?.messages ?? []; - const lastInput = inputMessages.length > 0 ? inputMessages.slice(-1) : null; - - return ( - <> - {entry.input !== null && ( - - {lastInput - ? lastInput.map((msg, i) =>
{msg.content as string}
) - :
{JSON.stringify(entry.input, null, 2)}
- } -
- )} - - {outputMessages.length > 0 - ? outputMessages.map((msg, i) => ( -
- {msg.content as string} -
- )) - : Object.keys(entry.data).length > 0 - ?
{JSON.stringify(entry.data, null, 2)}
- : null - } -
- - ); -} - -export function InspectPanel({topology, nodes, nodeOutputLog, nodeStepLog}: InspectPanelProps) { - const { - treeData, expandedKeys, visibleLog, - selectedKey, setSelectedKey, - selectedEntry, selectedMeta, - stepStart, stepEnd, - } = useInspectTree(topology, nodes, nodeOutputLog, nodeStepLog); + useEffect(() => { + if (nodeEntries.length > 0 && !selectedKey) { + setSelectedKey(nodeEntries.find(e => !e.parent_run_id)?.run_id ?? nodeEntries[0].run_id); + } + }, [nodeEntries, selectedKey]); return (
Trace Inspector
- {visibleLog.length !== 0 && ( + {nodeEntries.length !== 0 && ( } onSelect={([key]) => key && setSelectedKey(key as string)} @@ -138,13 +76,22 @@ export function InspectPanel({topology, nodes, nodeOutputLog, nodeStepLog}: Insp )}
- + {selectedEntry && ( + <> + {selectedEntry.input && ( +
+ Input +
{selectedEntry.input}
+
+ )} + {selectedEntry.output && ( +
+ Output +
{selectedEntry.output}
+
+ )} + + )}
diff --git a/langgraphics-web/src/hooks/useFocus.ts b/langgraphics-web/src/hooks/useFocus.ts index b4a7f52..02fc613 100644 --- a/langgraphics-web/src/hooks/useFocus.ts +++ b/langgraphics-web/src/hooks/useFocus.ts @@ -81,7 +81,9 @@ export function useFocus({nodes, edges, activeNodeId, rankDir = "TB"}: UseFocusO if (mode !== "auto") return; - if (activeNodeId && activeNodeId !== prevFocusId.current) { + if (nodes.some((n) => n.className === "error")) { + fitView({duration: FIT_VIEW_DURATION}).then(); + } else if (activeNodeId && activeNodeId !== prevFocusId.current) { prevFocusId.current = activeNodeId; const activeNode = nodes.find((n) => n.id === activeNodeId); diff --git a/langgraphics-web/src/hooks/useInspectTree.tsx b/langgraphics-web/src/hooks/useInspectTree.tsx deleted file mode 100644 index 1031158..0000000 --- a/langgraphics-web/src/hooks/useInspectTree.tsx +++ /dev/null @@ -1,215 +0,0 @@ -import {useMemo, useState} from "react"; -import type {TreeDataNode} from "antd"; -import type {Node} from "@xyflow/react"; -import type {GraphMessage, NodeData, NodeKind, NodeOutputEntry, NodeStepEntry} from "../types"; - -export interface NodeMeta { - depth: number; - kind: NodeKind | null; - isStart?: boolean; - isEnd?: boolean; -} - -export function computeDepthMap(topology: GraphMessage): Map { - const adj = new Map(); - for (const n of topology.nodes) adj.set(n.id, []); - for (const e of topology.edges) adj.get(e.source)?.push(e.target); - - const startNode = topology.nodes.find((n) => n.node_type === "start"); - if (!startNode) return new Map(); - - const rank = new Map(); - for (let q = [startNode.id], r = 0; q.length > 0; r++) { - const next: string[] = []; - for (const id of q) - if (!rank.has(id)) { - rank.set(id, r); - next.push(...(adj.get(id) ?? [])); - } - q = next; - } - - const onCycle = new Set(); - for (const [src, tgts] of adj) - for (const t of tgts) - if ((rank.get(t) ?? 0) <= (rank.get(src) ?? 0)) onCycle.add(src); - for (let changed = true; changed;) { - changed = false; - for (const [src, tgts] of adj) - if (!onCycle.has(src)) - for (const t of tgts) - if ((rank.get(t) ?? 0) > (rank.get(src) ?? 0) && onCycle.has(t)) { - onCycle.add(src); - changed = true; - break; - } - } - - for (const [id, tgts] of adj) { - const sr = rank.get(id) ?? 0; - adj.set(id, [...tgts].sort((a, b) => { - const ra = rank.get(a) ?? 0, rb = rank.get(b) ?? 0; - const ba = ra <= sr ? 1 : 0, bb = rb <= sr ? 1 : 0; - return (ba - bb) || ((onCycle.has(a) ? 0 : 1) - (onCycle.has(b) ? 0 : 1)) || (ra - rb); - })); - } - - const nodeInfo = new Map(topology.nodes.map((n) => [n.id, n])); - const result = new Map(); - result.set(startNode.id, {depth: 0, kind: null, isStart: true}); - const endNode = topology.nodes.find((n) => n.node_type === "end"); - if (endNode) result.set(endNode.id, {depth: 0, kind: null, isEnd: true}); - - const path: string[] = []; - const pathSet = new Set(); - const outDepth = new Map(); - const parent = new Map(); - - (function dfs(id: string, par: string | null, d: number) { - if (outDepth.has(id)) return; - parent.set(id, par); - path.push(id); - pathSet.add(id); - - let eff = d; - for (const t of adj.get(id) ?? []) - if (pathSet.has(t)) { - const p = parent.get(t); - const pd = p !== null ? (outDepth.get(p!) ?? 0) : 0; - if (pd < eff) eff = Math.max(0, pd); - } - outDepth.set(id, eff); - - const info = nodeInfo.get(id); - if (info?.node_type === "node") - result.set(id, {depth: eff, kind: info.node_kind ?? null}); - - for (const t of adj.get(id) ?? []) - if (pathSet.has(t)) { - const bd = eff + 1; - for (let i = path.length - 2; path[i] !== t; i--) { - if ((outDepth.get(path[i]) ?? 0) > bd) { - outDepth.set(path[i], bd); - const m = result.get(path[i]); - if (m) result.set(path[i], {...m, depth: bd}); - } - } - } - - for (const t of adj.get(id) ?? []) - if (!pathSet.has(t)) dfs(t, id, eff + 1); - - path.pop(); - pathSet.delete(id); - })(startNode.id, null, -1); - - for (const [id, meta] of result) - if (!onCycle.has(id) && meta.depth !== 0 && !meta.isStart) - result.set(id, {...meta, depth: 0}); - - return result; -} - -export function useInspectTree( - topology: GraphMessage | null, - nodes: Node[], - nodeOutputLog: NodeOutputEntry[], - nodeStepLog: NodeStepEntry[], -) { - const [selectedKey, setSelectedKey] = useState("log-0"); - - const depthMap = useMemo( - () => topology ? computeDepthMap(topology) : new Map(), - [topology]); - - const nodeDataMap = useMemo( - () => new Map(nodes.map((n) => [n.id, n.data])), - [nodes]); - - const visibleLog = useMemo( - () => nodeOutputLog.filter((e) => depthMap.has(e.nodeId)), - [nodeOutputLog, depthMap]); - - const stepsByParent = useMemo(() => { - const map = new Map(); - for (const s of nodeStepLog) - if (s.event === "start") { - const arr = map.get(s.parentRunId) ?? []; - arr.push(s); - map.set(s.parentRunId, arr); - } - return map; - }, [nodeStepLog]); - - const stepEndMap = useMemo(() => { - const map = new Map(); - for (const s of nodeStepLog) - if (s.event === "end") map.set(s.runId, s); - return map; - }, [nodeStepLog]); - - const treeData = useMemo(() => { - const root: TreeDataNode[] = []; - const stack: { d: number; c: TreeDataNode[] }[] = [{d: -1, c: root}]; - - for (let i = 0; i < visibleLog.length; i++) { - const entry = visibleLog[i]; - const meta = depthMap.get(entry.nodeId)!; - const key = `log-${i}`; - const label = nodeDataMap.get(entry.nodeId)?.label ?? entry.nodeId; - const steps = entry.runId ? (stepsByParent.get(entry.runId) ?? []) : []; - - const children: TreeDataNode[] = steps.map((step, si) => ({ - key: `${key}-step-${si}`, isLeaf: true, selectable: true, - title: {step.name ?? "step"}, - })); - - const node: TreeDataNode = { - key, children, - title: ( - - {meta.kind && {meta.kind}/} - {label} - - ), - }; - - while (stack.length > 1 && stack[stack.length - 1].d >= meta.depth) stack.pop(); - stack[stack.length - 1].c.push(node); - stack.push({d: meta.depth, c: children}); - } - - return root; - }, [visibleLog, depthMap, nodeDataMap, stepsByParent]); - - const expandedKeys = useMemo( - () => visibleLog.map((_, idx) => `log-${idx}`), - [visibleLog]); - - const selectedParts = selectedKey?.split("-") ?? []; - const logIdx = selectedKey ? parseInt(selectedParts[1], 10) : null; - const selectedEntry = logIdx !== null ? (visibleLog[logIdx] ?? null) : null; - const selectedMeta = selectedEntry ? depthMap.get(selectedEntry.nodeId) : null; - - let stepStart: NodeStepEntry | null = null; - let stepEnd: NodeStepEntry | null = null; - - if (selectedParts.length === 4 && selectedParts[2] === "step" && selectedEntry?.runId) { - const stepIdx = parseInt(selectedParts[3], 10); - const steps = stepsByParent.get(selectedEntry.runId) ?? []; - stepStart = steps[stepIdx] ?? null; - if (stepStart) stepEnd = stepEndMap.get(stepStart.runId) ?? null; - } - - return { - treeData, - expandedKeys, - visibleLog, - selectedKey, - setSelectedKey, - selectedEntry, - selectedMeta: selectedMeta ?? null, - stepStart, - stepEnd, - }; -} diff --git a/langgraphics-web/src/hooks/useWebSocket.ts b/langgraphics-web/src/hooks/useWebSocket.ts index 9161789..1a45f0a 100644 --- a/langgraphics-web/src/hooks/useWebSocket.ts +++ b/langgraphics-web/src/hooks/useWebSocket.ts @@ -1,13 +1,12 @@ import {useEffect, useRef, useState} from "react"; -import type {ExecutionEvent, GraphMessage, NodeOutputEntry, NodeStepEntry, WsMessage} from "../types"; +import type {ExecutionEvent, GraphMessage, NodeEntry, WsMessage} from "../types"; const RECONNECT_INTERVAL = 500; const CONNECTION_TIMEOUT = 500; export function useWebSocket(url: string) { const [events, setEvents] = useState([]); - const [nodeStepLog, setNodeStepLog] = useState([]); - const [nodeOutputLog, setNodeOutputLog] = useState([]); + const [nodeEntries, setNodeEntries] = useState([]); const [topology, setTopology] = useState(null); const wsRef = useRef(null); const timerRef = useRef | null>(null); @@ -42,19 +41,16 @@ export function useWebSocket(url: string) { const msg: WsMessage = JSON.parse(event.data); if (msg.type === "graph") { runDone = false; - setTopology(msg); setEvents([]); - setNodeStepLog([]); - setNodeOutputLog([]); + setTopology(msg); + setNodeEntries([]); } else if (msg.type === "run_start") { runDone = false; setEvents([msg]); - setNodeStepLog([]); - setNodeOutputLog([]); + setNodeEntries([]); } else if (msg.type === "node_output") { - setNodeOutputLog((prev) => [...prev, {nodeId: msg.node, data: msg.data, input: msg.input ?? null, runId: msg.run_id ?? null}]); - } else if (msg.type === "node_step") { - setNodeStepLog((prev) => [...prev, {runId: msg.run_id, parentRunId: msg.parent_run_id, name: msg.name, event: msg.event, data: msg.data}]); + const {type: _, ...entry} = msg; + setNodeEntries((prev) => [...prev, entry]); } else { if (msg.type === "run_end" || msg.type === "error") runDone = true; setEvents((prev) => [...prev, msg as ExecutionEvent]); @@ -78,5 +74,5 @@ export function useWebSocket(url: string) { }; }, [url]); - return {topology, events, nodeStepLog, nodeOutputLog}; + return {topology, events, nodeEntries}; } diff --git a/langgraphics-web/src/index.css b/langgraphics-web/src/index.css index 170fd53..7611076 100644 --- a/langgraphics-web/src/index.css +++ b/langgraphics-web/src/index.css @@ -223,10 +223,10 @@ html, body, #root { } .inspect-tree-pane { - width: 220px; - flex: 0 0 220px; - overflow-y: auto; + flex: 0 0 auto; + min-width: 220px; border-right: var(--xy-node-border-default); + overflow-y: auto; } .inspect-detail-pane { @@ -272,9 +272,42 @@ html, body, #root { } .inspect-step-label { + gap: 5px; + display: flex; font-size: 11px; font-weight: 400; font-style: italic; + align-items: center; +} + +.inspect-step-status { + width: 6px; + height: 6px; + flex-shrink: 0; + border-radius: 50%; + background: #c8c8c8; +} + +.inspect-step-status.error { + background: #ef4444; +} + +.ant-tree-node-content-wrapper:not(.ant-tree-node-selected) .inspect-step-status:not(.error) { + background: #3b82f6; +} + +.inspect-step-name { + min-width: 0; + flex: 1 1 auto; + overflow: hidden; + white-space: nowrap; + text-overflow: ellipsis; +} + +.inspect-step-icon { + width: 12px; + height: 12px; + flex-shrink: 0; } .inspect-detail-section { diff --git a/langgraphics-web/src/layout.ts b/langgraphics-web/src/layout.ts index 4776bfd..9dde8df 100644 --- a/langgraphics-web/src/layout.ts +++ b/langgraphics-web/src/layout.ts @@ -111,7 +111,7 @@ export function computeLayout(topology: GraphMessage, rankDir: RankDir = "TB"): x: (nodeX.get(n.id) ?? pos.x) - w / 2, y: (nodeY.get(n.id) ?? pos.y) - h / 2, }, - data: {label: n.name, nodeType: n.node_type, nodeKind: n.node_kind ?? null, status: "idle" as const, handles}, + data: {label: n.name, nodeType: n.node_type, status: "idle" as const, handles}, }; }); diff --git a/langgraphics-web/src/main.tsx b/langgraphics-web/src/main.tsx index 5eac5b8..03b4347 100644 --- a/langgraphics-web/src/main.tsx +++ b/langgraphics-web/src/main.tsx @@ -26,7 +26,7 @@ const {colorMode: initialColorMode, rankDir: initialRankDir} = parseParams(); function Index() { const [rankDir, setRankDir] = useState(initialRankDir); - const {topology, events, nodeOutputLog, nodeStepLog} = useWebSocket(WS_URL); + const {topology, events, nodeEntries} = useWebSocket(WS_URL); const {nodes, edges, activeNodeId} = useGraphState(topology, events, rankDir); return ( @@ -35,17 +35,10 @@ function Index() { nodes={nodes} edges={edges} activeNodeId={activeNodeId} - initialColorMode={initialColorMode} - initialRankDir={initialRankDir} onRankDirChange={setRankDir} - inspect={ - - } + initialRankDir={initialRankDir} + initialColorMode={initialColorMode} + inspect={} /> ); diff --git a/langgraphics-web/src/types.ts b/langgraphics-web/src/types.ts index 7385c5c..5bbb09f 100644 --- a/langgraphics-web/src/types.ts +++ b/langgraphics-web/src/types.ts @@ -10,12 +10,13 @@ export interface NodeHandle { style: { top?: string; left?: string; transform: string }; } +export type NodeKind = "llm" | "chain" | "tool" | "retriever" | "embedding" | "prompt" | "parser" | "chat_model"; + export interface NodeData extends Record { label: string; - nodeType: "start" | "end" | "node"; - nodeKind: NodeKind | null; status: NodeStatus; handles: NodeHandle[]; + nodeType: "start" | "end" | "node"; } export interface EdgeData extends Record { @@ -24,14 +25,9 @@ export interface EdgeData extends Record { status: EdgeStatus; } -export type NodeKind = - | "tool" | "llm" | "embedding" | "retriever" - | "agent" | "chain" | "function" | "runnable" | "unknown"; - export interface ProtocolNode { id: string; name: string; - node_kind: NodeKind | null; node_type: "start" | "end" | "node"; } @@ -85,50 +81,23 @@ export interface ErrorMessage { edge_id: string | null; } -export interface SerializedMessage { - content: string; - type: string; - - [key: string]: unknown; -} - -export interface NodeOutputMessage { +export interface NodeMessage { type: "node_output"; - node: string; - data: { messages?: SerializedMessage[]; [key: string]: unknown }; - input: { messages?: SerializedMessage[]; [key: string]: unknown } | null; - run_id: string | null; -} - -export interface NodeOutputEntry { - nodeId: string; - data: NodeOutputMessage["data"]; - input: NodeOutputMessage["input"]; - runId: string | null; -} - -export interface NodeStepMessage { - type: "node_step"; run_id: string; - parent_run_id: string; - name: string | null; - event: "start" | "end"; - data: { [key: string]: unknown }; + node_id: string; + node_kind: NodeKind | null; + parent_run_id?: string | null; + status?: "ok" | "error"; + input?: string | null; + output?: string | null; } -export interface NodeStepEntry { - runId: string; - parentRunId: string; - name: string | null; - event: "start" | "end"; - data: NodeStepMessage["data"]; -} +export type NodeEntry = Omit; export type WsMessage = - | GraphMessage | RunStartMessage | RunEndMessage - | NodeStartMessage | NodeEndMessage | EdgeActiveMessage - | ErrorMessage | NodeOutputMessage | NodeStepMessage; + | GraphMessage | RunStartMessage | RunEndMessage | NodeStartMessage + | NodeEndMessage | EdgeActiveMessage | ErrorMessage | NodeMessage; export type ExecutionEvent = - | RunStartMessage | RunEndMessage | NodeStartMessage - | NodeEndMessage | EdgeActiveMessage | ErrorMessage | NodeOutputMessage | NodeStepMessage; + | RunStartMessage | RunEndMessage | NodeStartMessage | NodeEndMessage + | EdgeActiveMessage | ErrorMessage | NodeMessage; diff --git a/langgraphics-web/tsconfig.app.json b/langgraphics-web/tsconfig.app.json index a9b5a59..11f8105 100644 --- a/langgraphics-web/tsconfig.app.json +++ b/langgraphics-web/tsconfig.app.json @@ -1,9 +1,9 @@ { "compilerOptions": { "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo", - "target": "ES2022", + "target": "ESNext", "useDefineForClassFields": true, - "lib": ["ES2022", "DOM", "DOM.Iterable"], + "lib": ["ESNext", "DOM", "DOM.Iterable"], "module": "ESNext", "types": ["vite/client"], "skipLibCheck": true, diff --git a/langgraphics/streamer.py b/langgraphics/streamer.py index a40ccbe..bf15035 100644 --- a/langgraphics/streamer.py +++ b/langgraphics/streamer.py @@ -1,182 +1,138 @@ import asyncio import json import uuid -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterator from socketserver import TCPServer from typing import Any -from langchain_core.callbacks import AsyncCallbackHandler - - -def _serialize_state(state: Any) -> Any: - if isinstance(state, dict): - return {k: _serialize_state(v) for k, v in state.items()} - if isinstance(state, (list, tuple)): - return [_serialize_state(item) for item in state] - if hasattr(state, "model_dump"): - return state.model_dump() - try: - json.dumps(state) - return state - except (TypeError, ValueError): - return str(state) - - -def _merge_state(base: dict[str, Any], update: dict[str, Any]) -> dict[str, Any]: - merged = dict(base) - for k, v in update.items(): - if isinstance(v, list) and isinstance(merged.get(k), list): - merged[k] = merged[k] + v - else: - merged[k] = v - return merged - - -class SubStepCallbackHandler(AsyncCallbackHandler): - def __init__(self, broadcast: Any, node_names: set[str]) -> None: - self.broadcast = broadcast - self.node_names = node_names - self.id_to_name: dict[str, str] = {} - self.node_run_ids: dict[str, str] = {} - - def _register(self, run_id: Any, parent_run_id: Any, name: str) -> None: - rid = str(run_id) - self.id_to_name[rid] = name - if name in self.node_names: - parent_name = ( - self.id_to_name.get(str(parent_run_id)) if parent_run_id else None - ) - if parent_name not in self.node_names: - self.node_run_ids[name] = rid - - def _parent_is_node(self, parent_run_id: Any) -> bool: - if not parent_run_id: - return False - return self.id_to_name.get(str(parent_run_id)) in self.node_names +from langchain_core.tracers.base import AsyncBaseTracer +from langchain_core.tracers.schemas import Run + + +def parse_message(msg: Any) -> dict | list: + msg = dict(msg) + msg = msg.get("kwargs", msg) + if tool_calls := msg.get("tool_calls", []): + tool_call = tool_calls[0] + return { + "content": tool_call.get("name", "") + + "(" + + str(tool_call.get("args", {})) + + ")", + "type": tool_call.get("type", ""), + } + for k in ("input", "output", "summary", "answer", "result"): + if res := msg.get(k): + if isinstance(res, list) and res: + try: + return parse_message(res[0].update) + except AttributeError: + return parse_message(res[0]) + elif isinstance(res, dict): + return res + elif res.__class__.__module__ != "builtins": + return parse_message(res) + else: + return {k: res} + return { + k: msg[k] + for k in ("content", "question", "prompt", "query", "type") + if k in msg + } + + +def preview(inputs: Any) -> str: + messages = [] + inputs = inputs or {} + if "generations" in inputs: + inputs = inputs["generations"][-1][-1]["message"] + inputs = inputs.get("messages", inputs) + if isinstance(inputs, list): + if inputs: + messages = inputs + if isinstance(inputs[0], list): + messages = inputs[0] + elif isinstance(inputs, dict): + inputs = inputs.get("state", inputs) + messages = inputs.get("messages", inputs) + if isinstance(messages, dict): + messages = [messages] + return json.dumps(list(map(parse_message, messages)), indent=4) + + +def error_output(error: str) -> str: + return json.dumps([{"content": error, "type": "error"}], indent=4) + + +class BroadcastingTracer(AsyncBaseTracer): + def __init__(self, viewport: "Viewport") -> None: + super().__init__(_schema_format="original+chat") + self.viewport = viewport + + async def _persist_run(self, run: Run) -> None: + pass + + async def _emit_end(self, run: Run) -> None: + node_run_id = str(run.parent_run_id) if run.parent_run_id else None + if node_run_id is None: + return + await self.viewport.broadcast( + { + "type": "node_output", + "run_id": str(run.id), + "parent_run_id": node_run_id, + "node_id": run.name, + "node_kind": run.run_type, + "status": "error" if run.error else "ok", + "input": preview(run.inputs), + "output": error_output(run.error) if run.error else preview(run.outputs), + } + ) - async def _emit_start( - self, run_id: Any, parent_run_id: Any, name: str, data: Any - ) -> None: - if self._parent_is_node(parent_run_id): - await self.broadcast( - { - "name": name, - "event": "start", - "type": "node_step", - "run_id": str(run_id), - "data": _serialize_state(data), - "parent_run_id": str(parent_run_id), - } - ) + async def _on_chain_start(self, run: Run) -> None: + if run.name in self.viewport.node_names: + self.viewport.node_current = run.name - async def _emit_end(self, run_id: Any, parent_run_id: Any, data: Any) -> None: - if self._parent_is_node(parent_run_id): - await self.broadcast( - { - "name": None, - "event": "end", - "type": "node_step", - "run_id": str(run_id), - "data": _serialize_state(data), - "parent_run_id": str(parent_run_id), - } + async def _on_chain_end(self, run: Run) -> None: + if run.name in self.viewport.node_names: + parent = ( + self.run_map.get(str(run.parent_run_id)) if run.parent_run_id else None ) + if parent is None or parent.name not in self.viewport.node_names: + await self.viewport.broadcast( + { + "type": "node_output", + "node_id": run.name, + "run_id": str(run.id), + "node_kind": run.run_type, + "status": "error" if run.error else "ok", + "input": preview(run.inputs), + "output": error_output(run.error) if run.error else preview(run.outputs), + } + ) + else: + await self._emit_end(run) - async def on_chain_start( - self, - serialized: Any, - inputs: Any, - *, - run_id: Any, - parent_run_id: Any = None, - **kwargs: Any, - ) -> None: - name = kwargs.get("name") or (serialized or {}).get("name") or "chain" - self._register(run_id, parent_run_id, name) - await self._emit_start(run_id, parent_run_id, name, inputs) - - async def on_chain_end( - self, outputs: Any, *, run_id: Any, parent_run_id: Any = None, **kwargs: Any - ) -> None: - await self._emit_end(run_id, parent_run_id, outputs) - - async def on_llm_start( - self, - serialized: Any, - prompts: Any, - *, - run_id: Any, - parent_run_id: Any = None, - **kwargs: Any, - ) -> None: - name = kwargs.get("name") or (serialized or {}).get("name") or "llm" - self._register(run_id, parent_run_id, name) - await self._emit_start(run_id, parent_run_id, name, {"prompts": prompts}) + async def _on_chain_error(self, run: Run) -> None: + await self._on_chain_end(run) - async def on_chat_model_start( - self, - serialized: Any, - messages: Any, - *, - run_id: Any, - parent_run_id: Any = None, - **kwargs: Any, - ) -> None: - name = kwargs.get("name") or (serialized or {}).get("name") or "chat_model" - self._register(run_id, parent_run_id, name) - await self._emit_start( - run_id, - parent_run_id, - name, - {"messages": _serialize_state([[m for m in batch] for batch in messages])}, - ) + async def _on_llm_end(self, run: Run) -> None: + await self._emit_end(run) - async def on_llm_end( - self, response: Any, *, run_id: Any, parent_run_id: Any = None, **kwargs: Any - ) -> None: - try: - text = response.generations[0][0].text - except Exception: - text = str(response) - await self._emit_end(run_id, parent_run_id, {"output": text}) + async def _on_llm_error(self, run: Run) -> None: + await self._emit_end(run) - async def on_tool_start( - self, - serialized: Any, - input_str: Any, - *, - run_id: Any, - parent_run_id: Any = None, - **kwargs: Any, - ) -> None: - name = (serialized or {}).get("name") or kwargs.get("name") or "tool" - self._register(run_id, parent_run_id, name) - await self._emit_start(run_id, parent_run_id, name, {"input": input_str}) + async def _on_tool_end(self, run: Run) -> None: + await self._emit_end(run) - async def on_tool_end( - self, output: Any, *, run_id: Any, parent_run_id: Any = None, **kwargs: Any - ) -> None: - await self._emit_end(run_id, parent_run_id, {"output": output}) + async def _on_tool_error(self, run: Run) -> None: + await self._emit_end(run) - async def on_retriever_start( - self, - serialized: Any, - query: Any, - *, - run_id: Any, - parent_run_id: Any = None, - **kwargs: Any, - ) -> None: - name = (serialized or {}).get("name") or kwargs.get("name") or "retriever" - self._register(run_id, parent_run_id, name) - await self._emit_start(run_id, parent_run_id, name, {"query": query}) + async def _on_retriever_end(self, run: Run) -> None: + await self._emit_end(run) - async def on_retriever_end( - self, documents: Any, *, run_id: Any, parent_run_id: Any = None, **kwargs: Any - ) -> None: - await self._emit_end( - run_id, parent_run_id, {"documents": _serialize_state(list(documents))} - ) + async def _on_retriever_error(self, run: Run) -> None: + await self._emit_end(run) class Viewport: @@ -189,19 +145,19 @@ def __init__( ) -> None: self.ws = ws self.graph = graph + self.node_current = None self.edge_lookup = edge_lookup self.http_server = http_server - node_names: set[str] = set() + self.node_names: set[str] = set() for src, tgt in edge_lookup: - node_names.add(src) - node_names.add(tgt) - node_names -= {"__start__", "__end__"} - self._node_names = node_names + self.node_names.add(src) + self.node_names.add(tgt) + self.node_names -= {"__start__", "__end__"} def __getattr__(self, name: str) -> Any: return getattr(self.graph, name) - async def _broadcast(self, message: dict[str, Any]) -> None: + async def broadcast(self, message: dict[str, Any]) -> None: message_str = json.dumps(message) self.ws.record(message_str) if self.ws.loop is None: @@ -218,7 +174,7 @@ async def _broadcast(self, message: dict[str, Any]) -> None: async def _emit_edge(self, source: str, target: str) -> None: edge_id = self.edge_lookup.get((source, target)) if edge_id: - await self._broadcast( + await self.broadcast( { "type": "edge_active", "source": source, @@ -228,24 +184,23 @@ async def _emit_edge(self, source: str, target: str) -> None: ) async def _emit_error(self, last_node: str) -> None: - target = last_node - edge_id = None for (src, tgt), eid in self.edge_lookup.items(): - if src == last_node: - target = tgt - edge_id = eid + if src == last_node and tgt == self.node_current: + await self.broadcast( + { + "type": "error", + "edge_id": eid, + "source": last_node, + "target": self.node_current, + } + ) break - await self._broadcast( - {"type": "error", "source": last_node, "target": target, "edge_id": edge_id} - ) - def _make_config( - self, config: Any - ) -> tuple[dict[str, Any], SubStepCallbackHandler]: - handler = SubStepCallbackHandler(self._broadcast, self._node_names) + def _make_config(self, config: Any) -> dict[str, Any]: + tracer = BroadcastingTracer(self) merged: dict[str, Any] = dict(config or {}) - merged["callbacks"] = list(merged.get("callbacks") or []) + [handler] - return merged, handler + merged["callbacks"] = list(merged.get("callbacks") or []) + [tracer] + return merged async def shutdown(self) -> None: await self.ws.shutdown() @@ -253,25 +208,11 @@ async def shutdown(self) -> None: async def ainvoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: run_id = uuid.uuid4().hex[:8] - await self._broadcast({"type": "run_start", "run_id": run_id}) - - serialized_input = _serialize_state(input) - await self._broadcast( - { - "node": "__start__", - "type": "node_output", - "data": serialized_input - if isinstance(serialized_input, dict) - else {"input": serialized_input}, - } - ) + await self.broadcast({"type": "run_start", "run_id": run_id}) result: Any = None last_node = "__start__" - merged_config, handler = self._make_config(config) - accumulated_state: dict[str, Any] = ( - dict(input) if isinstance(input, dict) else {} - ) + merged_config = self._make_config(config) try: async for chunk in self.graph.astream( @@ -282,34 +223,11 @@ async def ainvoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: if node_name == "__metadata__": continue await self._emit_edge(last_node, node_name) - node_output = chunk[node_name] - await self._broadcast( - { - "node": node_name, - "type": "node_output", - "data": _serialize_state(node_output), - "input": _serialize_state(accumulated_state), - "run_id": handler.node_run_ids.get(node_name), - } - ) - if isinstance(node_output, dict): - accumulated_state = _merge_state( - accumulated_state, node_output - ) last_node = node_name - result = node_output + result = chunk[node_name] await self._emit_edge(last_node, "__end__") - await self._broadcast( - { - "node": "__end__", - "type": "node_output", - "data": _serialize_state(accumulated_state), - "input": _serialize_state(accumulated_state), - } - ) - await asyncio.sleep(1) - await self._broadcast({"type": "run_end", "run_id": run_id}) + await self.broadcast({"type": "run_end", "run_id": run_id}) except Exception: await self._emit_error(last_node) raise @@ -318,29 +236,18 @@ async def ainvoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: return result + def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: + return asyncio.run(self.ainvoke(input, config=config, **kwargs)) + async def astream( self, input: Any, config: Any = None, **kwargs: Any ) -> AsyncIterator: run_id = uuid.uuid4().hex[:8] - await self._broadcast({"type": "run_start", "run_id": run_id}) - - serialized_input = _serialize_state(input) - await self._broadcast( - { - "node": "__start__", - "type": "node_output", - "data": serialized_input - if isinstance(serialized_input, dict) - else {"input": serialized_input}, - } - ) + await self.broadcast({"type": "run_start", "run_id": run_id}) last_node = "__start__" - merged_config, handler = self._make_config(config) + merged_config = self._make_config(config) stream_mode = kwargs.get("stream_mode", "values") - accumulated_state: dict[str, Any] = ( - dict(input) if isinstance(input, dict) else {} - ) try: async for chunk in self.graph.astream( @@ -351,38 +258,25 @@ async def astream( if node_name == "__metadata__": continue await self._emit_edge(last_node, node_name) - node_output = chunk[node_name] - await self._broadcast( - { - "node": node_name, - "type": "node_output", - "data": _serialize_state(node_output), - "input": _serialize_state(accumulated_state), - "run_id": handler.node_run_ids.get(node_name), - } - ) - if isinstance(node_output, dict): - accumulated_state = _merge_state( - accumulated_state, node_output - ) last_node = node_name yield chunk if last_node != "__start__": await self._emit_edge(last_node, "__end__") - await self._broadcast( - { - "node": "__end__", - "type": "node_output", - "data": _serialize_state(accumulated_state), - "input": _serialize_state(accumulated_state), - } - ) - await self._broadcast({"type": "run_end", "run_id": run_id}) + await self.broadcast({"type": "run_end", "run_id": run_id}) except Exception: await self._emit_error(last_node) raise - def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: - return asyncio.run(self.ainvoke(input, config=config, **kwargs)) + def stream(self, input: Any, config: Any = None, **kwargs: Any) -> Iterator: + loop = asyncio.new_event_loop() + ait = self.astream(input, config=config, **kwargs).__aiter__() + try: + while True: + try: + yield loop.run_until_complete(ait.__anext__()) + except StopAsyncIteration: + break + finally: + loop.close() diff --git a/langgraphics/topology.py b/langgraphics/topology.py index 468efc8..3a17f76 100644 --- a/langgraphics/topology.py +++ b/langgraphics/topology.py @@ -1,32 +1,6 @@ from typing import Any -def classify_node(data: Any) -> str: - checks = [ - ("langchain_core.tools", "BaseTool", "tool"), - ("langchain_core.language_models", "BaseLanguageModel", "llm"), - ("langchain_core.embeddings", "Embeddings", "embedding"), - ("langchain_core.retrievers", "BaseRetriever", "retriever"), - ("langchain_core.agents", "BaseMultiActionAgent", "agent"), - ("langchain_core.agents", "BaseSingleActionAgent", "agent"), - ("langchain_core.runnables.base", "RunnableSequence", "chain"), - ("langchain_core.runnables.base", "RunnableParallel", "chain"), - ("langgraph._internal._runnable", "RunnableCallable", "function"), - ("langchain_core.runnables", "Runnable", "runnable"), - ] - - for module_path, class_name, label in checks: - try: - module = __import__(module_path, fromlist=[class_name]) - cls = getattr(module, class_name) - if isinstance(data, cls): - return label - except (ImportError, AttributeError): - continue - - return "unknown" - - def extract(graph: Any) -> dict[str, Any]: raw = graph.get_graph() @@ -40,9 +14,6 @@ def extract(graph: Any) -> dict[str, Any]: "__end__": "end", "__start__": "start", }.get(node.name, "node"), - "node_kind": classify_node(node.data) - if node.name not in ("__start__", "__end__") - else None, } for node_id, node in raw.nodes.items() ], diff --git a/tests/lib/test_classify.py b/tests/lib/test_classify.py deleted file mode 100644 index 5a4a222..0000000 --- a/tests/lib/test_classify.py +++ /dev/null @@ -1,176 +0,0 @@ -import pytest - -from langgraphics.topology import classify_node - - -def test_tool_classified(): - from langchain_core.tools import tool - - @tool - def dummy(x: str) -> str: - """test""" - return x - - assert classify_node(dummy) == "tool" - - -def test_tool_takes_priority_over_runnable(): - from langchain_core.tools import tool - - @tool - def my_tool(q: str) -> str: - """test""" - return q - - assert classify_node(my_tool) == "tool" - - -def test_agent_classified(): - from typing import Annotated, TypedDict - - from langgraph.graph import END, StateGraph - from langgraph.graph.message import add_messages - - class AgentState(TypedDict): - messages: Annotated[list, add_messages] - - def agent_node(state: AgentState) -> dict: - return state - - def tools_node(state: AgentState) -> dict: - return state - - builder = StateGraph(AgentState) - builder.add_node("agent", agent_node) - builder.add_node("tools", tools_node) - builder.set_entry_point("agent") - builder.add_edge("agent", "tools") - builder.add_edge("tools", END) - agent = builder.compile() - assert classify_node(agent) == "runnable" - - -def test_create_agent_classified(): - from langchain.agents import create_agent - from langchain_core.language_models.fake_chat_models import ( - FakeMessagesListChatModel, - ) - from langchain_core.messages import AIMessage - - model = FakeMessagesListChatModel(responses=[AIMessage(content="")]) - agent = create_agent(model=model, tools=None) - assert classify_node(agent) == "runnable" - - -def test_embedding_classified(): - from langchain_core.embeddings import FakeEmbeddings - - assert classify_node(FakeEmbeddings(size=1)) == "embedding" - - -def test_retriever_classified(): - from langchain_core.documents import Document - from langchain_core.retrievers import BaseRetriever - - class SimpleRetriever(BaseRetriever): - def _get_relevant_documents(self, query: str) -> list[Document]: - return [] - - assert classify_node(SimpleRetriever()) == "retriever" - - -def test_llm_classified(): - from langchain_core.language_models.fake_chat_models import ( - FakeMessagesListChatModel, - ) - from langchain_core.messages import AIMessage - - model = FakeMessagesListChatModel(responses=[AIMessage(content="")]) - assert classify_node(model) == "llm" - - -def test_function_classified(): - from langgraph._internal._runnable import RunnableCallable - - r = RunnableCallable(func=lambda x: x) - assert classify_node(r) == "function" - - -def test_chain_classified(): - from langchain_core.runnables import RunnableLambda - - chain = RunnableLambda(lambda x: x) | RunnableLambda(lambda x: x) - assert classify_node(chain) == "chain" - - -def test_agent_legacy_classified(): - try: - from langchain_core.agents import AgentFinish, BaseSingleActionAgent - except (ImportError, AttributeError): - pytest.skip("legacy agent classes not in this langchain_core") - - class StubAgent(BaseSingleActionAgent): - def plan(self, intermediate_steps, **kwargs): - return AgentFinish(return_values={}, log="") - - @property - def input_keys(self): - return [] - - assert classify_node(StubAgent()) == "agent" - - -def test_unknown_classified(): - assert classify_node("foo") == "unknown" - assert classify_node(42) == "unknown" - - -def test_runnable_sequence_classified_as_runnable(): - from langchain_core.runnables import RunnableLambda - - chain = RunnableLambda(lambda x: x) | RunnableLambda(lambda x: x) - assert classify_node(chain) == "chain" - - -def test_runnable_parallel_classified_as_runnable(): - from langchain_core.runnables import RunnableLambda, RunnableParallel - - par = RunnableParallel(a=RunnableLambda(lambda x: x)) - assert classify_node(par) == "chain" - - -def test_runnable_lambda_classified_as_runnable(): - from langchain_core.runnables import RunnableLambda - - r = RunnableLambda(lambda x: x) - assert classify_node(r) == "runnable" - - -def test_classify_via_extract(simple_graph): - from langgraphics.topology import extract - - topology = extract(simple_graph) - kinds = {n["name"]: n["node_kind"] for n in topology["nodes"]} - - assert kinds["__start__"] is None - assert kinds["__end__"] is None - assert kinds["step_a"] is not None - assert kinds["step_b"] is not None - - -def test_start_end_nodes_have_null_kind(simple_graph): - from langgraphics.topology import extract - - topology = extract(simple_graph) - for node in topology["nodes"]: - if node["node_type"] in ("start", "end"): - assert node["node_kind"] is None - - -def test_regular_nodes_have_kind(simple_graph): - from langgraphics.topology import extract - - topology = extract(simple_graph) - for node in topology["nodes"]: - if node["node_type"] == "node": - assert node["node_kind"] is not None diff --git a/tests/web/depthMap.test.ts b/tests/web/depthMap.test.ts deleted file mode 100644 index ff4e476..0000000 --- a/tests/web/depthMap.test.ts +++ /dev/null @@ -1,216 +0,0 @@ -import {describe, expect, it} from "vitest"; -import {computeDepthMap} from "../../langgraphics-web/src/hooks/useInspectTree"; -import type {GraphMessage} from "../../langgraphics-web/src/types"; - -function topo(nodes: GraphMessage["nodes"], edges: GraphMessage["edges"]): GraphMessage { - return {type: "graph", nodes, edges}; -} - -function n(id: string, type: "start" | "end" | "node" = "node"): GraphMessage["nodes"][0] { - return {id, name: id, node_type: type, node_kind: type === "node" ? "function" : null}; -} - -function e(id: string, source: string, target: string, conditional = false): GraphMessage["edges"][0] { - return {id, source, target, conditional, label: null}; -} - -describe("computeDepthMap", () => { - it("returns empty map when no start node", () => { - const result = computeDepthMap(topo( - [n("A"), n("B")], - [e("e0", "A", "B")], - )); - expect(result.size).toBe(0); - }); - - it("marks start node with isStart and depth 0", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("A"), n("__end__", "end")], - [e("e0", "__start__", "A"), e("e1", "A", "__end__")], - )); - const start = result.get("__start__"); - expect(start).toBeDefined(); - expect(start!.isStart).toBe(true); - expect(start!.depth).toBe(0); - }); - - it("marks end node with isEnd and depth 0", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("A"), n("__end__", "end")], - [e("e0", "__start__", "A"), e("e1", "A", "__end__")], - )); - const end = result.get("__end__"); - expect(end).toBeDefined(); - expect(end!.isEnd).toBe(true); - expect(end!.depth).toBe(0); - }); - - it("assigns depth 0 to linear chain nodes", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("A"), n("B"), n("C"), n("__end__", "end")], - [e("e0", "__start__", "A"), e("e1", "A", "B"), e("e2", "B", "C"), e("e3", "C", "__end__")], - )); - expect(result.get("A")!.depth).toBe(0); - expect(result.get("B")!.depth).toBe(0); - expect(result.get("C")!.depth).toBe(0); - }); - - it("nests cycle nodes under their parent depth", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("decide"), n("A"), n("B"), n("__end__", "end")], - [ - e("e0", "__start__", "decide"), - e("e1", "decide", "A"), - e("e2", "A", "B"), - e("e3", "B", "decide"), - e("e4", "decide", "__end__"), - ], - )); - expect(result.get("decide")!.depth).toBeGreaterThanOrEqual(0); - expect(result.get("A")!.depth).toBeGreaterThan(result.get("decide")!.depth); - expect(result.get("B")!.depth).toBe(0); - }); - - it("keeps non-cycle nodes at depth 0 even if adjacent to cycle", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("loop_entry"), n("loop_body"), n("exit"), n("__end__", "end")], - [ - e("e0", "__start__", "loop_entry"), - e("e1", "loop_entry", "loop_body"), - e("e2", "loop_body", "loop_entry"), - e("e3", "loop_entry", "exit"), - e("e4", "exit", "__end__"), - ], - )); - expect(result.get("exit")!.depth).toBe(0); - }); - - it("preserves node_kind from topology", () => { - const result = computeDepthMap(topo( - [ - n("__start__", "start"), - {id: "A", name: "A", node_type: "node", node_kind: "llm"}, - {id: "B", name: "B", node_type: "node", node_kind: "tool"}, - n("__end__", "end"), - ], - [e("e0", "__start__", "A"), e("e1", "A", "B"), e("e2", "B", "__end__")], - )); - expect(result.get("A")!.kind).toBe("llm"); - expect(result.get("B")!.kind).toBe("tool"); - }); - - it("sets kind to null for start and end nodes", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("A"), n("__end__", "end")], - [e("e0", "__start__", "A"), e("e1", "A", "__end__")], - )); - expect(result.get("__start__")!.kind).toBeNull(); - expect(result.get("__end__")!.kind).toBeNull(); - }); - - it("includes only node-type nodes in result (not start/end as regular)", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("A"), n("__end__", "end")], - [e("e0", "__start__", "A"), e("e1", "A", "__end__")], - )); - expect(result.size).toBe(3); - expect(result.has("A")).toBe(true); - }); - - it("handles two-node cycle", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("A"), n("B"), n("__end__", "end")], - [ - e("e0", "__start__", "A"), - e("e1", "A", "B"), - e("e2", "B", "A"), - e("e3", "A", "__end__"), - ], - )); - expect(result.has("A")).toBe(true); - expect(result.has("B")).toBe(true); - expect(result.get("B")!.depth).toBe(0); - }); - - it("handles diamond graph without cycles", () => { - const result = computeDepthMap(topo( - [n("__start__", "start"), n("A"), n("B"), n("C"), n("D"), n("__end__", "end")], - [ - e("e0", "__start__", "A"), - e("e1", "A", "B"), - e("e2", "A", "C"), - e("e3", "B", "D"), - e("e4", "C", "D"), - e("e5", "D", "__end__"), - ], - )); - expect(result.get("A")!.depth).toBe(0); - expect(result.get("B")!.depth).toBe(0); - expect(result.get("C")!.depth).toBe(0); - expect(result.get("D")!.depth).toBe(0); - }); - - it("handles complex cycle with conditional exit", () => { - const result = computeDepthMap(topo( - [ - n("__start__", "start"), - n("plan"), n("select_tool"), n("call_tool"), n("check"), n("finalize"), - n("__end__", "end"), - ], - [ - e("e0", "__start__", "plan"), - e("e1", "plan", "select_tool"), - e("e2", "select_tool", "call_tool"), - e("e3", "call_tool", "check"), - e("e4", "check", "select_tool"), - e("e5", "check", "finalize"), - e("e6", "finalize", "__end__"), - ], - )); - expect(result.get("finalize")!.depth).toBe(0); - expect(result.get("select_tool")!.depth).toBeGreaterThan(0); - expect(result.get("call_tool")!.depth).toBeGreaterThan(0); - expect(result.get("check")!.depth).toBe(0); - }); - - it("replicates reactmini_agent network and matches inspect tree depths", () => { - const nodes: GraphMessage["nodes"] = [ - n("__start__", "start"), - n("plan"), - n("select_tool"), - {id: "call_tool", name: "call_tool", node_type: "node", node_kind: "tool"}, - n("reflect"), - n("revise_plan"), - {id: "check_progress", name: "check_progress", node_type: "node", node_kind: "retriever"}, - n("integrate"), - n("final_answer"), - n("__end__", "end"), - ]; - const edges: GraphMessage["edges"] = [ - e("e0", "__start__", "plan"), - e("e1", "plan", "select_tool"), - e("e2", "select_tool", "call_tool"), - e("e3", "call_tool", "check_progress"), - e("e4", "check_progress", "select_tool", true), - e("e5", "check_progress", "integrate", true), - e("e6", "check_progress", "reflect", true), - e("e7", "reflect", "revise_plan"), - e("e8", "revise_plan", "check_progress"), - e("e9", "integrate", "final_answer"), - e("e10", "final_answer", "__end__"), - ]; - const result = computeDepthMap(topo(nodes, edges)); - expect(result.get("__start__")!.depth).toBe(0); - expect(result.get("__end__")!.depth).toBe(0); - expect(result.get("plan")!.depth).toBe(0); - expect(result.get("select_tool")!.depth).toBe(1); - expect(result.get("call_tool")!.depth).toBe(1); - expect(result.get("check_progress")!.depth).toBe(0); - expect(result.get("reflect")!.depth).toBe(1); - expect(result.get("revise_plan")!.depth).toBe(1); - expect(result.get("integrate")!.depth).toBe(0); - expect(result.get("final_answer")!.depth).toBe(0); - expect(result.get("call_tool")!.kind).toBe("tool"); - expect(result.get("check_progress")!.kind).toBe("retriever"); - }); -});