From 51835df920ad7a3ffba1720ee2a74c92218e74be Mon Sep 17 00:00:00 2001 From: ragingwind Date: Thu, 24 Apr 2025 07:39:31 +0000 Subject: [PATCH] feat(FR-783): Refresh with baseURL and token (#3494) resolves 3466 (FR-783) Improved Custom Model Integration in Chat Component This PR enhances the custom model integration in the Chat component by: 1. Refactoring the model fetching logic to properly handle errors and display user-friendly messages 2. Adding a dedicated `ChatModelError` class to standardize error handling 3. Creating a new `EndpointTokenSelect` component to allow users to select from available endpoint tokens 4. Improving the `CustomModelForm` to better handle base URL and token inputs 5. Separating the model selection UI logic to only show when models are available 6. Implementing proper state management for base URL and token values --- react/src/components/Chat/ChatCard.tsx | 160 +++++++++++------- react/src/components/Chat/ChatHeader.tsx | 24 +-- react/src/components/Chat/ChatModel.tsx | 30 ++++ react/src/components/Chat/CustomModelForm.tsx | 130 +++++++------- .../components/Chat/EndpointTokenSelect.tsx | 90 ++++++++++ react/src/components/Chat/ModelSelect.tsx | 39 ++--- resources/i18n/de.json | 10 +- resources/i18n/el.json | 10 +- resources/i18n/en.json | 10 +- resources/i18n/es.json | 10 +- resources/i18n/fi.json | 10 +- resources/i18n/fr.json | 10 +- resources/i18n/id.json | 10 +- resources/i18n/it.json | 10 +- resources/i18n/ja.json | 10 +- resources/i18n/ko.json | 10 +- resources/i18n/mn.json | 10 +- resources/i18n/ms.json | 10 +- resources/i18n/pl.json | 10 +- resources/i18n/pt-BR.json | 10 +- resources/i18n/pt.json | 10 +- resources/i18n/ru.json | 10 +- resources/i18n/th.json | 10 +- resources/i18n/tr.json | 10 +- resources/i18n/vi.json | 10 +- resources/i18n/zh-CN.json | 10 +- resources/i18n/zh-TW.json | 10 +- 27 files changed, 495 insertions(+), 188 deletions(-) create mode 100644 react/src/components/Chat/EndpointTokenSelect.tsx diff --git a/react/src/components/Chat/ChatCard.tsx b/react/src/components/Chat/ChatCard.tsx index eef40cb757..969ae6574a 100644 --- a/react/src/components/Chat/ChatCard.tsx +++ b/react/src/components/Chat/ChatCard.tsx @@ -11,20 +11,23 @@ import { ChatType, Model, } from './ChatModel'; -import { CustomModelAlert, CustomModelForm } from './CustomModelForm'; -import { - ChatCard_endpoint$data, - ChatCard_endpoint$key, -} from './__generated__/ChatCard_endpoint.graphql'; +import { CustomModelForm } from './CustomModelForm'; +import { ChatCard_endpoint$key } from './__generated__/ChatCard_endpoint.graphql'; import { createOpenAI } from '@ai-sdk/openai'; import { useChat } from '@ai-sdk/react'; import { extractReasoningMiddleware, streamText, wrapLanguageModel } from 'ai'; -import { Alert, Card, CardProps, FormInstance } from 'antd'; +import { Alert, App, Card, CardProps } from 'antd'; import { createStyles } from 'antd-style'; import graphql from 'babel-plugin-relay/macro'; -import { isEmpty } from 'lodash'; import _ from 'lodash'; -import React, { useEffect, useMemo, useRef, useState } from 'react'; +import React, { + startTransition, + useEffect, + useMemo, + useRef, + useState, + useTransition, +} from 'react'; import { useFragment } from 'react-relay'; interface ChatCardProps extends CardProps, ChatLifecycleEventType { @@ -80,30 +83,68 @@ function useEndpoint(selectedEndpoint?: ChatCard_endpoint$key | null) { return { endpoint, setEndpoint } as const; } +function createModelsURL(baseURL: string) { + const { origin, port, pathname: path } = new URL(baseURL.trim()); + const host = port.length > 0 ? `${origin}:${port}` : origin; + const normalizedPath = path === '/' ? '/models' : `${path}/models`; + + return new URL(normalizedPath, host).toString(); +} + function useModels( provider: ChatProviderType, fetchKey: string, - endpoint?: ChatCard_endpoint$data | null, + baseURL?: string, + token?: string, ) { + const { t } = useTranslation(); + const getModelsErrorMessage = (status?: number) => { + switch (status) { + case 401: + return t('error.UnauthorizedToken'); + case 404: + return t('error.NotFoundBasePath'); + case 500: + return t('error.InternalServerError'); + case 503: + return t('error.ServiceUnavailable'); + default: + return t('error.UnknownError'); + } + }; + const { data: modelsResult } = useSuspenseTanQuery<{ data: Array; }>({ - queryKey: ['models', fetchKey, endpoint?.endpoint_id], - queryFn: () => { - return endpoint?.url - ? fetch( - new URL( - provider.basePath + '/models', - endpoint?.url ?? undefined, - ).toString(), - ) - .then((res) => res.json()) - .catch((e) => ({ data: [] })) - : Promise.resolve({ data: [] }); + queryKey: ['models', fetchKey, baseURL, token], + queryFn: async () => { + try { + if (baseURL) { + const url = createModelsURL(baseURL); + const authToken = token || provider.apiKey; + const res = await fetch(url, { + headers: { + Authorization: authToken ? `Bearer ${authToken}` : '', + }, + }); + + const url = createModelsURL(baseURL); + const authToken = provider.apiKey; + const res = await fetch(url, { + headers: { + Authorization: authToken ? `Bearer ${authToken}` : '', + }, + }); + + if (!res.ok) { + return { data: [], error: res.status }; + } + + return await res.json(); }, }); - const models = _.map(modelsResult?.data, (m) => ({ + const models = _.map(modelsResult?.data || [], (m) => ({ id: m.id, name: m.id, })) as BAIModel[]; @@ -111,7 +152,7 @@ function useModels( const selectedModelId = useMemo( () => provider.modelId && - _.includes(_.map(modelsResult?.data, 'id'), provider.modelId) + _.includes(_.map(modelsResult?.data || [], 'id'), provider.modelId) ? provider.modelId : (modelsResult?.data?.[0]?.id ?? 'custom'), [modelsResult?.data, provider.modelId], @@ -153,6 +194,10 @@ const ChatHeader = React.memo(PureChatHeader, (prev, next) => { const ChatInput = React.memo(PureChatInput); +function createBaseURL(basePath: string, endpointUrl?: string | null) { + return endpointUrl ? new URL(basePath, endpointUrl).toString() : undefined; +} + const ChatCard: React.FC = ({ chat, selectedEndpoint, @@ -164,37 +209,26 @@ const ChatCard: React.FC = ({ const { styles: { chatCard: chatCardStyle, alert: alertStyle, ...chatCardStyles }, } = useStyles(); - const formRef = useRef(null); + const [isPendingUpdate, startUpdateTransition] = useTransition(); + const dropContainerRef = useRef(null); const [fetchKey, updateFetchKey] = useUpdatableState('first'); const [startTime, setStartTime] = useState(null); const { endpoint, setEndpoint } = useEndpoint(selectedEndpoint); + const [baseURL, setBaseURL] = useState( + createBaseURL(chat.provider.basePath, endpoint?.url), + ); + const [token, setToken] = useState(); const { models, modelId, setModelId } = useModels( chat.provider, fetchKey, - endpoint, + baseURL, + token, ); const { agents, agent, setAgent } = useAgents(chat.provider); const [sync, setSync] = useState(chat.sync); - const baseURL = endpoint?.url - ? new URL(chat.provider.basePath, endpoint?.url ?? undefined).toString() - : undefined; - - const allowCustomModel = isEmpty(models); - const providerSettings = { - baseURL: allowCustomModel - ? formRef.current?.getFieldValue('baseURL') - : baseURL, - modelId: allowCustomModel - ? formRef.current?.getFieldValue('modelId') - : modelId, - apiKey: allowCustomModel - ? formRef.current?.getFieldValue('token') - : chat.provider.apiKey, - }; - const { error, messages, @@ -216,13 +250,13 @@ const ChatCard: React.FC = ({ if (fetchOnClient || modelId === 'custom') { const body = JSON.parse(init?.body as string); const provider = createOpenAI({ - baseURL: providerSettings.baseURL, - apiKey: providerSettings.apiKey || 'dummy', + baseURL: baseURL, + apiKey: token || chat.provider.apiKey || 'dummy', }); const result = streamText({ abortSignal: init?.signal || undefined, model: wrapLanguageModel({ - model: provider(providerSettings.modelId), + model: provider(modelId), middleware: extractReasoningMiddleware({ tagName: 'think' }), }), messages: body?.messages, @@ -234,13 +268,22 @@ const ChatCard: React.FC = ({ return result.toDataStreamResponse({ sendReasoning: true, }); - } else { - return fetch(input, init); } + + return fetch(input, init); }, }); + useEffect(() => { + startTransition(() => { + setBaseURL(createBaseURL(chat.provider.basePath, endpoint?.url)); + setToken(undefined); + updateFetchKey('first'); + }); + }, [endpoint?.url, chat.provider.basePath, updateFetchKey]); + const isStreaming = status === 'streaming' || status === 'submitted'; + return ( = ({ title={ = ({ } ref={dropContainerRef} > - {allowCustomModel ? ( + {_.isEmpty(models) && ( updateFetchKey(baseURL)} /> - ) - } + token={token} + endpointId={endpoint?.endpoint_id} + loading={isPendingUpdate} + onSubmit={(data) => { + startUpdateTransition(() => { + updateFetchKey(); + setBaseURL(data.baseURL); + setToken(data.token); + }); + }} /> - ) : null} + )} {!_.isEmpty(error?.message) ? ( = ({ sync, onClick }) => { interface ChatHeaderProps extends ChatLifecycleEventType { chat: ChatType; showCompareMenuItem?: boolean; - allowCustomModel?: boolean; closable?: boolean; models: BAIModel[]; modelId: string; @@ -59,7 +59,6 @@ interface ChatHeaderProps extends ChatLifecycleEventType { const ChatHeader: React.FC = ({ chat, showCompareMenuItem, - allowCustomModel, closable, models, modelId, @@ -169,16 +168,17 @@ const ChatHeader: React.FC = ({ value={endpoint?.endpoint_id} popupMatchSelectWidth={false} /> - { - startTransition(() => { - setModelId(modelId); - }); - }} - allowCustomModel={allowCustomModel} - /> + {!isEmpty(models) && ( + { + startTransition(() => { + setModelId(modelId); + }); + }} + /> + )} {closable && ( diff --git a/react/src/components/Chat/ChatModel.tsx b/react/src/components/Chat/ChatModel.tsx index 9120be5d1a..715d39148d 100644 --- a/react/src/components/Chat/ChatModel.tsx +++ b/react/src/components/Chat/ChatModel.tsx @@ -3,6 +3,7 @@ import { ChatCard_endpoint$data, ChatCard_endpoint$key, } from './__generated__/ChatCard_endpoint.graphql'; +import i18n from 'i18next'; export type ChatProviderType = { baseURL?: string; @@ -77,3 +78,32 @@ export type ChatOptions = { agents: AIAgent[]; agentId?: string; }; + +export class ChatModelError extends Error { + status: number; + + constructor(status: number) { + super(ChatModelError.errorMessage(status)); + this.name = 'ModelsError'; + this.status = status; + } + + static errorMessage(status: number) { + const messageKey = (() => { + switch (status) { + case 401: + return 'error.UnauthorizedToken'; + case 404: + return 'error.NotFoundBasePath'; + case 500: + return 'error.InternalServerError'; + case 503: + return 'error.ServiceUnavailable'; + default: + return 'error.UnknownError'; + } + })(); + + return i18n.t(messageKey, { status }); + } +} diff --git a/react/src/components/Chat/CustomModelForm.tsx b/react/src/components/Chat/CustomModelForm.tsx index 9494048296..c9e3a5f431 100644 --- a/react/src/components/Chat/CustomModelForm.tsx +++ b/react/src/components/Chat/CustomModelForm.tsx @@ -1,34 +1,44 @@ import Flex from '../Flex'; +import EndpointTokenSelect from './EndpointTokenSelect'; import { ReloadOutlined } from '@ant-design/icons'; -import { - Alert, - Button, - ButtonProps, - Form, - FormInstance, - Input, - theme, -} from 'antd'; +import { Alert, Button, Form, Input, theme } from 'antd'; +import type { FormInstance } from 'antd'; +import { useRef } from 'react'; import { useTranslation } from 'react-i18next'; +export type CustomModelFormValues = { + baseURL?: string; + token?: string; +}; + type CustomModelFormProps = { baseURL?: string; token?: string; - allowCustomModel?: boolean; - alert?: React.ReactNode; - modelId?: string; - formRef: React.RefObject | null>; + endpointId?: string | null; + loading: boolean; + onSubmit?: (formData: CustomModelFormValues) => void; }; +function parseBaseURL(baseURL?: string) { + const { origin, pathname } = new URL(baseURL || ''); + return { + origin: `${origin}/`, + pathname: pathname.replace(/^\//, ''), + }; +} + const CustomModelForm: React.FC = ({ baseURL, token, - allowCustomModel, - alert, - modelId, - formRef, + endpointId, + loading, + onSubmit, }) => { + const { t } = useTranslation(); const { token: themeToken } = theme.useToken(); + const formRef = useRef(null); + + const { origin, pathname: basePath } = parseBaseURL(baseURL); return ( = ({ paddingRight: themeToken.paddingContentHorizontalLG, paddingLeft: themeToken.paddingContentHorizontalLG, backgroundColor: themeToken.colorBgContainer, - // @FIXME: check the condition at the parent component - display: (allowCustomModel && modelId === 'custom' && 'flex') || 'none', }} >
= ({ style={{ flex: 1 }} key={baseURL} initialValues={{ - baseURL: baseURL, + basePath: basePath, token: token, }} > - {alert ? ( -
{alert}
- ) : null} - - +
+ +
+ + - - - - - + + +
); }; -type CustomModelAlertProp = { - onClick?: ButtonProps['onClick']; -}; - -const CustomModelAlert: React.FC = ({ onClick }) => { - const { t } = useTranslation(); - - return ( - } onClick={onClick}> - {t('button.Refresh')} - - } - /> - ); -}; - -export { CustomModelForm, CustomModelAlert }; +export { CustomModelForm }; diff --git a/react/src/components/Chat/EndpointTokenSelect.tsx b/react/src/components/Chat/EndpointTokenSelect.tsx new file mode 100644 index 0000000000..7e00397edb --- /dev/null +++ b/react/src/components/Chat/EndpointTokenSelect.tsx @@ -0,0 +1,90 @@ +import type { + EndpointTokenSelectQuery, + EndpointTokenSelectQuery$data, +} from './__generated__/EndpointTokenSelectQuery.graphql'; +import { useControllableValue } from 'ahooks'; +import { Input, Select } from 'antd'; +import type { SelectProps } from 'antd'; +import graphql from 'babel-plugin-relay/macro'; +import dayjs from 'dayjs'; +import { castArray, sortBy } from 'lodash'; +import { useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useLazyLoadQuery } from 'react-relay'; + +function sortEndpointTokenList( + endpointTokenListData: EndpointTokenSelectQuery$data['endpoint_token_list'], +) { + if (!endpointTokenListData.ok) return []; + + const now = dayjs(); + return sortBy(endpointTokenListData.value?.items, 'created_at').map( + (item) => ({ + label: item?.token, + value: item?.token, + disabled: !dayjs(item?.valid_until).tz().isAfter(now), + }), + ); +} + +interface EndpointTokenSelectProps extends Omit { + endpointId?: string | null; +} + +const EndpointTokenSelect: React.FC = ({ + endpointId, + ...props +}) => { + const { t } = useTranslation(); + const [controllableValue, setControllableValue] = + useControllableValue(props); + + const { endpoint_token_list } = useLazyLoadQuery( + graphql` + query EndpointTokenSelectQuery( + $endpointId: UUID! + $isEmptyEndpointId: Boolean! + ) { + endpoint_token_list(offset: 0, limit: 100, endpoint_id: $endpointId) + @skipOnClient(if: $isEmptyEndpointId) + @catch { + items { + id + token + created_at + valid_until + } + } + } + `, + { + endpointId: endpointId || '', + isEmptyEndpointId: !endpointId, + }, + ); + + const selectOptions = useMemo( + () => sortEndpointTokenList(endpoint_token_list), + [endpoint_token_list], + ); + + return selectOptions.length <= 0 ? ( + setControllableValue(e.target.value)} /> + ) : ( +