@@ -13,9 +13,7 @@ import {
1313 type ResumableStreamContext ,
1414} from "resumable-stream" ;
1515import { auth , type UserType } from "@/app/(auth)/auth" ;
16- import type { VisibilityType } from "@/components/visibility-selector" ;
1716import { entitlementsByUserType } from "@/lib/ai/entitlements" ;
18- import type { ChatModel } from "@/lib/ai/models" ;
1917import { type RequestHints , systemPrompt } from "@/lib/ai/prompts" ;
2018import { getLanguageModel } from "@/lib/ai/providers" ;
2119import { createDocument } from "@/lib/ai/tools/create-document" ;
@@ -32,6 +30,7 @@ import {
3230 saveChat ,
3331 saveMessages ,
3432 updateChatTitleById ,
33+ updateMessage ,
3534} from "@/lib/db/queries" ;
3635import type { DBMessage } from "@/lib/db/schema" ;
3736import { ChatSDKError } from "@/lib/errors" ;
@@ -75,17 +74,8 @@ export async function POST(request: Request) {
7574 }
7675
7776 try {
78- const {
79- id,
80- message,
81- selectedChatModel,
82- selectedVisibilityType,
83- } : {
84- id : string ;
85- message : ChatMessage ;
86- selectedChatModel : ChatModel [ "id" ] ;
87- selectedVisibilityType : VisibilityType ;
88- } = requestBody ;
77+ const { id, message, messages, selectedChatModel, selectedVisibilityType } =
78+ requestBody ;
8979
9080 const session = await auth ( ) ;
9181
@@ -104,6 +94,9 @@ export async function POST(request: Request) {
10494 return new ChatSDKError ( "rate_limit:chat" ) . toResponse ( ) ;
10595 }
10696
97+ // Check if this is a tool approval flow (all messages sent)
98+ const isToolApprovalFlow = Boolean ( messages ) ;
99+
107100 const chat = await getChatById ( { id } ) ;
108101 let messagesFromDb : DBMessage [ ] = [ ] ;
109102 let titlePromise : Promise < string > | null = null ;
@@ -112,9 +105,11 @@ export async function POST(request: Request) {
112105 if ( chat . userId !== session . user . id ) {
113106 return new ChatSDKError ( "forbidden:chat" ) . toResponse ( ) ;
114107 }
115- // Only fetch messages if chat already exists
116- messagesFromDb = await getMessagesByChatId ( { id } ) ;
117- } else {
108+ // Only fetch messages if chat already exists and not tool approval
109+ if ( ! isToolApprovalFlow ) {
110+ messagesFromDb = await getMessagesByChatId ( { id } ) ;
111+ }
112+ } else if ( message ?. role === "user" ) {
118113 // Save chat immediately with placeholder title
119114 await saveChat ( {
120115 id,
@@ -127,7 +122,10 @@ export async function POST(request: Request) {
127122 titlePromise = generateTitleFromUserMessage ( { message } ) ;
128123 }
129124
130- const uiMessages = [ ...convertToUIMessages ( messagesFromDb ) , message ] ;
125+ // Use all messages for tool approval, otherwise DB messages + new message
126+ const uiMessages = isToolApprovalFlow
127+ ? ( messages as ChatMessage [ ] )
128+ : [ ...convertToUIMessages ( messagesFromDb ) , message as ChatMessage ] ;
131129
132130 const { longitude, latitude, city, country } = geolocation ( request ) ;
133131
@@ -138,24 +136,29 @@ export async function POST(request: Request) {
138136 country,
139137 } ;
140138
141- await saveMessages ( {
142- messages : [
143- {
144- chatId : id ,
145- id : message . id ,
146- role : "user" ,
147- parts : message . parts ,
148- attachments : [ ] ,
149- createdAt : new Date ( ) ,
150- } ,
151- ] ,
152- } ) ;
139+ // Only save user messages to the database (not tool approval responses)
140+ if ( message ?. role === "user" ) {
141+ await saveMessages ( {
142+ messages : [
143+ {
144+ chatId : id ,
145+ id : message . id ,
146+ role : "user" ,
147+ parts : message . parts ,
148+ attachments : [ ] ,
149+ createdAt : new Date ( ) ,
150+ } ,
151+ ] ,
152+ } ) ;
153+ }
153154
154155 const streamId = generateUUID ( ) ;
155156 await createStreamId ( { streamId, chatId : id } ) ;
156157
157158 const stream = createUIMessageStream ( {
158- execute : ( { writer : dataStream } ) => {
159+ // Pass original messages for tool approval continuation
160+ originalMessages : isToolApprovalFlow ? uiMessages : undefined ,
161+ execute : async ( { writer : dataStream } ) => {
159162 // Handle title generation in parallel
160163 if ( titlePromise ) {
161164 titlePromise . then ( ( title ) => {
@@ -171,7 +174,7 @@ export async function POST(request: Request) {
171174 const result = streamText ( {
172175 model : getLanguageModel ( selectedChatModel ) ,
173176 system : systemPrompt ( { selectedChatModel, requestHints } ) ,
174- messages : convertToModelMessages ( uiMessages ) ,
177+ messages : await convertToModelMessages ( uiMessages ) ,
175178 stopWhen : stepCountIs ( 5 ) ,
176179 experimental_activeTools : isReasoningModel
177180 ? [ ]
@@ -215,32 +218,67 @@ export async function POST(request: Request) {
215218 ) ;
216219 } ,
217220 generateId : generateUUID ,
218- onFinish : async ( { messages } ) => {
219- await saveMessages ( {
220- messages : messages . map ( ( currentMessage ) => ( {
221- id : currentMessage . id ,
222- role : currentMessage . role ,
223- parts : currentMessage . parts ,
224- createdAt : new Date ( ) ,
225- attachments : [ ] ,
226- chatId : id ,
227- } ) ) ,
228- } ) ;
221+ onFinish : async ( { messages : finishedMessages } ) => {
222+ if ( isToolApprovalFlow ) {
223+ // For tool approval, update existing messages (tool state changed) and save new ones
224+ for ( const finishedMsg of finishedMessages ) {
225+ const existingMsg = uiMessages . find ( ( m ) => m . id === finishedMsg . id ) ;
226+ if ( existingMsg ) {
227+ // Update existing message with new parts (tool state changed)
228+ await updateMessage ( {
229+ id : finishedMsg . id ,
230+ parts : finishedMsg . parts ,
231+ } ) ;
232+ } else {
233+ // Save new message
234+ await saveMessages ( {
235+ messages : [
236+ {
237+ id : finishedMsg . id ,
238+ role : finishedMsg . role ,
239+ parts : finishedMsg . parts ,
240+ createdAt : new Date ( ) ,
241+ attachments : [ ] ,
242+ chatId : id ,
243+ } ,
244+ ] ,
245+ } ) ;
246+ }
247+ }
248+ } else if ( finishedMessages . length > 0 ) {
249+ // Normal flow - save all finished messages
250+ await saveMessages ( {
251+ messages : finishedMessages . map ( ( currentMessage ) => ( {
252+ id : currentMessage . id ,
253+ role : currentMessage . role ,
254+ parts : currentMessage . parts ,
255+ createdAt : new Date ( ) ,
256+ attachments : [ ] ,
257+ chatId : id ,
258+ } ) ) ,
259+ } ) ;
260+ }
229261 } ,
230262 onError : ( ) => {
231263 return "Oops, an error occurred!" ;
232264 } ,
233265 } ) ;
234266
235- // const streamContext = getStreamContext();
267+ const streamContext = getStreamContext ( ) ;
236268
237- // if (streamContext) {
238- // return new Response(
239- // await streamContext.resumableStream(streamId, () =>
240- // stream.pipeThrough(new JsonToSseTransformStream())
241- // )
242- // );
243- // }
269+ if ( streamContext ) {
270+ try {
271+ const resumableStream = await streamContext . resumableStream (
272+ streamId ,
273+ ( ) => stream . pipeThrough ( new JsonToSseTransformStream ( ) )
274+ ) ;
275+ if ( resumableStream ) {
276+ return new Response ( resumableStream ) ;
277+ }
278+ } catch ( error ) {
279+ console . error ( "Failed to create resumable stream:" , error ) ;
280+ }
281+ }
244282
245283 return new Response ( stream . pipeThrough ( new JsonToSseTransformStream ( ) ) ) ;
246284 } catch ( error ) {
0 commit comments