Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions packages/frontend/src/assets/plumber-logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
127 changes: 127 additions & 0 deletions packages/frontend/src/hooks/useChatStream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import { useCallback, useMemo } from 'react'
import type { UIMessage } from '@ai-sdk/react'
import { useChat } from '@ai-sdk/react'
import { useToast } from '@opengovsg/design-system-react'
import { DefaultChatTransport } from 'ai'

export interface Message {
text: string
traceId?: string
generationId?: string
isUser: boolean
}

// Custom message type with metadata
type CustomUIMessage = UIMessage<{
traceId?: string
}>

export function useChatStream() {
const toast = useToast()

const {
messages: aiMessages,
sendMessage,
status,
error: aiError,
stop,
} = useChat<CustomUIMessage>({
transport: new DefaultChatTransport({
api: '/api/chat',
credentials: 'include',
prepareSendMessagesRequest: ({ messages }) => {
// Send all messages to maintain conversation context
const body = {
messages: messages,
sessionId: '',
}
return { body }
},
}),
onError: (error: Error) => {
toast({
title: 'Error: ' + error.message,
status: 'error',
duration: 3000,
isClosable: true,
position: 'top',
})
},
})

// Helper function to extract text content from UIMessage
const extractTextContent = useCallback((msg: CustomUIMessage): string => {
return msg.parts
.filter((part) => part.type === 'text')
.map((part) => (part as any).text)
.join('')
}, [])

// Transform AI SDK messages to our Message format
const messages = useMemo<Message[]>(() => {
const isActivelyStreaming = status === 'streaming' || status === 'submitted'

// Filter user and assistant messages
let messagesToTransform = aiMessages.filter(
(msg) => msg.role === 'user' || msg.role === 'assistant',
)

// If streaming, exclude the last assistant message (shown via currentResponse)
if (isActivelyStreaming && messagesToTransform.length > 0) {
const lastMsg = messagesToTransform[messagesToTransform.length - 1]
if (lastMsg.role === 'assistant') {
messagesToTransform = messagesToTransform.slice(0, -1)
}
}

return messagesToTransform.map((msg) => {
// Extract traceId from message metadata
const traceId = msg.metadata?.traceId

return {
text: extractTextContent(msg),
isUser: msg.role === 'user',
traceId: traceId,
}
})
}, [aiMessages, extractTextContent, status])

// Get the current streaming response (last assistant message that's still being streamed)
const currentResponse = useMemo(() => {
const isActivelyStreaming = status === 'streaming' || status === 'submitted'

if (!isActivelyStreaming) {
return ''
}

const lastMessage = aiMessages[aiMessages.length - 1]
if (lastMessage && lastMessage.role === 'assistant') {
return extractTextContent(lastMessage)
}
return ''
}, [aiMessages, status, extractTextContent])

// Wrapper for sendMessage that matches the expected signature
const sendMessageWrapper = useCallback(
(userPrompt: string) => {
sendMessage({
role: 'user',
parts: [{ type: 'text', text: userPrompt }],
})
},
[sendMessage],
)

const cancelStream = useCallback(() => {
stop()
}, [stop])

return {
messages,
currentResponse,
isStreaming: status === 'submitted' || status === 'streaming',
error: aiError?.message || null,
sendMessage: sendMessageWrapper,
cancelStream,
}
}
10 changes: 9 additions & 1 deletion packages/frontend/src/pages/AiBuilder/AiBuilderContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { datadogRum } from '@datadog/browser-rum'
import { useIsMobile } from '@opengovsg/design-system-react'

import PrimarySpinner from '@/components/PrimarySpinner'
import { LaunchDarklyContext } from '@/contexts/LaunchDarkly'
import { GET_APPS } from '@/graphql/queries/get-apps'
import { getStepGroupTypeAndCaption, getStepStructure } from '@/helpers/toolbox'

