Skip to content

Commit 4d3ba8d

Browse files
authored
🎄 merry christmas: ai sdk v6 beta + tool approval (#1361)
1 parent 6e5b883 commit 4d3ba8d

File tree

21 files changed

+429
-202
lines changed

21 files changed

+429
-202
lines changed

app/(chat)/api/chat/route.ts

Lines changed: 88 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ import {
1313
type ResumableStreamContext,
1414
} from "resumable-stream";
1515
import { auth, type UserType } from "@/app/(auth)/auth";
16-
import type { VisibilityType } from "@/components/visibility-selector";
1716
import { entitlementsByUserType } from "@/lib/ai/entitlements";
18-
import type { ChatModel } from "@/lib/ai/models";
1917
import { type RequestHints, systemPrompt } from "@/lib/ai/prompts";
2018
import { getLanguageModel } from "@/lib/ai/providers";
2119
import { createDocument } from "@/lib/ai/tools/create-document";
@@ -32,6 +30,7 @@ import {
3230
saveChat,
3331
saveMessages,
3432
updateChatTitleById,
33+
updateMessage,
3534
} from "@/lib/db/queries";
3635
import type { DBMessage } from "@/lib/db/schema";
3736
import { ChatSDKError } from "@/lib/errors";
@@ -75,17 +74,8 @@ export async function POST(request: Request) {
7574
}
7675

7776
try {
78-
const {
79-
id,
80-
message,
81-
selectedChatModel,
82-
selectedVisibilityType,
83-
}: {
84-
id: string;
85-
message: ChatMessage;
86-
selectedChatModel: ChatModel["id"];
87-
selectedVisibilityType: VisibilityType;
88-
} = requestBody;
77+
const { id, message, messages, selectedChatModel, selectedVisibilityType } =
78+
requestBody;
8979

9080
const session = await auth();
9181

@@ -104,6 +94,9 @@ export async function POST(request: Request) {
10494
return new ChatSDKError("rate_limit:chat").toResponse();
10595
}
10696

97+
// Check if this is a tool approval flow (all messages sent)
98+
const isToolApprovalFlow = Boolean(messages);
99+
107100
const chat = await getChatById({ id });
108101
let messagesFromDb: DBMessage[] = [];
109102
let titlePromise: Promise<string> | null = null;
@@ -112,9 +105,11 @@ export async function POST(request: Request) {
112105
if (chat.userId !== session.user.id) {
113106
return new ChatSDKError("forbidden:chat").toResponse();
114107
}
115-
// Only fetch messages if chat already exists
116-
messagesFromDb = await getMessagesByChatId({ id });
117-
} else {
108+
// Only fetch messages if chat already exists and not tool approval
109+
if (!isToolApprovalFlow) {
110+
messagesFromDb = await getMessagesByChatId({ id });
111+
}
112+
} else if (message?.role === "user") {
118113
// Save chat immediately with placeholder title
119114
await saveChat({
120115
id,
@@ -127,7 +122,10 @@ export async function POST(request: Request) {
127122
titlePromise = generateTitleFromUserMessage({ message });
128123
}
129124

130-
const uiMessages = [...convertToUIMessages(messagesFromDb), message];
125+
// Use all messages for tool approval, otherwise DB messages + new message
126+
const uiMessages = isToolApprovalFlow
127+
? (messages as ChatMessage[])
128+
: [...convertToUIMessages(messagesFromDb), message as ChatMessage];
131129

132130
const { longitude, latitude, city, country } = geolocation(request);
133131

@@ -138,24 +136,29 @@ export async function POST(request: Request) {
138136
country,
139137
};
140138

141-
await saveMessages({
142-
messages: [
143-
{
144-
chatId: id,
145-
id: message.id,
146-
role: "user",
147-
parts: message.parts,
148-
attachments: [],
149-
createdAt: new Date(),
150-
},
151-
],
152-
});
139+
// Only save user messages to the database (not tool approval responses)
140+
if (message?.role === "user") {
141+
await saveMessages({
142+
messages: [
143+
{
144+
chatId: id,
145+
id: message.id,
146+
role: "user",
147+
parts: message.parts,
148+
attachments: [],
149+
createdAt: new Date(),
150+
},
151+
],
152+
});
153+
}
153154

154155
const streamId = generateUUID();
155156
await createStreamId({ streamId, chatId: id });
156157

157158
const stream = createUIMessageStream({
158-
execute: ({ writer: dataStream }) => {
159+
// Pass original messages for tool approval continuation
160+
originalMessages: isToolApprovalFlow ? uiMessages : undefined,
161+
execute: async ({ writer: dataStream }) => {
159162
// Handle title generation in parallel
160163
if (titlePromise) {
161164
titlePromise.then((title) => {
@@ -171,7 +174,7 @@ export async function POST(request: Request) {
171174
const result = streamText({
172175
model: getLanguageModel(selectedChatModel),
173176
system: systemPrompt({ selectedChatModel, requestHints }),
174-
messages: convertToModelMessages(uiMessages),
177+
messages: await convertToModelMessages(uiMessages),
175178
stopWhen: stepCountIs(5),
176179
experimental_activeTools: isReasoningModel
177180
? []
@@ -215,32 +218,67 @@ export async function POST(request: Request) {
215218
);
216219
},
217220
generateId: generateUUID,
218-
onFinish: async ({ messages }) => {
219-
await saveMessages({
220-
messages: messages.map((currentMessage) => ({
221-
id: currentMessage.id,
222-
role: currentMessage.role,
223-
parts: currentMessage.parts,
224-
createdAt: new Date(),
225-
attachments: [],
226-
chatId: id,
227-
})),
228-
});
221+
onFinish: async ({ messages: finishedMessages }) => {
222+
if (isToolApprovalFlow) {
223+
// For tool approval, update existing messages (tool state changed) and save new ones
224+
for (const finishedMsg of finishedMessages) {
225+
const existingMsg = uiMessages.find((m) => m.id === finishedMsg.id);
226+
if (existingMsg) {
227+
// Update existing message with new parts (tool state changed)
228+
await updateMessage({
229+
id: finishedMsg.id,
230+
parts: finishedMsg.parts,
231+
});
232+
} else {
233+
// Save new message
234+
await saveMessages({
235+
messages: [
236+
{
237+
id: finishedMsg.id,
238+
role: finishedMsg.role,
239+
parts: finishedMsg.parts,
240+
createdAt: new Date(),
241+
attachments: [],
242+
chatId: id,
243+
},
244+
],
245+
});
246+
}
247+
}
248+
} else if (finishedMessages.length > 0) {
249+
// Normal flow - save all finished messages
250+
await saveMessages({
251+
messages: finishedMessages.map((currentMessage) => ({
252+
id: currentMessage.id,
253+
role: currentMessage.role,
254+
parts: currentMessage.parts,
255+
createdAt: new Date(),
256+
attachments: [],
257+
chatId: id,
258+
})),
259+
});
260+
}
229261
},
230262
onError: () => {
231263
return "Oops, an error occurred!";
232264
},
233265
});
234266

235-
// const streamContext = getStreamContext();
267+
const streamContext = getStreamContext();
236268

237-
// if (streamContext) {
238-
// return new Response(
239-
// await streamContext.resumableStream(streamId, () =>
240-
// stream.pipeThrough(new JsonToSseTransformStream())
241-
// )
242-
// );
243-
// }
269+
if (streamContext) {
270+
try {
271+
const resumableStream = await streamContext.resumableStream(
272+
streamId,
273+
() => stream.pipeThrough(new JsonToSseTransformStream())
274+
);
275+
if (resumableStream) {
276+
return new Response(resumableStream);
277+
}
278+
} catch (error) {
279+
console.error("Failed to create resumable stream:", error);
280+
}
281+
}
244282

245283
return new Response(stream.pipeThrough(new JsonToSseTransformStream()));
246284
} catch (error) {

app/(chat)/api/chat/schema.ts

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,24 @@ const filePartSchema = z.object({
1414

1515
const partSchema = z.union([textPartSchema, filePartSchema]);
1616

17+
const userMessageSchema = z.object({
18+
id: z.string().uuid(),
19+
role: z.enum(["user"]),
20+
parts: z.array(partSchema),
21+
});
22+
23+
// For tool approval flows, we accept all messages (more permissive schema)
24+
const messageSchema = z.object({
25+
id: z.string(),
26+
role: z.string(),
27+
parts: z.array(z.any()),
28+
});
29+
1730
export const postRequestBodySchema = z.object({
1831
id: z.string().uuid(),
19-
message: z.object({
20-
id: z.string().uuid(),
21-
role: z.enum(["user"]),
22-
parts: z.array(partSchema),
23-
}),
32+
// Either a single new message or all messages (for tool approvals)
33+
message: userMessageSchema.optional(),
34+
messages: z.array(messageSchema).optional(),
2435
selectedChatModel: z.string(),
2536
selectedVisibilityType: z.enum(["public", "private"]),
2637
});

components/ai-elements/confirmation.tsx

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ export const ConfirmationRequest = ({ children }: ConfirmationRequestProps) => {
9797
const { state } = useConfirmation();
9898

9999
// Only show when approval is requested
100-
// @ts-expect-error state only available in AI SDK v6
101100
if (state !== "approval-requested") {
102101
return null;
103102
}
@@ -117,9 +116,7 @@ export const ConfirmationAccepted = ({
117116
// Only show when approved and in response states
118117
if (
119118
!approval?.approved ||
120-
// @ts-expect-error state only available in AI SDK v6
121119
(state !== "approval-responded" &&
122-
// @ts-expect-error state only available in AI SDK v6
123120
state !== "output-denied" &&
124121
state !== "output-available")
125122
) {
@@ -141,9 +138,7 @@ export const ConfirmationRejected = ({
141138
// Only show when rejected and in response states
142139
if (
143140
approval?.approved !== false ||
144-
// @ts-expect-error state only available in AI SDK v6
145141
(state !== "approval-responded" &&
146-
// @ts-expect-error state only available in AI SDK v6
147142
state !== "output-denied" &&
148143
state !== "output-available")
149144
) {
@@ -162,7 +157,6 @@ export const ConfirmationActions = ({
162157
const { state } = useConfirmation();
163158

164159
// Only show when approval is requested
165-
// @ts-expect-error state only available in AI SDK v6
166160
if (state !== "approval-requested") {
167161
return null;
168162
}

components/ai-elements/tool.tsx

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ const getStatusBadge = (status: ToolUIPart["state"]) => {
4040
const labels: Record<ToolUIPart["state"], string> = {
4141
"input-streaming": "Pending",
4242
"input-available": "Running",
43-
// @ts-expect-error state only available in AI SDK v6
4443
"approval-requested": "Awaiting Approval",
4544
"approval-responded": "Responded",
4645
"output-available": "Completed",
@@ -51,7 +50,6 @@ const getStatusBadge = (status: ToolUIPart["state"]) => {
5150
const icons: Record<ToolUIPart["state"], ReactNode> = {
5251
"input-streaming": <CircleIcon className="size-4" />,
5352
"input-available": <ClockIcon className="size-4 animate-pulse" />,
54-
// @ts-expect-error state only available in AI SDK v6
5553
"approval-requested": <ClockIcon className="size-4 text-yellow-600" />,
5654
"approval-responded": <CheckCircleIcon className="size-4 text-blue-600" />,
5755
"output-available": <CheckCircleIcon className="size-4 text-green-600" />,

components/artifact-messages.tsx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import type { UIArtifact } from "./artifact";
99
import { PreviewMessage, ThinkingMessage } from "./message";
1010

1111
type ArtifactMessagesProps = {
12+
addToolApprovalResponse: UseChatHelpers<ChatMessage>["addToolApprovalResponse"];
1213
chatId: string;
1314
status: UseChatHelpers<ChatMessage>["status"];
1415
votes: Vote[] | undefined;
@@ -20,6 +21,7 @@ type ArtifactMessagesProps = {
2021
};
2122

2223
function PureArtifactMessages({
24+
addToolApprovalResponse,
2325
chatId,
2426
status,
2527
votes,
@@ -45,6 +47,7 @@ function PureArtifactMessages({
4547
>
4648
{messages.map((message, index) => (
4749
<PreviewMessage
50+
addToolApprovalResponse={addToolApprovalResponse}
4851
chatId={chatId}
4952
isLoading={status === "streaming" && index === messages.length - 1}
5053
isReadonly={isReadonly}
@@ -64,7 +67,12 @@ function PureArtifactMessages({
6467
))}
6568

6669
<AnimatePresence mode="wait">
67-
{status === "submitted" && <ThinkingMessage key="thinking" />}
70+
{status === "submitted" &&
71+
!messages.some((msg) =>
72+
msg.parts?.some(
73+
(part) => "state" in part && part.state === "approval-responded"
74+
)
75+
) && <ThinkingMessage key="thinking" />}
6876
</AnimatePresence>
6977

7078
<motion.div

components/artifact.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ export type UIArtifact = {
5353
};
5454

5555
function PureArtifact({
56+
addToolApprovalResponse,
5657
chatId,
5758
input,
5859
setInput,
@@ -69,6 +70,7 @@ function PureArtifact({
6970
selectedVisibilityType,
7071
selectedModelId,
7172
}: {
73+
addToolApprovalResponse: UseChatHelpers<ChatMessage>["addToolApprovalResponse"];
7274
chatId: string;
7375
input: string;
7476
setInput: Dispatch<SetStateAction<string>>;
@@ -320,6 +322,7 @@ function PureArtifact({
320322

321323
<div className="flex h-full flex-col items-center justify-between">
322324
<ArtifactMessages
325+
addToolApprovalResponse={addToolApprovalResponse}
323326
artifactStatus={artifact.status}
324327
chatId={chatId}
325328
isReadonly={isReadonly}

0 commit comments

Comments
 (0)