Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

96 changes: 64 additions & 32 deletions packages/frontend/src/hooks/useChatStream.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
import { useCallback, useMemo } from 'react'
import { useLocation, useNavigate } from 'react-router-dom'
import type { UIMessage } from '@ai-sdk/react'
import { useChat } from '@ai-sdk/react'
import { useToast } from '@opengovsg/design-system-react'
import { DefaultChatTransport } from 'ai'

import * as URLS from '@/config/urls'
import {
deduplicateMessages,
extractTextContent,
transformMessages,
} from '@/pages/AiBuilder/helpers'

export interface Message {
id: string // this is auto-generated by the AI SDK
text: string
traceId?: string
generationId?: string
traceId?: string // only assistant messages have this
isUser: boolean
}

// Custom message type with metadata
type CustomUIMessage = UIMessage<{
export type CustomUIMessage = UIMessage<{
traceId?: string
}>

export function useChatStream() {
export interface UseChatStreamOptions {
initialMessages?: Message[]
}

export function useChatStream(options?: UseChatStreamOptions) {
const toast = useToast()
const navigate = useNavigate()
const location = useLocation()

const {
messages: aiMessages,
Expand All @@ -30,9 +44,19 @@ export function useChatStream() {
api: '/api/chat',
credentials: 'include',
prepareSendMessagesRequest: ({ messages }) => {
// Send all messages to maintain conversation context
// Convert initialMessages to the format expected by the API
const initialMsgs = (options?.initialMessages || []).map((msg) => ({
id: msg.id,
role: msg.isUser ? 'user' : 'assistant',
parts: [{ type: 'text', text: msg.text }],
...(msg.traceId && { metadata: { traceId: msg.traceId } }),
}))

// Prepend initial messages to maintain full conversation context
const allMessages = [...initialMsgs, ...messages]

const body = {
messages: messages,
messages: allMessages,
sessionId: '',
}
return { body }
Expand All @@ -47,21 +71,37 @@ export function useChatStream() {
position: 'top',
})
},
onFinish: ({ messages }) => {
// transform the messages and save to location state
// so that user can still access it if they refresh the page
const transformedMessages = transformMessages(messages)

// Combine initial messages with new messages to preserve full history
const allMessages = deduplicateMessages([
...(options?.initialMessages || []),
...transformedMessages,
])

navigate(`${URLS.EDITOR}/ai`, {
state: {
...location.state,
isFormMode: false,
chatInput: allMessages[allMessages.length - 1].text,
chatMessages: allMessages,
},
replace: true,
})
},
})

// Helper function to extract text content from UIMessage
const extractTextContent = useCallback((msg: CustomUIMessage): string => {
return msg.parts
.filter((part) => part.type === 'text')
.map((part) => (part as any).text)
.join('')
}, [])

// Transform AI SDK messages to our Message format
const messages = useMemo<Message[]>(() => {
const isActivelyStreaming = status === 'streaming' || status === 'submitted'

// Filter user and assistant messages
// Start with initial messages if provided
const initialMsgs = options?.initialMessages || []

// Filter user and assistant messages from AI SDK
let messagesToTransform = aiMessages.filter(
(msg) => msg.role === 'user' || msg.role === 'assistant',
)
Expand All @@ -74,17 +114,13 @@ export function useChatStream() {
}
}

return messagesToTransform.map((msg) => {
// Extract traceId from message metadata
const traceId = msg.metadata?.traceId

return {
text: extractTextContent(msg),
isUser: msg.role === 'user',
traceId: traceId,
}
})
}, [aiMessages, extractTextContent, status])
const transformedMessages = transformMessages(messagesToTransform)
const allMessages = deduplicateMessages([
...initialMsgs,
...transformedMessages,
])
return allMessages
}, [aiMessages, options?.initialMessages, status])

// Get the current streaming response (last assistant message that's still being streamed)
const currentResponse = useMemo(() => {
Expand All @@ -99,7 +135,7 @@ export function useChatStream() {
return extractTextContent(lastMessage)
}
return ''
}, [aiMessages, status, extractTextContent])
}, [aiMessages, status])

// Wrapper for sendMessage that matches the expected signature
const sendMessageWrapper = useCallback(
Expand All @@ -112,16 +148,12 @@ export function useChatStream() {
[sendMessage],
)

const cancelStream = useCallback(() => {
stop()
}, [stop])

return {
messages,
currentResponse,
isStreaming: status === 'submitted' || status === 'streaming',
error: aiError?.message || null,
sendMessage: sendMessageWrapper,
cancelStream,
cancelStream: stop,
}
}
12 changes: 12 additions & 0 deletions packages/frontend/src/pages/AiBuilder/AiBuilderContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import PrimarySpinner from '@/components/PrimarySpinner'
import { LaunchDarklyContext } from '@/contexts/LaunchDarkly'
import { GET_APPS } from '@/graphql/queries/get-apps'
import { getStepGroupTypeAndCaption, getStepStructure } from '@/helpers/toolbox'
import { Message } from '@/hooks/useChatStream'

