Skip to content

Commit 54d11b2

Browse files
committed
chore: upgrade ai sdk v6 + tool approval (vercel#1361)
1 parent c0de3e2 commit 54d11b2

File tree

19 files changed

+426
-211
lines changed

19 files changed

+426
-211
lines changed

app/(chat)/api/chat/route.ts

Lines changed: 82 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import {
1414
} from "resumable-stream";
1515
import type { VisibilityType } from "@/components/visibility-selector";
1616
import { entitlementsByUserType } from "@/lib/ai/entitlements";
17-
import type { ChatModel } from "@/lib/ai/models";
1817
import { type RequestHints, systemPrompt } from "@/lib/ai/prompts";
1918
import { getLanguageModel } from "@/lib/ai/providers";
2019
import { createDocument } from "@/lib/ai/tools/create-document";
@@ -32,6 +31,7 @@ import {
3231
saveChat,
3332
saveMessages,
3433
updateChatTitleById,
34+
updateMessage,
3535
} from "@/lib/db/queries";
3636
import type { DBMessage } from "@/lib/db/schema";
3737
import { ChatSDKError } from "@/lib/errors";
@@ -78,13 +78,9 @@ export async function POST(request: Request) {
7878
const {
7979
id,
8080
message,
81+
messages,
8182
selectedChatModel,
8283
selectedVisibilityType,
83-
}: {
84-
id: string;
85-
message: ChatMessage;
86-
selectedChatModel: ChatModel["id"];
87-
selectedVisibilityType: VisibilityType;
8884
} = requestBody;
8985

9086
const session = await getSession();
@@ -104,6 +100,10 @@ export async function POST(request: Request) {
104100
return new ChatSDKError("rate_limit:chat").toResponse();
105101
}
106102

103+
const isToolApprovalFlow = Boolean(messages);
104+
if (!isToolApprovalFlow && !message) {
105+
return new ChatSDKError("bad_request:api").toResponse();
106+
}
107107
const chat = await getChatById({ id });
108108
let messagesFromDb: DBMessage[] = [];
109109
let titlePromise: Promise<string> | null = null;
@@ -112,23 +112,25 @@ export async function POST(request: Request) {
112112
if (chat.userId !== session.user.id) {
113113
return new ChatSDKError("forbidden:chat").toResponse();
114114
}
115-
// Only fetch messages if chat already exists
116-
messagesFromDb = await getMessagesByChatId({ id });
117-
} else {
118-
// Save chat immediately with a placeholder title
115+
if (!isToolApprovalFlow) {
116+
messagesFromDb = await getMessagesByChatId({ id });
117+
}
118+
} else if (message?.role === "user") {
119119
await saveChat({
120120
id,
121121
userId: session.user.id,
122122
title: "New chat",
123123
visibility: selectedVisibilityType,
124124
});
125125

126-
// Kick off title generation in the background
127126
titlePromise = generateTitleFromUserMessage({ message });
128-
// New chat - no need to fetch messages, it's empty
127+
} else {
128+
return new ChatSDKError("bad_request:chat").toResponse();
129129
}
130130

131-
const uiMessages = [...convertToUIMessages(messagesFromDb), message];
131+
const uiMessages: ChatMessage[] = isToolApprovalFlow
132+
? (messages as ChatMessage[])
133+
: [...convertToUIMessages(messagesFromDb), message as ChatMessage];
132134

133135
const { longitude, latitude, city, country } = geolocation(request);
134136

@@ -139,24 +141,27 @@ export async function POST(request: Request) {
139141
country,
140142
};
141143

142-
await saveMessages({
143-
messages: [
144-
{
145-
chatId: id,
146-
id: message.id,
147-
role: "user",
148-
parts: message.parts,
149-
attachments: [],
150-
createdAt: new Date(),
151-
},
152-
],
153-
});
144+
if (message?.role === "user") {
145+
await saveMessages({
146+
messages: [
147+
{
148+
chatId: id,
149+
id: message.id,
150+
role: "user",
151+
parts: message.parts,
152+
attachments: [],
153+
createdAt: new Date(),
154+
},
155+
],
156+
});
157+
}
154158

155159
const streamId = generateUUID();
156160
await createStreamId({ streamId, chatId: id });
157161

158162
const stream = createUIMessageStream({
159-
execute: ({ writer: dataStream }) => {
163+
originalMessages: isToolApprovalFlow ? uiMessages : undefined,
164+
execute: async ({ writer: dataStream }) => {
160165
// Handle async title generation in parallel for new chats
161166
if (titlePromise) {
162167
titlePromise.then((title) => {
@@ -172,7 +177,7 @@ export async function POST(request: Request) {
172177
const result = streamText({
173178
model: getLanguageModel(selectedChatModel),
174179
system: systemPrompt({ selectedChatModel, requestHints }),
175-
messages: convertToModelMessages(uiMessages),
180+
messages: await convertToModelMessages(uiMessages),
176181
stopWhen: stepCountIs(5),
177182
experimental_activeTools: isReasoningModel
178183
? []
@@ -216,32 +221,63 @@ export async function POST(request: Request) {
216221
);
217222
},
218223
generateId: generateUUID,
219-
onFinish: async ({ messages }) => {
220-
await saveMessages({
221-
messages: messages.map((currentMessage) => ({
222-
id: currentMessage.id,
223-
role: currentMessage.role,
224-
parts: currentMessage.parts,
225-
createdAt: new Date(),
226-
attachments: [],
227-
chatId: id,
228-
})),
229-
});
224+
onFinish: async ({ messages: finishedMessages }) => {
225+
if (isToolApprovalFlow) {
226+
for (const finishedMsg of finishedMessages) {
227+
const existingMsg = uiMessages.find((m) => m.id === finishedMsg.id);
228+
if (existingMsg) {
229+
await updateMessage({
230+
id: finishedMsg.id,
231+
parts: finishedMsg.parts,
232+
});
233+
} else {
234+
await saveMessages({
235+
messages: [
236+
{
237+
id: finishedMsg.id,
238+
role: finishedMsg.role,
239+
parts: finishedMsg.parts,
240+
createdAt: new Date(),
241+
attachments: [],
242+
chatId: id,
243+
},
244+
],
245+
});
246+
}
247+
}
248+
} else if (finishedMessages.length > 0) {
249+
await saveMessages({
250+
messages: finishedMessages.map((currentMessage) => ({
251+
id: currentMessage.id,
252+
role: currentMessage.role,
253+
parts: currentMessage.parts,
254+
createdAt: new Date(),
255+
attachments: [],
256+
chatId: id,
257+
})),
258+
});
259+
}
230260
},
231261
onError: () => {
232262
return "Oops, an error occurred!";
233263
},
234264
});
235265

236-
// const streamContext = getStreamContext();
266+
const streamContext = getStreamContext();
237267

238-
// if (streamContext) {
239-
// return new Response(
240-
// await streamContext.resumableStream(streamId, () =>
241-
// stream.pipeThrough(new JsonToSseTransformStream())
242-
// )
243-
// );
244-
// }
268+
if (streamContext) {
269+
try {
270+
const resumableStream = await streamContext.resumableStream(
271+
streamId,
272+
() => stream.pipeThrough(new JsonToSseTransformStream())
273+
);
274+
if (resumableStream) {
275+
return new Response(resumableStream);
276+
}
277+
} catch (error) {
278+
console.error("Failed to create resumable stream:", error);
279+
}
280+
}
245281

246282
return new Response(stream.pipeThrough(new JsonToSseTransformStream()));
247283
} catch (error) {

app/(chat)/api/chat/schema.ts

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,24 @@ const filePartSchema = z.object({
1414

1515
const partSchema = z.union([textPartSchema, filePartSchema]);
1616

17+
const userMessageSchema = z.object({
18+
id: z.string().uuid(),
19+
role: z.enum(["user"]),
20+
parts: z.array(partSchema),
21+
});
22+
23+
// For tool approval flows, we accept all messages (more permissive schema)
24+
const messageSchema = z.object({
25+
id: z.string(),
26+
role: z.string(),
27+
parts: z.array(z.any()),
28+
});
29+
1730
export const postRequestBodySchema = z.object({
1831
id: z.string().uuid(),
19-
message: z.object({
20-
id: z.string().uuid(),
21-
role: z.enum(["user"]),
22-
parts: z.array(partSchema),
23-
}),
32+
// Either a single new message or all messages (for tool approvals)
33+
message: userMessageSchema.optional(),
34+
messages: z.array(messageSchema).optional(),
2435
selectedChatModel: z.string(),
2536
selectedVisibilityType: z.enum(["public", "private"]),
2637
});

components/ai-elements/confirmation.tsx

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ export const ConfirmationRequest = ({ children }: ConfirmationRequestProps) => {
9797
const { state } = useConfirmation();
9898

9999
// Only show when approval is requested
100-
// @ts-expect-error state only available in AI SDK v6
101100
if (state !== "approval-requested") {
102101
return null;
103102
}
@@ -117,9 +116,7 @@ export const ConfirmationAccepted = ({
117116
// Only show when approved and in response states
118117
if (
119118
!approval?.approved ||
120-
// @ts-expect-error state only available in AI SDK v6
121119
(state !== "approval-responded" &&
122-
// @ts-expect-error state only available in AI SDK v6
123120
state !== "output-denied" &&
124121
state !== "output-available")
125122
) {
@@ -141,9 +138,7 @@ export const ConfirmationRejected = ({
141138
// Only show when rejected and in response states
142139
if (
143140
approval?.approved !== false ||
144-
// @ts-expect-error state only available in AI SDK v6
145141
(state !== "approval-responded" &&
146-
// @ts-expect-error state only available in AI SDK v6
147142
state !== "output-denied" &&
148143
state !== "output-available")
149144
) {
@@ -162,7 +157,6 @@ export const ConfirmationActions = ({
162157
const { state } = useConfirmation();
163158

164159
// Only show when approval is requested
165-
// @ts-expect-error state only available in AI SDK v6
166160
if (state !== "approval-requested") {
167161
return null;
168162
}

components/ai-elements/tool.tsx

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ const getStatusBadge = (status: ToolUIPart["state"]) => {
4040
const labels: Record<ToolUIPart["state"], string> = {
4141
"input-streaming": "Pending",
4242
"input-available": "Running",
43-
// @ts-expect-error state only available in AI SDK v6
4443
"approval-requested": "Awaiting Approval",
4544
"approval-responded": "Responded",
4645
"output-available": "Completed",
@@ -51,7 +50,6 @@ const getStatusBadge = (status: ToolUIPart["state"]) => {
5150
const icons: Record<ToolUIPart["state"], ReactNode> = {
5251
"input-streaming": <CircleIcon className="size-4" />,
5352
"input-available": <ClockIcon className="size-4 animate-pulse" />,
54-
// @ts-expect-error state only available in AI SDK v6
5553
"approval-requested": <ClockIcon className="size-4 text-yellow-600" />,
5654
"approval-responded": <CheckCircleIcon className="size-4 text-blue-600" />,
5755
"output-available": <CheckCircleIcon className="size-4 text-green-600" />,

components/artifact-messages.tsx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import type { UIArtifact } from "./artifact";
99
import { PreviewMessage, ThinkingMessage } from "./message";
1010

1111
type ArtifactMessagesProps = {
12+
addToolApprovalResponse: UseChatHelpers<ChatMessage>["addToolApprovalResponse"];
1213
chatId: string;
1314
status: UseChatHelpers<ChatMessage>["status"];
1415
votes: Vote[] | undefined;
@@ -20,6 +21,7 @@ type ArtifactMessagesProps = {
2021
};
2122

2223
function PureArtifactMessages({
24+
addToolApprovalResponse,
2325
chatId,
2426
status,
2527
votes,
@@ -45,6 +47,7 @@ function PureArtifactMessages({
4547
>
4648
{messages.map((message, index) => (
4749
<PreviewMessage
50+
addToolApprovalResponse={addToolApprovalResponse}
4851
chatId={chatId}
4952
isLoading={status === "streaming" && index === messages.length - 1}
5053
isReadonly={isReadonly}
@@ -64,7 +67,12 @@ function PureArtifactMessages({
6467
))}
6568

6669
<AnimatePresence mode="wait">
67-
{status === "submitted" && <ThinkingMessage key="thinking" />}
70+
{status === "submitted" &&
71+
!messages.some((msg) =>
72+
msg.parts?.some(
73+
(part) => "state" in part && part.state === "approval-responded"
74+
)
75+
) && <ThinkingMessage key="thinking" />}
6876
</AnimatePresence>
6977

7078
<motion.div

components/artifact.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ export type UIArtifact = {
5353
};
5454

5555
function PureArtifact({
56+
addToolApprovalResponse,
5657
chatId,
5758
input,
5859
setInput,
@@ -69,6 +70,7 @@ function PureArtifact({
6970
selectedVisibilityType,
7071
selectedModelId,
7172
}: {
73+
addToolApprovalResponse: UseChatHelpers<ChatMessage>["addToolApprovalResponse"];
7274
chatId: string;
7375
input: string;
7476
setInput: Dispatch<SetStateAction<string>>;
@@ -320,6 +322,7 @@ function PureArtifact({
320322

321323
<div className="flex h-full flex-col items-center justify-between">
322324
<ArtifactMessages
325+
addToolApprovalResponse={addToolApprovalResponse}
323326
artifactStatus={artifact.status}
324327
chatId={chatId}
325328
isReadonly={isReadonly}

0 commit comments

Comments
 (0)