11import { useCallback , useMemo } from 'react'
2+ import { useLocation , useNavigate } from 'react-router-dom'
23import type { UIMessage } from '@ai-sdk/react'
34import { useChat } from '@ai-sdk/react'
45import { useToast } from '@opengovsg/design-system-react'
56import { DefaultChatTransport } from 'ai'
67
8+ import * as URLS from '@/config/urls'
9+ import {
10+ deduplicateMessages ,
11+ extractTextContent ,
12+ transformMessages ,
13+ } from '@/pages/AiBuilder/helpers'
14+
715export interface Message {
16+ id : string // this is auto-generated by the AI SDK
817 text : string
9- traceId ?: string
10- generationId ?: string
18+ traceId ?: string // only assistant messages have this
1119 isUser : boolean
1220}
1321
1422// Custom message type with metadata
15- type CustomUIMessage = UIMessage < {
23+ export type CustomUIMessage = UIMessage < {
1624 traceId ?: string
1725} >
1826
@@ -22,6 +30,8 @@ export interface UseChatStreamOptions {
2230
2331export function useChatStream ( options ?: UseChatStreamOptions ) {
2432 const toast = useToast ( )
33+ const navigate = useNavigate ( )
34+ const location = useLocation ( )
2535
2636 const {
2737 messages : aiMessages ,
@@ -34,9 +44,19 @@ export function useChatStream(options?: UseChatStreamOptions) {
3444 api : '/api/chat' ,
3545 credentials : 'include' ,
3646 prepareSendMessagesRequest : ( { messages } ) => {
37- // Send all messages to maintain conversation context
47+ // Convert initialMessages to the format expected by the API
48+ const initialMsgs = ( options ?. initialMessages || [ ] ) . map ( ( msg ) => ( {
49+ id : msg . id ,
50+ role : msg . isUser ? 'user' : 'assistant' ,
51+ parts : [ { type : 'text' , text : msg . text } ] ,
52+ ...( msg . traceId && { metadata : { traceId : msg . traceId } } ) ,
53+ } ) )
54+
55+ // Prepend initial messages to maintain full conversation context
56+ const allMessages = [ ...initialMsgs , ...messages ]
57+
3858 const body = {
39- messages : messages ,
59+ messages : allMessages ,
4060 sessionId : '' ,
4161 }
4262 return { body }
@@ -51,16 +71,29 @@ export function useChatStream(options?: UseChatStreamOptions) {
5171 position : 'top' ,
5272 } )
5373 } ,
74+ onFinish : ( { messages } ) => {
75+ // transform the messages and save to location state
76+ // so that user can still access it if they refresh the page
77+ const transformedMessages = transformMessages ( messages )
78+
79+ // Combine initial messages with new messages to preserve full history
80+ const allMessages = deduplicateMessages ( [
81+ ...( options ?. initialMessages || [ ] ) ,
82+ ...transformedMessages ,
83+ ] )
84+
85+ navigate ( `${ URLS . EDITOR } /ai` , {
86+ state : {
87+ ...location . state ,
88+ isFormMode : false ,
89+ chatInput : allMessages [ allMessages . length - 1 ] . text ,
90+ chatMessages : allMessages ,
91+ } ,
92+ replace : true ,
93+ } )
94+ } ,
5495 } )
5596
56- // Helper function to extract text content from UIMessage
57- const extractTextContent = useCallback ( ( msg : CustomUIMessage ) : string => {
58- return msg . parts
59- . filter ( ( part ) => part . type === 'text' )
60- . map ( ( part ) => ( part as any ) . text )
61- . join ( '' )
62- } , [ ] )
63-
6497 // Transform AI SDK messages to our Message format
6598 const messages = useMemo < Message [ ] > ( ( ) => {
6699 const isActivelyStreaming = status === 'streaming' || status === 'submitted'
@@ -81,20 +114,13 @@ export function useChatStream(options?: UseChatStreamOptions) {
81114 }
82115 }
83116
84- const transformedMessages = messagesToTransform . map ( ( msg ) => {
85- // Extract traceId from message metadata
86- const traceId = msg . metadata ?. traceId
87-
88- return {
89- text : extractTextContent ( msg ) ,
90- isUser : msg . role === 'user' ,
91- traceId : traceId ,
92- }
93- } )
94-
95- // Combine initial messages with new messages
96- return [ ...initialMsgs , ...transformedMessages ]
97- } , [ aiMessages , extractTextContent , status , options ?. initialMessages ] )
117+ const transformedMessages = transformMessages ( messagesToTransform )
118+ const allMessages = deduplicateMessages ( [
119+ ...initialMsgs ,
120+ ...transformedMessages ,
121+ ] )
122+ return allMessages
123+ } , [ aiMessages , options ?. initialMessages , status ] )
98124
99125 // Get the current streaming response (last assistant message that's still being streamed)
100126 const currentResponse = useMemo ( ( ) => {
@@ -109,7 +135,7 @@ export function useChatStream(options?: UseChatStreamOptions) {
109135 return extractTextContent ( lastMessage )
110136 }
111137 return ''
112- } , [ aiMessages , status , extractTextContent ] )
138+ } , [ aiMessages , status ] )
113139
114140 // Wrapper for sendMessage that matches the expected signature
115141 const sendMessageWrapper = useCallback (
@@ -122,16 +148,12 @@ export function useChatStream(options?: UseChatStreamOptions) {
122148 [ sendMessage ] ,
123149 )
124150
125- const cancelStream = useCallback ( ( ) => {
126- stop ( )
127- } , [ stop ] )
128-
129151 return {
130152 messages,
131153 currentResponse,
132154 isStreaming : status === 'submitted' || status === 'streaming' ,
133155 error : aiError ?. message || null ,
134156 sendMessage : sendMessageWrapper ,
135- cancelStream,
157+ cancelStream : stop ,
136158 }
137159}
0 commit comments