interface AIBuilderSharedProps {
flowName: string
Expand All @@ -18,6 +19,7 @@ interface AIBuilderSharedProps {
actions: string
}
chatInput: string
chatMessages: Message[]
isFormMode: boolean
output: {
trigger: IStep
Expand Down Expand Up @@ -62,12 +64,21 @@ export const useAiBuilderContext = () => {

interface AiBuilderContextProviderProps extends AIBuilderSharedProps {
children: React.ReactNode
flowName: string
formInput: {
trigger: string
actions: string
}
chatInput: string
chatMessages: Message[]
isFormMode: boolean
}

export const AiBuilderContextProvider = ({
children,
flowName = 'Name your Pipe', // default to Name your Pipe if no flow name is provided
chatInput,
chatMessages,
formInput,
isFormMode,
output,
Expand Down Expand Up @@ -124,6 +135,7 @@ export const AiBuilderContextProvider = ({
flowName,
formInput,
chatInput,
chatMessages,
isFormMode,
output,
isMobile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ import { Box, Flex, Icon, Text, Textarea } from '@chakra-ui/react'

import pairLogo from '@/assets/pair-logo.svg'
import { ImageBox } from '@/components/FlowStepConfigurationModal/ChooseAndAddConnection/ConfigureExcelConnection'
import IdeaButtons from '@/pages/AiBuilder/components/IdeaButtons'
import { AI_CHAT_IDEAS, AiChatIdea, AiFormIdea } from '@/pages/Flows/constants'

import IdeaButtons from '../IdeaButtons'

interface PromptInputProps {
isStreaming: boolean
showIdeas?: boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export default function SideDrawer({ isOpen, onClose }: SideDrawerProps) {

{/* Content */}
<Box flex={1} overflowY="auto" pb={4}>
<StepsPreview />
{isOpen && <StepsPreview />}
</Box>
</Flex>
</Box>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import { useEffect, useRef, useState } from 'react'
import { IoChevronDown } from 'react-icons/io5'
import { useNavigate } from 'react-router-dom'
import { useLocation, useNavigate } from 'react-router-dom'
import { Box, Flex, IconButton, Text } from '@chakra-ui/react'
import { useIsMobile } from '@opengovsg/design-system-react'

import { parseWorkflow } from '@/components/AiBuilder/helpers/parseMarkdown'
import * as URLS from '@/config/urls'
import { useChatStream } from '@/hooks/useChatStream'
import { useAiBuilderContext } from '@/pages/AiBuilder/AiBuilderContext'
Expand All @@ -15,11 +14,12 @@ import SideDrawer from './SideDrawer'

export default function ChatInterface() {
const navigate = useNavigate()
const location = useLocation()
const isMobile = useIsMobile()
const { flowName, formInput } = useAiBuilderContext()
const { flowName, chatInput, chatMessages } = useAiBuilderContext()

const { messages, currentResponse, isStreaming, sendMessage, cancelStream } =
useChatStream()
useChatStream({ initialMessages: chatMessages })
const messagesEndRef = useRef<HTMLDivElement>(null)
const messagesContainerRef = useRef<HTMLDivElement>(null)
const [isDrawerOpen, setIsDrawerOpen] = useState(false)
Expand Down Expand Up @@ -67,22 +67,17 @@ export default function ChatInterface() {
const hasMessages = messages.length > 0 || isStreaming

const handleOpenPreview = () => {
const { trigger, actions } = parseWorkflow(
messages[messages.length - 1].text,
)

// NOTE: only need to update the location state if there has been changes
// if the user just closed and open the side drawer, we don't need to update
// as we don't want to generate the ai steps again
if (formInput?.trigger !== trigger || formInput?.actions !== actions) {
if (chatInput !== messages[messages.length - 1].text) {
navigate(`${URLS.EDITOR}/ai`, {
state: {
...location.state,
flowName,
isFormMode: false,
formInput: {
trigger,
actions,
},
chatInput: messages[messages.length - 1].text,
chatMessages: messages,
},
replace: true,
})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { Box, Flex } from '@chakra-ui/react'

import Loader from '@/pages/AiBuilder/components/ChatMessages/Loader'
import { ChakraStreamdown } from '@/theme/components/Streamdown'

import Loader from './Loader'
import PlumberAvatar from './PlumberAvatar'

interface StreamingMessageProps {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Modal, ModalContent, ModalOverlay } from '@chakra-ui/react'

import { AiFormData } from '../../schema'
import { AiFormData } from '@/pages/AiBuilder/schema'

import { AIFormModalContent } from '../AIFormModalContent'

const ModifyPromptModal = ({
Expand Down
33 changes: 33 additions & 0 deletions packages/frontend/src/pages/AiBuilder/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
import { CustomUIMessage, Message } from '@/hooks/useChatStream'

export const getPromptFromFormInput = (formInput: {
trigger: string
actions: string
}) => {
return `#### Start the workflow\n${formInput.trigger}\n\n#### Actions\n${formInput.actions}`
}

// deduplicate messages by id
// there may be duplicates when the messages are combined
export const deduplicateMessages = (messages: Message[]) => {
const seen = new Set()
return messages.filter((msg) => {
if (!msg.id || !seen.has(msg.id)) {
if (msg.id) {
seen.add(msg.id)
return true
}
}
})
}

// Helper function to extract text content from UIMessage
export const extractTextContent = (msg: CustomUIMessage): string => {
return msg.parts
.filter((part) => part.type === 'text')
.map((part) => (part as any).text)
.join('')
}

export const transformMessages = (messages: CustomUIMessage[]) => {
return messages.map((msg) => ({
id: msg.id,
text: extractTextContent(msg),
isUser: msg.role === 'user',
traceId: msg.metadata?.traceId,
}))
}
Loading