Skip to content

Commit 64fb25a

Browse files
committed
refactor: preserve messages on refresh
1 parent 734986d commit 64fb25a

File tree

2 files changed

+88
-33
lines changed

2 files changed

+88
-33
lines changed

packages/frontend/src/hooks/useChatStream.ts

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
import { useCallback, useMemo } from 'react'
2+
import { useLocation, useNavigate } from 'react-router-dom'
23
import type { UIMessage } from '@ai-sdk/react'
34
import { useChat } from '@ai-sdk/react'
45
import { useToast } from '@opengovsg/design-system-react'
56
import { DefaultChatTransport } from 'ai'
67

8+
import * as URLS from '@/config/urls'
9+
import {
10+
deduplicateMessages,
11+
extractTextContent,
12+
transformMessages,
13+
} from '@/pages/AiBuilder/helpers'
14+
715
export interface Message {
16+
id: string // this is auto-generated by the AI SDK
817
text: string
9-
traceId?: string
10-
generationId?: string
18+
traceId?: string // only assistant messages have this
1119
isUser: boolean
1220
}
1321

1422
// Custom message type with metadata
15-
type CustomUIMessage = UIMessage<{
23+
export type CustomUIMessage = UIMessage<{
1624
traceId?: string
1725
}>
1826

@@ -22,6 +30,8 @@ export interface UseChatStreamOptions {
2230

2331
export function useChatStream(options?: UseChatStreamOptions) {
2432
const toast = useToast()
33+
const navigate = useNavigate()
34+
const location = useLocation()
2535

2636
const {
2737
messages: aiMessages,
@@ -34,9 +44,19 @@ export function useChatStream(options?: UseChatStreamOptions) {
3444
api: '/api/chat',
3545
credentials: 'include',
3646
prepareSendMessagesRequest: ({ messages }) => {
37-
// Send all messages to maintain conversation context
47+
// Convert initialMessages to the format expected by the API
48+
const initialMsgs = (options?.initialMessages || []).map((msg) => ({
49+
id: msg.id,
50+
role: msg.isUser ? 'user' : 'assistant',
51+
parts: [{ type: 'text', text: msg.text }],
52+
...(msg.traceId && { metadata: { traceId: msg.traceId } }),
53+
}))
54+
55+
// Prepend initial messages to maintain full conversation context
56+
const allMessages = [...initialMsgs, ...messages]
57+
3858
const body = {
39-
messages: messages,
59+
messages: allMessages,
4060
sessionId: '',
4161
}
4262
return { body }
@@ -51,16 +71,29 @@ export function useChatStream(options?: UseChatStreamOptions) {
5171
position: 'top',
5272
})
5373
},
74+
onFinish: ({ messages }) => {
75+
// transform the messages and save to location state
76+
// so that user can still access it if they refresh the page
77+
const transformedMessages = transformMessages(messages)
78+
79+
// Combine initial messages with new messages to preserve full history
80+
const allMessages = deduplicateMessages([
81+
...(options?.initialMessages || []),
82+
...transformedMessages,
83+
])
84+
85+
navigate(`${URLS.EDITOR}/ai`, {
86+
state: {
87+
...location.state,
88+
isFormMode: false,
89+
chatInput: allMessages[allMessages.length - 1].text,
90+
chatMessages: allMessages,
91+
},
92+
replace: true,
93+
})
94+
},
5495
})
5596

56-
// Helper function to extract text content from UIMessage
57-
const extractTextContent = useCallback((msg: CustomUIMessage): string => {
58-
return msg.parts
59-
.filter((part) => part.type === 'text')
60-
.map((part) => (part as any).text)
61-
.join('')
62-
}, [])
63-
6497
// Transform AI SDK messages to our Message format
6598
const messages = useMemo<Message[]>(() => {
6699
const isActivelyStreaming = status === 'streaming' || status === 'submitted'
@@ -81,20 +114,13 @@ export function useChatStream(options?: UseChatStreamOptions) {
81114
}
82115
}
83116

84-
const transformedMessages = messagesToTransform.map((msg) => {
85-
// Extract traceId from message metadata
86-
const traceId = msg.metadata?.traceId
87-
88-
return {
89-
text: extractTextContent(msg),
90-
isUser: msg.role === 'user',
91-
traceId: traceId,
92-
}
93-
})
94-
95-
// Combine initial messages with new messages
96-
return [...initialMsgs, ...transformedMessages]
97-
}, [aiMessages, extractTextContent, status, options?.initialMessages])
117+
const transformedMessages = transformMessages(messagesToTransform)
118+
const allMessages = deduplicateMessages([
119+
...initialMsgs,
120+
...transformedMessages,
121+
])
122+
return allMessages
123+
}, [aiMessages, options?.initialMessages, status])
98124

99125
// Get the current streaming response (last assistant message that's still being streamed)
100126
const currentResponse = useMemo(() => {
@@ -109,7 +135,7 @@ export function useChatStream(options?: UseChatStreamOptions) {
109135
return extractTextContent(lastMessage)
110136
}
111137
return ''
112-
}, [aiMessages, status, extractTextContent])
138+
}, [aiMessages, status])
113139

114140
// Wrapper for sendMessage that matches the expected signature
115141
const sendMessageWrapper = useCallback(
@@ -122,16 +148,12 @@ export function useChatStream(options?: UseChatStreamOptions) {
122148
[sendMessage],
123149
)
124150

125-
const cancelStream = useCallback(() => {
126-
stop()
127-
}, [stop])
128-
129151
return {
130152
messages,
131153
currentResponse,
132154
isStreaming: status === 'submitted' || status === 'streaming',
133155
error: aiError?.message || null,
134156
sendMessage: sendMessageWrapper,
135-
cancelStream,
157+
cancelStream: stop,
136158
}
137159
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,39 @@
1+
import { CustomUIMessage, Message } from '@/hooks/useChatStream'
2+
13
export const getPromptFromFormInput = (formInput: {
24
trigger: string
35
actions: string
46
}) => {
57
return `#### Start the workflow\n${formInput.trigger}\n\n#### Actions\n${formInput.actions}`
68
}
9+
10+
// deduplicate messages by id
11+
// there may be duplicates when the messages are combined
12+
export const deduplicateMessages = (messages: Message[]) => {
13+
const seen = new Set()
14+
return messages.filter((msg) => {
15+
if (!msg.id || !seen.has(msg.id)) {
16+
if (msg.id) {
17+
seen.add(msg.id)
18+
return true
19+
}
20+
}
21+
})
22+
}
23+
24+
// Helper function to extract text content from UIMessage
25+
export const extractTextContent = (msg: CustomUIMessage): string => {
26+
return msg.parts
27+
.filter((part) => part.type === 'text')
28+
.map((part) => (part as any).text)
29+
.join('')
30+
}
31+
32+
export const transformMessages = (messages: CustomUIMessage[]) => {
33+
return messages.map((msg) => ({
34+
id: msg.id,
35+
text: extractTextContent(msg),
36+
isUser: msg.role === 'user',
37+
traceId: msg.metadata?.traceId,
38+
}))
39+
}

0 commit comments

Comments
 (0)