diff --git a/components/messages.tsx b/components/messages.tsx index dbcb546907..875ddc5dc3 100644 --- a/components/messages.tsx +++ b/components/messages.tsx @@ -2,7 +2,7 @@ import type { UIMessage } from 'ai'; import { PreviewMessage, ThinkingMessage } from './message'; import { useScrollToBottom } from './use-scroll-to-bottom'; import { Greeting } from './greeting'; -import { memo } from 'react'; +import { memo, useEffect } from 'react'; import type { Vote } from '@/lib/db/schema'; import equal from 'fast-deep-equal'; import type { UseChatHelpers } from '@ai-sdk/react'; @@ -27,12 +27,19 @@ function PureMessages({ reload, isReadonly, }: MessagesProps) { - const [messagesContainerRef, messagesEndRef] = + const { containerRef, endRef, scrollToBottom } = useScrollToBottom(); + useEffect(() => { + if (messages.length <= 0) return; + if (messages[messages.length - 1].role === 'user') { + scrollToBottom(); + } + }, [messages, scrollToBottom]); + return (
{messages.length === 0 && } @@ -58,10 +65,7 @@ function PureMessages({ messages.length > 0 && messages[messages.length - 1].role === 'user' && } -
+
); } diff --git a/components/use-scroll-to-bottom.ts b/components/use-scroll-to-bottom.ts index e45a8a5c73..1b6ced03a2 100644 --- a/components/use-scroll-to-bottom.ts +++ b/components/use-scroll-to-bottom.ts @@ -1,31 +1,58 @@ -import { useEffect, useRef, type RefObject } from 'react'; +import { useEffect, useRef } from 'react'; -export function useScrollToBottom(): [ - RefObject, - RefObject, -] { +export function useScrollToBottom() { const containerRef = useRef(null); const endRef = useRef(null); + const shouldScrollRef = useRef(true); + + const scrollToBottom = () => { + if (endRef.current) { + endRef.current.scrollIntoView({ + behavior: 'instant', + block: 'end', + }); + } + }; useEffect(() => { const container = containerRef.current; const end = endRef.current; if (container && end) { - const observer = new MutationObserver(() => { - end.scrollIntoView({ behavior: 'instant', block: 'end' }); + const intersectionObserver = new IntersectionObserver( + ([entry]) => { + shouldScrollRef.current = entry.isIntersecting; + }, + { threshold: 0 }, + ); + + intersectionObserver.observe(end); + + const mutationObserver = new MutationObserver(() => { + if (shouldScrollRef.current) { + scrollToBottom(); + } }); - observer.observe(container, { + mutationObserver.observe(container, { childList: true, subtree: true, attributes: true, characterData: true, }); - return () => observer.disconnect(); + return () => { + intersectionObserver.disconnect(); + mutationObserver.disconnect(); + }; } }, []); - return [containerRef, endRef]; + return { + containerRef, + endRef, + scrollToBottom, + }; } + +export type UseScrollToBottomReturn = ReturnType;