Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 88 additions & 50 deletions app/(chat)/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ import {
type ResumableStreamContext,
} from "resumable-stream";
import { auth, type UserType } from "@/app/(auth)/auth";
import type { VisibilityType } from "@/components/visibility-selector";
import { entitlementsByUserType } from "@/lib/ai/entitlements";
import type { ChatModel } from "@/lib/ai/models";
import { type RequestHints, systemPrompt } from "@/lib/ai/prompts";
import { getLanguageModel } from "@/lib/ai/providers";
import { createDocument } from "@/lib/ai/tools/create-document";
Expand All @@ -32,6 +30,7 @@ import {
saveChat,
saveMessages,
updateChatTitleById,
updateMessage,
} from "@/lib/db/queries";
import type { DBMessage } from "@/lib/db/schema";
import { ChatSDKError } from "@/lib/errors";
Expand Down Expand Up @@ -75,17 +74,8 @@ export async function POST(request: Request) {
}

try {
const {
id,
message,
selectedChatModel,
selectedVisibilityType,
}: {
id: string;
message: ChatMessage;
selectedChatModel: ChatModel["id"];
selectedVisibilityType: VisibilityType;
} = requestBody;
const { id, message, messages, selectedChatModel, selectedVisibilityType } =
requestBody;

const session = await auth();

Expand All @@ -104,6 +94,9 @@ export async function POST(request: Request) {
return new ChatSDKError("rate_limit:chat").toResponse();
}

// Check if this is a tool approval flow (all messages sent)
const isToolApprovalFlow = Boolean(messages);

const chat = await getChatById({ id });
let messagesFromDb: DBMessage[] = [];
let titlePromise: Promise<string> | null = null;
Expand All @@ -112,9 +105,11 @@ export async function POST(request: Request) {
if (chat.userId !== session.user.id) {
return new ChatSDKError("forbidden:chat").toResponse();
}
// Only fetch messages if chat already exists
messagesFromDb = await getMessagesByChatId({ id });
} else {
// Only fetch messages if chat already exists and not tool approval
if (!isToolApprovalFlow) {
messagesFromDb = await getMessagesByChatId({ id });
}
} else if (message?.role === "user") {
// Save chat immediately with placeholder title
await saveChat({
id,
Expand All @@ -127,7 +122,10 @@ export async function POST(request: Request) {
titlePromise = generateTitleFromUserMessage({ message });
}

const uiMessages = [...convertToUIMessages(messagesFromDb), message];
// Use all messages for tool approval, otherwise DB messages + new message
const uiMessages = isToolApprovalFlow
? (messages as ChatMessage[])
: [...convertToUIMessages(messagesFromDb), message as ChatMessage];

const { longitude, latitude, city, country } = geolocation(request);

Expand All @@ -138,24 +136,29 @@ export async function POST(request: Request) {
country,
};

await saveMessages({
messages: [
{
chatId: id,
id: message.id,
role: "user",
parts: message.parts,
attachments: [],
createdAt: new Date(),
},
],
});
// Only save user messages to the database (not tool approval responses)
if (message?.role === "user") {
await saveMessages({
messages: [
{
chatId: id,
id: message.id,
role: "user",
parts: message.parts,
attachments: [],
createdAt: new Date(),
},
],
});
}

const streamId = generateUUID();
await createStreamId({ streamId, chatId: id });

const stream = createUIMessageStream({
execute: ({ writer: dataStream }) => {
// Pass original messages for tool approval continuation
originalMessages: isToolApprovalFlow ? uiMessages : undefined,
execute: async ({ writer: dataStream }) => {
// Handle title generation in parallel
if (titlePromise) {
titlePromise.then((title) => {
Expand All @@ -171,7 +174,7 @@ export async function POST(request: Request) {
const result = streamText({
model: getLanguageModel(selectedChatModel),
system: systemPrompt({ selectedChatModel, requestHints }),
messages: convertToModelMessages(uiMessages),
messages: await convertToModelMessages(uiMessages),
stopWhen: stepCountIs(5),
experimental_activeTools: isReasoningModel
? []
Expand Down Expand Up @@ -215,32 +218,67 @@ export async function POST(request: Request) {
);
},
generateId: generateUUID,
onFinish: async ({ messages }) => {
await saveMessages({
messages: messages.map((currentMessage) => ({
id: currentMessage.id,
role: currentMessage.role,
parts: currentMessage.parts,
createdAt: new Date(),
attachments: [],
chatId: id,
})),
});
onFinish: async ({ messages: finishedMessages }) => {
if (isToolApprovalFlow) {
// For tool approval, update existing messages (tool state changed) and save new ones
for (const finishedMsg of finishedMessages) {
const existingMsg = uiMessages.find((m) => m.id === finishedMsg.id);
if (existingMsg) {
// Update existing message with new parts (tool state changed)
await updateMessage({
id: finishedMsg.id,
parts: finishedMsg.parts,
});
} else {
// Save new message
await saveMessages({
messages: [
{
id: finishedMsg.id,
role: finishedMsg.role,
parts: finishedMsg.parts,
createdAt: new Date(),
attachments: [],
chatId: id,
},
],
});
}
}
} else if (finishedMessages.length > 0) {
// Normal flow - save all finished messages
await saveMessages({
messages: finishedMessages.map((currentMessage) => ({
id: currentMessage.id,
role: currentMessage.role,
parts: currentMessage.parts,
createdAt: new Date(),
attachments: [],
chatId: id,
})),
});
}
},
onError: () => {
return "Oops, an error occurred!";
},
});

// const streamContext = getStreamContext();
const streamContext = getStreamContext();

// if (streamContext) {
// return new Response(
// await streamContext.resumableStream(streamId, () =>
// stream.pipeThrough(new JsonToSseTransformStream())
// )
// );
// }
if (streamContext) {
try {
const resumableStream = await streamContext.resumableStream(
streamId,
() => stream.pipeThrough(new JsonToSseTransformStream())
);
if (resumableStream) {
return new Response(resumableStream);
}
} catch (error) {
console.error("Failed to create resumable stream:", error);
}
}

return new Response(stream.pipeThrough(new JsonToSseTransformStream()));
} catch (error) {
Expand Down
21 changes: 16 additions & 5 deletions app/(chat)/api/chat/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@ const filePartSchema = z.object({

const partSchema = z.union([textPartSchema, filePartSchema]);

const userMessageSchema = z.object({
id: z.string().uuid(),
role: z.enum(["user"]),
parts: z.array(partSchema),
});

// For tool approval flows, we accept all messages (more permissive schema)
const messageSchema = z.object({
id: z.string(),
role: z.string(),
parts: z.array(z.any()),
});

export const postRequestBodySchema = z.object({
id: z.string().uuid(),
message: z.object({
id: z.string().uuid(),
role: z.enum(["user"]),
parts: z.array(partSchema),
}),
// Either a single new message or all messages (for tool approvals)
message: userMessageSchema.optional(),
messages: z.array(messageSchema).optional(),
selectedChatModel: z.string(),
selectedVisibilityType: z.enum(["public", "private"]),
});
Expand Down
6 changes: 0 additions & 6 deletions components/ai-elements/confirmation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ export const ConfirmationRequest = ({ children }: ConfirmationRequestProps) => {
const { state } = useConfirmation();

// Only show when approval is requested
// @ts-expect-error state only available in AI SDK v6
if (state !== "approval-requested") {
return null;
}
Expand All @@ -117,9 +116,7 @@ export const ConfirmationAccepted = ({
// Only show when approved and in response states
if (
!approval?.approved ||
// @ts-expect-error state only available in AI SDK v6
(state !== "approval-responded" &&
// @ts-expect-error state only available in AI SDK v6
state !== "output-denied" &&
state !== "output-available")
) {
Expand All @@ -141,9 +138,7 @@ export const ConfirmationRejected = ({
// Only show when rejected and in response states
if (
approval?.approved !== false ||
// @ts-expect-error state only available in AI SDK v6
(state !== "approval-responded" &&
// @ts-expect-error state only available in AI SDK v6
state !== "output-denied" &&
state !== "output-available")
) {
Expand All @@ -162,7 +157,6 @@ export const ConfirmationActions = ({
const { state } = useConfirmation();

// Only show when approval is requested
// @ts-expect-error state only available in AI SDK v6
if (state !== "approval-requested") {
return null;
}
Expand Down
2 changes: 0 additions & 2 deletions components/ai-elements/tool.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ const getStatusBadge = (status: ToolUIPart["state"]) => {
const labels: Record<ToolUIPart["state"], string> = {
"input-streaming": "Pending",
"input-available": "Running",
// @ts-expect-error state only available in AI SDK v6
"approval-requested": "Awaiting Approval",
"approval-responded": "Responded",
"output-available": "Completed",
Expand All @@ -51,7 +50,6 @@ const getStatusBadge = (status: ToolUIPart["state"]) => {
const icons: Record<ToolUIPart["state"], ReactNode> = {
"input-streaming": <CircleIcon className="size-4" />,
"input-available": <ClockIcon className="size-4 animate-pulse" />,
// @ts-expect-error state only available in AI SDK v6
"approval-requested": <ClockIcon className="size-4 text-yellow-600" />,
"approval-responded": <CheckCircleIcon className="size-4 text-blue-600" />,
"output-available": <CheckCircleIcon className="size-4 text-green-600" />,
Expand Down
10 changes: 9 additions & 1 deletion components/artifact-messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type { UIArtifact } from "./artifact";
import { PreviewMessage, ThinkingMessage } from "./message";

type ArtifactMessagesProps = {
addToolApprovalResponse: UseChatHelpers<ChatMessage>["addToolApprovalResponse"];
chatId: string;
status: UseChatHelpers<ChatMessage>["status"];
votes: Vote[] | undefined;
Expand All @@ -20,6 +21,7 @@ type ArtifactMessagesProps = {
};

function PureArtifactMessages({
addToolApprovalResponse,
chatId,
status,
votes,
Expand All @@ -45,6 +47,7 @@ function PureArtifactMessages({
>
{messages.map((message, index) => (
<PreviewMessage
addToolApprovalResponse={addToolApprovalResponse}
chatId={chatId}
isLoading={status === "streaming" && index === messages.length - 1}
isReadonly={isReadonly}
Expand All @@ -64,7 +67,12 @@ function PureArtifactMessages({
))}

<AnimatePresence mode="wait">
{status === "submitted" && <ThinkingMessage key="thinking" />}
{status === "submitted" &&
!messages.some((msg) =>
msg.parts?.some(
(part) => "state" in part && part.state === "approval-responded"
)
) && <ThinkingMessage key="thinking" />}
</AnimatePresence>

<motion.div
Expand Down
3 changes: 3 additions & 0 deletions components/artifact.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export type UIArtifact = {
};

function PureArtifact({
addToolApprovalResponse,
chatId,
input,
setInput,
Expand All @@ -69,6 +70,7 @@ function PureArtifact({
selectedVisibilityType,
selectedModelId,
}: {
addToolApprovalResponse: UseChatHelpers<ChatMessage>["addToolApprovalResponse"];
chatId: string;
input: string;
setInput: Dispatch<SetStateAction<string>>;
Expand Down Expand Up @@ -320,6 +322,7 @@ function PureArtifact({

<div className="flex h-full flex-col items-center justify-between">
<ArtifactMessages
addToolApprovalResponse={addToolApprovalResponse}
artifactStatus={artifact.status}
chatId={chatId}
isReadonly={isReadonly}
Expand Down
Loading