Expand All @@ -30,16 +31,18 @@ interface AiBuilderStep extends IStep {

interface AIBuilderContextValue extends AIBuilderSharedProps {
allApps: IApp[]
isMobile: boolean
triggerStep: IStep | null
steps: AiBuilderStep[]
isMobile: boolean
actionSteps: IStep[]
stepsBeforeGroup: IStep[]
groupedSteps: IStep[][]
stepGroupType: string | null
stepGroupCaption: string | null
// DataDog RUM Session ID so we can associate the trace with the RUM
ddSessionId: string
// TODO(kevinkim-ogp): remove this once A/B test is complete
aiBuilderType: string
}

const AiBuilderContext = createContext<AIBuilderContextValue | undefined>(
Expand Down Expand Up @@ -69,6 +72,9 @@ export const AiBuilderContextProvider = ({
}: AiBuilderContextProviderProps) => {
const isMobile = useIsMobile()
const ddSessionId = datadogRum.getInternalContext()?.session_id ?? ''
// TODO(kevinkim-ogp): remove this once A/B test is complete
const { getFlagValue } = useContext(LaunchDarklyContext)
const aiBuilderType = getFlagValue('ai-builder-type', 'none')

const { data: getAppsData, loading: isLoadingAllApps } = useQuery(GET_APPS)

Expand Down Expand Up @@ -126,6 +132,8 @@ export const AiBuilderContextProvider = ({
stepGroupType,
stepGroupCaption,
ddSessionId,
// TODO(kevinkim-ogp): remove this once A/B test is complete
aiBuilderType,
}}
>
{children}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import {
type FormEvent,
type KeyboardEvent,
type SyntheticEvent,
useRef,
useState,
} from 'react'
import { FaArrowCircleRight } from 'react-icons/fa'
import { FaCircleStop } from 'react-icons/fa6'
import { Box, Flex, Icon, Text, Textarea } from '@chakra-ui/react'

import pairLogo from '@/assets/pair-logo.svg'
import { ImageBox } from '@/components/FlowStepConfigurationModal/ChooseAndAddConnection/ConfigureExcelConnection'
import { AI_CHAT_IDEAS, AiChatIdea, AiFormIdea } from '@/pages/Flows/constants'

import IdeaButtons from '../IdeaButtons'

interface PromptInputProps {
isStreaming: boolean
showIdeas?: boolean
placeholder?: string
sendMessage: (message: string) => void
cancelStream: () => void
}

export default function PromptInput({
isStreaming,
showIdeas = false,
placeholder = 'Send a message',
sendMessage,
cancelStream,
}: PromptInputProps) {
const [input, setInput] = useState<string>('')
const textareaRef = useRef<HTMLTextAreaElement>(null)

const handleSubmit = async (e: SyntheticEvent) => {
e.preventDefault()
if (input?.trim()) {
sendMessage(input)
setInput('')
if (textareaRef.current) {
textareaRef.current.style.height = 'auto'
}
}
}

const handleKeyPress = (e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault()
handleSubmit(e)
}
}

const handleResize = (e?: FormEvent<HTMLTextAreaElement>) => {
const target = e?.currentTarget || textareaRef.current
if (!target) {
return
}
const maxHeight = window.innerHeight * 0.4 - 100 // 40vh minus padding/margins
target.style.height = 'auto'
target.style.height = Math.min(target.scrollHeight, maxHeight) + 'px'
}

return (
<Box w="full" maxW="4xl">
<Flex
direction="row"
align="stretch"
bg="white"
border="1px"
borderColor="gray.200"
borderRadius="16px"
boxShadow="0 2px 4px rgba(0, 0, 0, 0.1)"
p={2}
w="full"
minH={showIdeas ? '120px' : '50px'}
height="auto"
mb={6}
>
<Textarea
ref={textareaRef}
disabled={isStreaming}
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={handleKeyPress}
placeholder={placeholder}
w="full"
resize="none"
border="none"
bg="transparent"
p={3}
color="gray.900"
_placeholder={{ color: 'gray.500' }}
_focus={{ outline: 'none', boxShadow: 'none' }}
_disabled={{ opacity: 1, bg: 'transparent', color: 'gray.900' }}
fontSize="base"
lineHeight="6"
maxH="calc(40vh - 100px)"
rows={1}
overflowY="auto"
sx={{
'&::-webkit-scrollbar': {
width: '8px',
},
'&::-webkit-scrollbar-track': {
background: 'transparent',
},
'&::-webkit-scrollbar-thumb': {
backgroundColor: 'rgba(0, 0, 0, 0.2)',
borderRadius: '4px',
},
'&::-webkit-scrollbar-thumb:hover': {
backgroundColor: 'rgba(0, 0, 0, 0.3)',
},
}}
onInput={handleResize}
onFocus={(e) => {
// prevent iOS Safari from zooming in on input focus
e.currentTarget.style.fontSize = '16px'
}}
/>

<Flex justify="end" align="flex-end" p={3}>
{isStreaming ? (
<Icon
as={FaCircleStop}
fontSize="24px"
color="red.500"
cursor="pointer"
onClick={cancelStream}
_hover={{ color: 'red.600' }}
/>
) : (
<Icon
as={FaArrowCircleRight}
fontSize="24px"
color={
input?.trim()
? 'primary.500'
: 'interaction.support.disabled-content'
}
onClick={handleSubmit}
cursor={input?.trim() ? 'pointer' : 'default'}
/>
)}
</Flex>
</Flex>

{showIdeas && (
<IdeaButtons
ideas={AI_CHAT_IDEAS}
onClick={(idea: AiChatIdea | AiFormIdea) => {
setInput((idea as AiChatIdea).input)
// trigger resize after state update
setTimeout(() => handleResize(), 0)
}}
/>
)}

<Flex gap={1} alignItems="center" justify="center" mt={3}>
<Text fontSize="xs" color="gray.500">
Powered by{' '}
</Text>
<ImageBox imageUrl={pairLogo} boxSize={6} />
</Flex>
</Box>
)
}
Loading