@@ -14,7 +14,6 @@ import {
1414} from "resumable-stream" ;
1515import type { VisibilityType } from "@/components/visibility-selector" ;
1616import { entitlementsByUserType } from "@/lib/ai/entitlements" ;
17- import type { ChatModel } from "@/lib/ai/models" ;
1817import { type RequestHints , systemPrompt } from "@/lib/ai/prompts" ;
1918import { getLanguageModel } from "@/lib/ai/providers" ;
2019import { createDocument } from "@/lib/ai/tools/create-document" ;
@@ -32,6 +31,7 @@ import {
3231 saveChat ,
3332 saveMessages ,
3433 updateChatTitleById ,
34+ updateMessage ,
3535} from "@/lib/db/queries" ;
3636import type { DBMessage } from "@/lib/db/schema" ;
3737import { ChatSDKError } from "@/lib/errors" ;
@@ -78,13 +78,9 @@ export async function POST(request: Request) {
7878 const {
7979 id,
8080 message,
81+ messages,
8182 selectedChatModel,
8283 selectedVisibilityType,
83- } : {
84- id : string ;
85- message : ChatMessage ;
86- selectedChatModel : ChatModel [ "id" ] ;
87- selectedVisibilityType : VisibilityType ;
8884 } = requestBody ;
8985
9086 const session = await getSession ( ) ;
@@ -104,6 +100,10 @@ export async function POST(request: Request) {
104100 return new ChatSDKError ( "rate_limit:chat" ) . toResponse ( ) ;
105101 }
106102
103+ const isToolApprovalFlow = Boolean ( messages ) ;
104+ if ( ! isToolApprovalFlow && ! message ) {
105+ return new ChatSDKError ( "bad_request:api" ) . toResponse ( ) ;
106+ }
107107 const chat = await getChatById ( { id } ) ;
108108 let messagesFromDb : DBMessage [ ] = [ ] ;
109109 let titlePromise : Promise < string > | null = null ;
@@ -112,23 +112,25 @@ export async function POST(request: Request) {
112112 if ( chat . userId !== session . user . id ) {
113113 return new ChatSDKError ( "forbidden:chat" ) . toResponse ( ) ;
114114 }
115- // Only fetch messages if chat already exists
116- messagesFromDb = await getMessagesByChatId ( { id } ) ;
117- } else {
118- // Save chat immediately with a placeholder title
115+ if ( ! isToolApprovalFlow ) {
116+ messagesFromDb = await getMessagesByChatId ( { id } ) ;
117+ }
118+ } else if ( message ?. role === "user" ) {
119119 await saveChat ( {
120120 id,
121121 userId : session . user . id ,
122122 title : "New chat" ,
123123 visibility : selectedVisibilityType ,
124124 } ) ;
125125
126- // Kick off title generation in the background
127126 titlePromise = generateTitleFromUserMessage ( { message } ) ;
128- // New chat - no need to fetch messages, it's empty
127+ } else {
128+ return new ChatSDKError ( "bad_request:chat" ) . toResponse ( ) ;
129129 }
130130
131- const uiMessages = [ ...convertToUIMessages ( messagesFromDb ) , message ] ;
131+ const uiMessages : ChatMessage [ ] = isToolApprovalFlow
132+ ? ( messages as ChatMessage [ ] )
133+ : [ ...convertToUIMessages ( messagesFromDb ) , message as ChatMessage ] ;
132134
133135 const { longitude, latitude, city, country } = geolocation ( request ) ;
134136
@@ -139,24 +141,27 @@ export async function POST(request: Request) {
139141 country,
140142 } ;
141143
142- await saveMessages ( {
143- messages : [
144- {
145- chatId : id ,
146- id : message . id ,
147- role : "user" ,
148- parts : message . parts ,
149- attachments : [ ] ,
150- createdAt : new Date ( ) ,
151- } ,
152- ] ,
153- } ) ;
144+ if ( message ?. role === "user" ) {
145+ await saveMessages ( {
146+ messages : [
147+ {
148+ chatId : id ,
149+ id : message . id ,
150+ role : "user" ,
151+ parts : message . parts ,
152+ attachments : [ ] ,
153+ createdAt : new Date ( ) ,
154+ } ,
155+ ] ,
156+ } ) ;
157+ }
154158
155159 const streamId = generateUUID ( ) ;
156160 await createStreamId ( { streamId, chatId : id } ) ;
157161
158162 const stream = createUIMessageStream ( {
159- execute : ( { writer : dataStream } ) => {
163+ originalMessages : isToolApprovalFlow ? uiMessages : undefined ,
164+ execute : async ( { writer : dataStream } ) => {
160165 // Handle async title generation in parallel for new chats
161166 if ( titlePromise ) {
162167 titlePromise . then ( ( title ) => {
@@ -172,7 +177,7 @@ export async function POST(request: Request) {
172177 const result = streamText ( {
173178 model : getLanguageModel ( selectedChatModel ) ,
174179 system : systemPrompt ( { selectedChatModel, requestHints } ) ,
175- messages : convertToModelMessages ( uiMessages ) ,
180+ messages : await convertToModelMessages ( uiMessages ) ,
176181 stopWhen : stepCountIs ( 5 ) ,
177182 experimental_activeTools : isReasoningModel
178183 ? [ ]
@@ -216,32 +221,63 @@ export async function POST(request: Request) {
216221 ) ;
217222 } ,
218223 generateId : generateUUID ,
219- onFinish : async ( { messages } ) => {
220- await saveMessages ( {
221- messages : messages . map ( ( currentMessage ) => ( {
222- id : currentMessage . id ,
223- role : currentMessage . role ,
224- parts : currentMessage . parts ,
225- createdAt : new Date ( ) ,
226- attachments : [ ] ,
227- chatId : id ,
228- } ) ) ,
229- } ) ;
224+ onFinish : async ( { messages : finishedMessages } ) => {
225+ if ( isToolApprovalFlow ) {
226+ for ( const finishedMsg of finishedMessages ) {
227+ const existingMsg = uiMessages . find ( ( m ) => m . id === finishedMsg . id ) ;
228+ if ( existingMsg ) {
229+ await updateMessage ( {
230+ id : finishedMsg . id ,
231+ parts : finishedMsg . parts ,
232+ } ) ;
233+ } else {
234+ await saveMessages ( {
235+ messages : [
236+ {
237+ id : finishedMsg . id ,
238+ role : finishedMsg . role ,
239+ parts : finishedMsg . parts ,
240+ createdAt : new Date ( ) ,
241+ attachments : [ ] ,
242+ chatId : id ,
243+ } ,
244+ ] ,
245+ } ) ;
246+ }
247+ }
248+ } else if ( finishedMessages . length > 0 ) {
249+ await saveMessages ( {
250+ messages : finishedMessages . map ( ( currentMessage ) => ( {
251+ id : currentMessage . id ,
252+ role : currentMessage . role ,
253+ parts : currentMessage . parts ,
254+ createdAt : new Date ( ) ,
255+ attachments : [ ] ,
256+ chatId : id ,
257+ } ) ) ,
258+ } ) ;
259+ }
230260 } ,
231261 onError : ( ) => {
232262 return "Oops, an error occurred!" ;
233263 } ,
234264 } ) ;
235265
236- // const streamContext = getStreamContext();
266+ const streamContext = getStreamContext ( ) ;
237267
238- // if (streamContext) {
239- // return new Response(
240- // await streamContext.resumableStream(streamId, () =>
241- // stream.pipeThrough(new JsonToSseTransformStream())
242- // )
243- // );
244- // }
268+ if ( streamContext ) {
269+ try {
270+ const resumableStream = await streamContext . resumableStream (
271+ streamId ,
272+ ( ) => stream . pipeThrough ( new JsonToSseTransformStream ( ) )
273+ ) ;
274+ if ( resumableStream ) {
275+ return new Response ( resumableStream ) ;
276+ }
277+ } catch ( error ) {
278+ console . error ( "Failed to create resumable stream:" , error ) ;
279+ }
280+ }
245281
246282 return new Response ( stream . pipeThrough ( new JsonToSseTransformStream ( ) ) ) ;
247283 } catch ( error ) {
0 commit comments