Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions backend/app/alembic/versions/04947f9684ab_public_chat_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""public_chat_engine

Revision ID: 04947f9684ab
Revises: 211f3c5aa125
Create Date: 2025-05-28 15:13:22.058160

"""

from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "04947f9684ab"
down_revision = "211f3c5aa125"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chat_engines", sa.Column("is_public", sa.Boolean(), nullable=False))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_engines", "is_public")
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import APIRouter
from app.api.routes import (
chat_engine,
index,
chat,
user,
Expand Down Expand Up @@ -62,6 +63,7 @@
api_router.include_router(user.router, tags=["user"])
api_router.include_router(api_key.router, tags=["auth"])
api_router.include_router(document.router, tags=["documents"])
api_router.include_router(chat_engine.router, tags=["chat-engines"])
api_router.include_router(retrieve_routes.router, tags=["retrieve"])
api_router.include_router(admin_user_router)
api_router.include_router(admin_chat_engine.router, tags=["admin/chat-engines"])
Expand Down
38 changes: 38 additions & 0 deletions backend/app/api/routes/chat_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

from fastapi import APIRouter, Depends
from app.api.deps import SessionDep
from fastapi_pagination import Params, Page

from app.models.chat_engine import ChatEngine
from app.rag.chat.config import ChatEngineConfig
from app.repositories.chat_engine import chat_engine_repo

logger = logging.getLogger(__name__)

router = APIRouter()


@router.get("/chat-engines")
def list_chat_engines(
db_session: SessionDep,
params: Params = Depends(),
) -> Page[ChatEngine]:
page = chat_engine_repo.paginate(db_session, params, need_public=True)
for chat_engine in page.items:
engine_config = ChatEngineConfig.model_validate(chat_engine.engine_options)
chat_engine.engine_options = engine_config.screenshot()
return page


@router.get("/chat-engines/{chat_engine_id}")
def get_chat_engine(
db_session: SessionDep,
chat_engine_id: int,
) -> ChatEngine:
chat_engine = chat_engine_repo.must_get(
db_session, chat_engine_id, need_public=True
)
engine_config = ChatEngineConfig.model_validate(chat_engine.engine_options)
chat_engine.engine_options = engine_config.screenshot()
return chat_engine
2 changes: 2 additions & 0 deletions backend/app/models/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ChatEngine(UpdatableBaseModel, table=True):
},
)
is_default: bool = Field(default=False)
is_public: bool = Field(default=False)
deleted_at: Optional[datetime] = Field(default=None, sa_column=Column(DateTime))

__tablename__ = "chat_engines"
Expand All @@ -48,3 +49,4 @@ class ChatEngineUpdate(BaseModel):
reranker_id: Optional[int] = None
engine_options: Optional[dict] = None
is_default: Optional[bool] = None
is_public: Optional[bool] = None
5 changes: 5 additions & 0 deletions backend/app/rag/chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def screenshot(self) -> dict:
"condense_question_prompt",
"text_qa_prompt",
"refine_prompt",
"intent_graph_knowledge",
"normal_graph_knowledge",
"generate_goal_prompt",
"further_questions_prompt",
"clarifying_question_prompt",
],
"post_verification_token": True,
}
Expand Down
22 changes: 16 additions & 6 deletions backend/app/repositories/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
class ChatEngineRepo(BaseRepo):
model_cls = ChatEngine

def get(self, session: Session, id: int) -> Optional[ChatEngine]:
return session.exec(
select(ChatEngine).where(ChatEngine.id == id, ChatEngine.deleted_at == None)
).first()
def get(
self, session: Session, id: int, need_public: bool = False
) -> Optional[ChatEngine]:
query = select(ChatEngine).where(
ChatEngine.id == id, ChatEngine.deleted_at == None
)
if need_public:
query = query.where(ChatEngine.is_public == True)
return session.exec(query).first()

def must_get(self, session: Session, id: int) -> ChatEngine:
chat_engine = self.get(session, id)
def must_get(
self, session: Session, id: int, need_public: bool = False
) -> ChatEngine:
chat_engine = self.get(session, id, need_public)
if chat_engine is None:
raise ChatEngineNotFound(id)
return chat_engine
Expand All @@ -30,8 +37,11 @@ def paginate(
self,
session: Session,
params: Params | None = Params(),
need_public: bool = False,
) -> Page[ChatEngine]:
query = select(ChatEngine).where(ChatEngine.deleted_at == None)
if need_public:
query = query.where(ChatEngine.is_public == True)
# Make sure the default engine is always on top
query = query.order_by(ChatEngine.is_default.desc(), ChatEngine.name)
return paginate(session, query, params)
Expand Down
16 changes: 16 additions & 0 deletions frontend/app/src/api/chat-engines.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export interface ChatEngine {
fast_llm_id: number | null;
reranker_id: number | null;
is_default: boolean;
is_public: boolean;
}

export interface CreateChatEngineParams {
Expand Down Expand Up @@ -134,6 +135,7 @@ const chatEngineSchema = z.object({
fast_llm_id: z.number().nullable(),
reranker_id: z.number().nullable(),
is_default: z.boolean(),
is_public: z.boolean(),
}) satisfies ZodType<ChatEngine, any, any>;

export async function getDefaultChatEngineOptions (): Promise<ChatEngineOptions> {
Expand Down Expand Up @@ -190,3 +192,17 @@ export async function deleteChatEngine (id: number): Promise<void> {
})
.then(handleErrors);
}

export async function listPublicChatEngines ({ page = 1, size = 10 }: PageParams = {}): Promise<Page<ChatEngine>> {
return await fetch(requestUrl('/api/v1/chat-engines', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(chatEngineSchema)));
}

export async function getPublicChatEngine (id: number): Promise<ChatEngine> {
return await fetch(requestUrl(`/api/v1/chat-engines/${id}`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(chatEngineSchema));
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ const columns = [
header: 'IS DEFAULT',
cell: boolean
}),
helper.accessor('is_public', {
header: 'IS PUBLIC',
cell: boolean
}),
helper.display({
header: 'ACTIONS',
cell: actions((chatEngine) => [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import { z } from 'zod';

const schema = z.object({
name: z.string().min(1),
is_public: z.boolean().optional(),
llm_id: z.number().optional(),
fast_llm_id: z.number().optional(),
reranker_id: z.number().optional(),
Expand Down Expand Up @@ -53,6 +54,9 @@ export function CreateChatEngineForm ({ defaultChatEngineOptions }: { defaultCha

const form = useForm({
onSubmit: onSubmitHelper(schema, async data => {
if (data.is_public == null) {
data.is_public = true;
}
const ce = await createChatEngine(data);
startTransition(() => {
router.push(`/chat-engines/${ce.id}`);
Expand Down Expand Up @@ -85,6 +89,9 @@ export function CreateChatEngineForm ({ defaultChatEngineOptions }: { defaultCha
<field.Basic required name="name" label="Name" defaultValue="" validators={{ onSubmit: nameSchema, onBlur: nameSchema }}>
<FormInput placeholder="Enter chat engine name" />
</field.Basic>
<field.Contained name='is_public' label="Is Public" defaultValue={true}>
<FormSwitch />
</field.Contained>
<SubSection title="Models">
<field.Basic name="llm_id" label="LLM">
<LLMSelect />
Expand Down Expand Up @@ -279,4 +286,4 @@ const llmPromptDescriptions: { [P in typeof llmPromptFields[number]]: string } =
'clarifying_question_prompt': 'Prompt template for generating clarifying questions when the user\'s input needs more context or specificity',
'generate_goal_prompt': 'Prompt template for generating conversation goals and objectives based on user input',
'further_questions_prompt': 'Prompt template for generating follow-up questions to continue the conversation',
};
};
6 changes: 3 additions & 3 deletions frontend/app/src/components/chat-engine/hooks.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { listChatEngines } from '@/api/chat-engines';
import { listChatEngines, listPublicChatEngines } from '@/api/chat-engines';
import { listAllHelper } from '@/lib/request';
import useSWR from 'swr';

export function useAllChatEngines () {
return useSWR('api.chat-engines.list-all', () => listAllHelper(listChatEngines, 'id'));
export function useAllChatEngines (onlyPublic: boolean = false) {
return useSWR(onlyPublic ? 'api.chat-engines.list-all-public' : 'api.chat-engines.list-all', () => listAllHelper(onlyPublic ? listPublicChatEngines : listChatEngines, 'id'));
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ export function UpdateChatEngineForm ({ chatEngine, defaultChatEngineOptions }:
<FormSwitch />
</field.Contained>
</GeneralSettingsField>
<GeneralSettingsField accessor={isPublicAccessor} schema={isPublicSchema}>
<field.Contained unimportant name="value" label="Is Public" fallbackValue={chatEngine.is_public}>
<FormSwitch />
</field.Contained>
</GeneralSettingsField>
<SubSection title="Models">
<GeneralSettingsField accessor={llmIdAccessor} schema={idSchema}>
<field.Basic name="value" label="LLM">
Expand Down Expand Up @@ -222,7 +227,7 @@ export function UpdateChatEngineForm ({ chatEngine, defaultChatEngineOptions }:
);
}

const updatableFields = ['name', 'llm_id', 'fast_llm_id', 'reranker_id', 'engine_options', 'is_default'] as const;
const updatableFields = ['name', 'llm_id', 'fast_llm_id', 'reranker_id', 'engine_options', 'is_default', 'is_public'] as const;

function optionAccessor<K extends keyof ChatEngineOptions> (key: K): GeneralSettingsFieldAccessor<ChatEngine, ChatEngineOptions[K]> {
return {
Expand Down Expand Up @@ -311,6 +316,9 @@ const clarifyAccessorSchema = z.boolean().nullable().optional();
const isDefaultAccessor = fieldAccessor<ChatEngine, 'is_default'>('is_default');
const isDefaultSchema = z.boolean();

const isPublicAccessor = fieldAccessor<ChatEngine, 'is_public'>('is_public');
const isPublicSchema = z.boolean();

const getIdAccessor = (id: KeyOfType<ChatEngine, number | null>) => fieldAccessor<ChatEngine, KeyOfType<ChatEngine, number | null>>(id);
const idSchema = z.number().nullable();
const llmIdAccessor = getIdAccessor('llm_id');
Expand Down
4 changes: 2 additions & 2 deletions frontend/app/src/components/chat/message-input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ export function MessageInput ({
onChangeRef.current?.(ev);
}, []);

const showShowSelectChatEngine = !!auth.me?.is_superuser && !!onEngineChange;
const { data, isLoading } = useAllChatEngines();
const { data, isLoading } = useAllChatEngines(!auth.me?.is_superuser);
const showShowSelectChatEngine = !!data?.length && !!onEngineChange;

return (
<div className={cn('bg-background border p-2 rounded-lg', className)}>
Expand Down