Skip to content

Commit 14d0b06

Browse files
committed
Fix replay MCP app payload handling
Signed-off-by: Andrew Harvard <aharvard@squareup.com>
1 parent fc93cd2 commit 14d0b06

5 files changed

Lines changed: 271 additions & 105 deletions

File tree

ui/goose2/src/shared/api/__tests__/acpNotificationHandler.test.ts

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import { waitFor } from "@testing-library/react";
22
import { beforeEach, describe, expect, it } from "vitest";
3+
import {
4+
clearReplayBuffer,
5+
getReplayBuffer,
6+
} from "@/features/chat/hooks/replayBuffer";
37
import { useChatStore } from "@/features/chat/stores/chatStore";
48
import type { McpAppPayload } from "@/shared/types/messages";
59
import {
@@ -30,6 +34,8 @@ function createMcpAppPayload(): McpAppPayload {
3034
describe("acpNotificationHandler", () => {
3135
beforeEach(() => {
3236
clearMessageTracking();
37+
clearReplayBuffer("local-session");
38+
clearReplayBuffer("goose-session");
3339
useChatStore.setState({
3440
messagesBySession: {},
3541
sessionStateById: {},
@@ -142,4 +148,146 @@ describe("acpNotificationHandler", () => {
142148
.streamingMessageId,
143149
).toBe("assistant-1");
144150
});
151+
152+
it("replay keeps tool and MCP app content on an assistant message when tool events arrive before text", async () => {
153+
const replaySessionId = "replay-goose-session";
154+
useChatStore.setState({
155+
loadingSessionIds: new Set<string>([replaySessionId]),
156+
});
157+
158+
await handleSessionNotification({
159+
sessionId: replaySessionId,
160+
update: {
161+
sessionUpdate: "user_message_chunk",
162+
messageId: "user-1",
163+
content: {
164+
type: "text",
165+
text: "run the app bench",
166+
},
167+
},
168+
} as never);
169+
170+
await handleSessionNotification({
171+
sessionId: replaySessionId,
172+
update: {
173+
sessionUpdate: "tool_call",
174+
toolCallId: "tool-1",
175+
title: "mcp_app_bench__inspect_host_info",
176+
},
177+
} as never);
178+
179+
await handleSessionNotification({
180+
sessionId: replaySessionId,
181+
update: {
182+
sessionUpdate: "tool_call_update",
183+
toolCallId: "tool-1",
184+
status: "completed",
185+
content: [
186+
{
187+
type: "content",
188+
content: {
189+
type: "text",
190+
text: "Opened the Host Info inspector.",
191+
},
192+
},
193+
],
194+
_meta: {
195+
goose: {
196+
mcpApp: {
197+
toolName: "mcp_app_bench__inspect_host_info",
198+
extensionName: "mcp_app_bench",
199+
resourceUri: "ui://inspect-host-info",
200+
},
201+
},
202+
},
203+
},
204+
} as never);
205+
206+
await handleSessionNotification({
207+
sessionId: replaySessionId,
208+
update: {
209+
sessionUpdate: "agent_message_chunk",
210+
messageId: "assistant-1",
211+
content: {
212+
type: "text",
213+
text: "The Host Info inspector is now open.",
214+
},
215+
},
216+
} as never);
217+
218+
const buffer = getReplayBuffer(replaySessionId);
219+
expect(buffer).toHaveLength(2);
220+
expect(buffer?.[0]).toMatchObject({
221+
id: "user-1",
222+
role: "user",
223+
content: [{ type: "text", text: "run the app bench" }],
224+
});
225+
expect(
226+
buffer?.[0]?.content.some((block) => block.type === "toolRequest"),
227+
).toBe(false);
228+
229+
expect(buffer?.[1]?.id).toBe("assistant-1");
230+
expect(buffer?.[1]?.role).toBe("assistant");
231+
expect(buffer?.[1]?.content.map((block) => block.type)).toEqual([
232+
"toolRequest",
233+
"toolResponse",
234+
"mcpApp",
235+
"text",
236+
]);
237+
expect(buffer?.[1]?.content[2]).toMatchObject({
238+
type: "mcpApp",
239+
id: "tool-1",
240+
payload: {
241+
...createMcpAppPayload(),
242+
sessionId: replaySessionId,
243+
gooseSessionId: replaySessionId,
244+
},
245+
});
246+
});
247+
248+
it("replay preserves gooseSessionId in MCP app payloads before tracker registration", async () => {
249+
const replaySessionId = "replay-goose-session-2";
250+
useChatStore.setState({
251+
loadingSessionIds: new Set<string>([replaySessionId]),
252+
});
253+
254+
await handleSessionNotification({
255+
sessionId: replaySessionId,
256+
update: {
257+
sessionUpdate: "tool_call",
258+
toolCallId: "tool-1",
259+
title: "mcp_app_bench__inspect_host_info",
260+
},
261+
} as never);
262+
263+
await handleSessionNotification({
264+
sessionId: replaySessionId,
265+
update: {
266+
sessionUpdate: "tool_call_update",
267+
toolCallId: "tool-1",
268+
status: "completed",
269+
_meta: {
270+
goose: {
271+
mcpApp: {
272+
toolName: "mcp_app_bench__inspect_host_info",
273+
extensionName: "mcp_app_bench",
274+
resourceUri: "ui://inspect-host-info",
275+
},
276+
},
277+
},
278+
},
279+
} as never);
280+
281+
const buffer = getReplayBuffer(replaySessionId);
282+
const assistant = buffer?.[0];
283+
const mcpAppBlock = assistant?.content.find(
284+
(block) => block.type === "mcpApp",
285+
);
286+
expect(mcpAppBlock).toMatchObject({
287+
type: "mcpApp",
288+
payload: expect.objectContaining({
289+
gooseSessionId: replaySessionId,
290+
}),
291+
});
292+
});
145293
});

ui/goose2/src/shared/api/acpNotificationHandler.ts

Lines changed: 44 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -17,58 +17,19 @@ import type { AcpNotificationHandler } from "./acpConnection";
1717
import {
1818
attachMcpAppPayload,
1919
extractToolResultText,
20-
findMessageInReplayBuffer,
2120
findReplayMessageWithToolCall,
2221
} from "./acpToolCallContent";
2322
import {
24-
getLocalSessionId,
25-
subscribeToSessionRegistration,
26-
} from "./acpSessionTracker";
23+
clearReplayAssistantMessage,
24+
clearReplayAssistantTracking,
25+
ensureReplayAssistantMessage,
26+
getTrackedReplayAssistantMessageId,
27+
} from "./acpReplayAssistant";
28+
import { getLocalSessionId } from "./acpSessionTracker";
2729
import { perfLog } from "@/shared/lib/perfLog";
2830

2931
// Pre-set message ID for the next live stream per goose session
3032
const presetMessageIds = new Map<string, string>();
31-
const pendingUsageUpdates = new Map<string, SessionUpdate[]>();
32-
33-
function shouldBufferPendingUpdate(update: SessionUpdate): boolean {
34-
return update.sessionUpdate === "usage_update";
35-
}
36-
37-
function queuePendingUsageUpdate(
38-
gooseSessionId: string,
39-
update: SessionUpdate,
40-
): void {
41-
const pending = pendingUsageUpdates.get(gooseSessionId);
42-
if (pending) {
43-
pending.push(update);
44-
return;
45-
}
46-
pendingUsageUpdates.set(gooseSessionId, [update]);
47-
}
48-
49-
function flushPendingUsageUpdates(
50-
localSessionId: string,
51-
gooseSessionId: string,
52-
): void {
53-
const pending = pendingUsageUpdates.get(gooseSessionId);
54-
if (!pending?.length) {
55-
return;
56-
}
57-
58-
pendingUsageUpdates.delete(gooseSessionId);
59-
60-
for (const update of pending) {
61-
if (useChatStore.getState().loadingSessionIds.has(localSessionId)) {
62-
handleReplay(localSessionId, update);
63-
} else {
64-
handleLive(localSessionId, gooseSessionId, update);
65-
}
66-
}
67-
}
68-
69-
subscribeToSessionRegistration((localSessionId, gooseSessionId) => {
70-
flushPendingUsageUpdates(localSessionId, gooseSessionId);
71-
});
7233

7334
// Per-session perf counters for replay/live streaming.
7435
interface ReplayPerf {
@@ -117,32 +78,22 @@ export async function handleSessionNotification(
11778
notification: SessionNotification,
11879
): Promise<void> {
11980
const gooseSessionId = notification.sessionId;
81+
const sessionId = getLocalSessionId(gooseSessionId) ?? gooseSessionId;
12082
const { update } = notification;
121-
const localSessionId = getLocalSessionId(gooseSessionId);
122-
123-
if (!localSessionId) {
124-
if (shouldBufferPendingUpdate(update)) {
125-
queuePendingUsageUpdate(gooseSessionId, update);
126-
}
127-
return;
128-
}
129-
130-
const isReplay = useChatStore
131-
.getState()
132-
.loadingSessionIds.has(localSessionId);
83+
const isReplay = useChatStore.getState().loadingSessionIds.has(sessionId);
13384

13485
if (isReplay) {
135-
const sid = localSessionId.slice(0, 8);
136-
let perf = replayPerf.get(localSessionId);
86+
const sid = sessionId.slice(0, 8);
87+
let perf = replayPerf.get(sessionId);
13788
const now = performance.now();
13889
if (!perf) {
13990
perf = { firstAt: now, lastAt: now, count: 0 };
140-
replayPerf.set(localSessionId, perf);
91+
replayPerf.set(sessionId, perf);
14192
perfLog(`[perf:replay] ${sid} first notification received`);
14293
}
14394
perf.lastAt = now;
14495
perf.count += 1;
145-
handleReplay(localSessionId, update);
96+
handleReplay(sessionId, gooseSessionId, update);
14697
} else {
14798
const perf = livePerf.get(gooseSessionId);
14899
if (perf && update.sessionUpdate === "agent_message_chunk") {
@@ -155,7 +106,7 @@ export async function handleSessionNotification(
155106
);
156107
}
157108
}
158-
handleLive(localSessionId, gooseSessionId, update);
109+
handleLive(sessionId, gooseSessionId, update);
159110
}
160111
}
161112

@@ -171,25 +122,17 @@ export function clearReplayPerf(sessionId: string): void {
171122
replayPerf.delete(sessionId);
172123
}
173124

174-
function handleReplay(sessionId: string, update: SessionUpdate): void {
125+
function handleReplay(
126+
sessionId: string,
127+
gooseSessionId: string,
128+
update: SessionUpdate,
129+
): void {
175130
switch (update.sessionUpdate) {
176131
case "agent_message_chunk": {
177-
const messageId = update.messageId ?? crypto.randomUUID();
178-
const buffer = ensureReplayBuffer(sessionId);
179-
if (!getBufferedMessage(sessionId, messageId)) {
180-
buffer.push({
181-
id: messageId,
182-
role: "assistant",
183-
created: Date.now(),
184-
content: [],
185-
metadata: {
186-
userVisible: true,
187-
agentVisible: true,
188-
completionStatus: "inProgress",
189-
},
190-
});
191-
}
192-
const msg = getBufferedMessage(sessionId, messageId);
132+
const msg = ensureReplayAssistantMessage(
133+
sessionId,
134+
update.messageId ?? null,
135+
);
193136
if (msg && update.content.type === "text" && "text" in update.content) {
194137
const last = msg.content[msg.content.length - 1];
195138
if (last?.type === "text") {
@@ -202,6 +145,7 @@ function handleReplay(sessionId: string, update: SessionUpdate): void {
202145
}
203146

204147
case "user_message_chunk": {
148+
clearReplayAssistantMessage(sessionId);
205149
const messageId = update.messageId ?? crypto.randomUUID();
206150
const buffer = ensureReplayBuffer(sessionId);
207151
const existing = getBufferedMessage(sessionId, messageId);
@@ -233,22 +177,25 @@ function handleReplay(sessionId: string, update: SessionUpdate): void {
233177
}
234178

235179
case "tool_call": {
236-
const msg = findMessageInReplayBuffer(sessionId);
237-
if (msg) {
238-
msg.content.push({
239-
type: "toolRequest",
240-
id: update.toolCallId,
241-
name: update.title,
242-
arguments: {},
243-
status: "executing",
244-
startedAt: Date.now(),
245-
});
246-
}
180+
const msg = ensureReplayAssistantMessage(sessionId);
181+
msg.content.push({
182+
type: "toolRequest",
183+
id: update.toolCallId,
184+
name: update.title,
185+
arguments: {},
186+
status: "executing",
187+
startedAt: Date.now(),
188+
});
247189
break;
248190
}
249191

250192
case "tool_call_update": {
251-
const msg = findReplayMessageWithToolCall(sessionId, update.toolCallId);
193+
const replayMessageId = getTrackedReplayAssistantMessageId(sessionId);
194+
const msg =
195+
findReplayMessageWithToolCall(sessionId, update.toolCallId) ??
196+
(replayMessageId
197+
? getBufferedMessage(sessionId, replayMessageId)
198+
: undefined);
252199
if (msg) {
253200
if (update.title) {
254201
const tc = msg.content.find(
@@ -286,6 +233,10 @@ function handleReplay(sessionId: string, update: SessionUpdate): void {
286233
(tc as ToolRequestContent)?.name ?? update.title ?? "",
287234
update,
288235
true,
236+
{
237+
gooseSessionId,
238+
replayMessageId,
239+
},
289240
);
290241
}
291242
}
@@ -459,6 +410,7 @@ function handleShared(sessionId: string, update: SessionUpdate): void {
459410
currentModelId;
460411

461412
const sessionStore = useChatSessionStore.getState();
413+
sessionStore.setSessionModels(sessionId, availableModels);
462414
sessionStore.updateSession(
463415
sessionId,
464416
{ modelId: currentModelId, modelName: currentModelName },
@@ -533,6 +485,7 @@ function ensureLiveAssistantMessage(
533485

534486
export function clearMessageTracking(): void {
535487
presetMessageIds.clear();
488+
clearReplayAssistantTracking();
536489
}
537490

538491
const handler: AcpNotificationHandler = {

0 commit comments

Comments
 (0)