Skip to content

Commit 564f907

Browse files
634750802Mini256
authored andcommitted
feat: support public chat engine (pingcap#692)
Co-authored-by: Mini256 <minianter@foxmail.com>
1 parent 3988175 commit 564f907

File tree

12 files changed

+122
-18
lines changed

12 files changed

+122
-18
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""public_chat_engine
2+
3+
Revision ID: 04947f9684ab
4+
Revises: 211f3c5aa125
5+
Create Date: 2025-05-28 15:13:22.058160
6+
7+
"""
8+
9+
from alembic import op
10+
import sqlalchemy as sa
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "04947f9684ab"
14+
down_revision = "211f3c5aa125"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.add_column("chat_engines", sa.Column("is_public", sa.Boolean(), nullable=False))
22+
# ### end Alembic commands ###
23+
24+
25+
def downgrade():
26+
# ### commands auto generated by Alembic - please adjust! ###
27+
op.drop_column("chat_engines", "is_public")
28+
# ### end Alembic commands ###

backend/app/api/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastapi import APIRouter
22
from app.api.routes import (
3+
chat_engine,
34
index,
45
chat,
56
chat_engine,
@@ -60,11 +61,11 @@
6061
api_router = APIRouter()
6162
api_router.include_router(index.router, tags=["index"])
6263
api_router.include_router(chat.router, tags=["chat"])
63-
api_router.include_router(chat_engine.router, tags=["chat_engine"])
6464
api_router.include_router(feedback.router, tags=["chat"])
6565
api_router.include_router(user.router, tags=["user"])
6666
api_router.include_router(api_key.router, tags=["auth"])
6767
api_router.include_router(document.router, tags=["documents"])
68+
api_router.include_router(chat_engine.router, tags=["chat-engines"])
6869
api_router.include_router(retrieve_routes.router, tags=["retrieve"])
6970
api_router.include_router(tos.router, tags=["tos"])
7071
api_router.include_router(admin_user_router)
Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import logging
2+
13
from fastapi import APIRouter, Depends
4+
from app.api.deps import SessionDep, CurrentUserDep
25
from fastapi_pagination import Params, Page
36

4-
from app.api.deps import SessionDep, CurrentUserDep
5-
from app.repositories import chat_engine_repo
6-
from app.models import ChatEngine
7+
from app.models.chat_engine import ChatEngine
8+
from app.rag.chat.config import ChatEngineConfig
9+
from app.repositories.chat_engine import chat_engine_repo
10+
11+
logger = logging.getLogger(__name__)
712

813
router = APIRouter()
914

@@ -14,4 +19,22 @@ def list_chat_engines(
1419
user: CurrentUserDep,
1520
params: Params = Depends(),
1621
) -> Page[ChatEngine]:
17-
return chat_engine_repo.paginate(db_session, params)
22+
page = chat_engine_repo.paginate(db_session, params, need_public=True)
23+
for chat_engine in page.items:
24+
engine_config = ChatEngineConfig.model_validate(chat_engine.engine_options)
25+
chat_engine.engine_options = engine_config.screenshot()
26+
return page
27+
28+
29+
@router.get("/chat-engines/{chat_engine_id}")
30+
def get_chat_engine(
31+
db_session: SessionDep,
32+
user: CurrentUserDep,
33+
chat_engine_id: int,
34+
) -> ChatEngine:
35+
chat_engine = chat_engine_repo.must_get(
36+
db_session, chat_engine_id, need_public=True
37+
)
38+
engine_config = ChatEngineConfig.model_validate(chat_engine.engine_options)
39+
chat_engine.engine_options = engine_config.screenshot()
40+
return chat_engine

backend/app/models/chat_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class ChatEngine(UpdatableBaseModel, table=True):
3636
},
3737
)
3838
is_default: bool = Field(default=False)
39+
is_public: bool = Field(default=False)
3940
deleted_at: Optional[datetime] = Field(default=None, sa_column=Column(DateTime))
4041

4142
__tablename__ = "chat_engines"
@@ -48,3 +49,4 @@ class ChatEngineUpdate(BaseModel):
4849
reranker_id: Optional[int] = None
4950
engine_options: Optional[dict] = None
5051
is_default: Optional[bool] = None
52+
is_public: Optional[bool] = None

backend/app/rag/chat/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ def screenshot(self) -> dict:
200200
"condense_question_prompt",
201201
"text_qa_prompt",
202202
"refine_prompt",
203+
"intent_graph_knowledge",
204+
"normal_graph_knowledge",
205+
"generate_goal_prompt",
206+
"further_questions_prompt",
207+
"clarifying_question_prompt",
203208
],
204209
"post_verification_token": True,
205210
}

backend/app/repositories/chat_engine.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
class ChatEngineRepo(BaseRepo):
1616
model_cls = ChatEngine
1717

18-
def get(self, session: Session, id: int) -> Optional[ChatEngine]:
19-
return session.exec(
20-
select(ChatEngine).where(ChatEngine.id == id, ChatEngine.deleted_at == None)
21-
).first()
18+
def get(
19+
self, session: Session, id: int, need_public: bool = False
20+
) -> Optional[ChatEngine]:
21+
query = select(ChatEngine).where(
22+
ChatEngine.id == id, ChatEngine.deleted_at == None
23+
)
24+
if need_public:
25+
query = query.where(ChatEngine.is_public == True)
26+
return session.exec(query).first()
2227

23-
def must_get(self, session: Session, id: int) -> ChatEngine:
24-
chat_engine = self.get(session, id)
28+
def must_get(
29+
self, session: Session, id: int, need_public: bool = False
30+
) -> ChatEngine:
31+
chat_engine = self.get(session, id, need_public)
2532
if chat_engine is None:
2633
raise ChatEngineNotFound(id)
2734
return chat_engine
@@ -30,8 +37,11 @@ def paginate(
3037
self,
3138
session: Session,
3239
params: Params | None = Params(),
40+
need_public: bool = False,
3341
) -> Page[ChatEngine]:
3442
query = select(ChatEngine).where(ChatEngine.deleted_at == None)
43+
if need_public:
44+
query = query.where(ChatEngine.is_public == True)
3545
# Make sure the default engine is always on top
3646
query = query.order_by(ChatEngine.is_default.desc(), ChatEngine.name)
3747
return paginate(session, query, params)

frontend/app/src/api/chat-engines.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export interface ChatEngine {
1313
fast_llm_id: number | null;
1414
reranker_id: number | null;
1515
is_default: boolean;
16+
is_public: boolean;
1617
}
1718

1819
export interface CreateChatEngineParams {
@@ -134,6 +135,7 @@ const chatEngineSchema = z.object({
134135
fast_llm_id: z.number().nullable(),
135136
reranker_id: z.number().nullable(),
136137
is_default: z.boolean(),
138+
is_public: z.boolean(),
137139
}) satisfies ZodType<ChatEngine, any, any>;
138140

139141
export async function getDefaultChatEngineOptions (): Promise<ChatEngineOptions> {
@@ -190,3 +192,17 @@ export async function deleteChatEngine (id: number): Promise<void> {
190192
})
191193
.then(handleErrors);
192194
}
195+
196+
export async function listPublicChatEngines ({ page = 1, size = 10 }: PageParams = {}): Promise<Page<ChatEngine>> {
197+
return await fetch(requestUrl('/api/v1/chat-engines', { page, size }), {
198+
headers: await authenticationHeaders(),
199+
})
200+
.then(handleResponse(zodPage(chatEngineSchema)));
201+
}
202+
203+
export async function getPublicChatEngine (id: number): Promise<ChatEngine> {
204+
return await fetch(requestUrl(`/api/v1/chat-engines/${id}`), {
205+
headers: await authenticationHeaders(),
206+
})
207+
.then(handleResponse(chatEngineSchema));
208+
}

frontend/app/src/components/chat-engine/chat-engines-table.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ const columns = [
3737
header: 'IS DEFAULT',
3838
cell: boolean
3939
}),
40+
helper.accessor('is_public', {
41+
header: 'IS PUBLIC',
42+
cell: boolean
43+
}),
4044
helper.display({
4145
header: 'ACTIONS',
4246
cell: actions((chatEngine) => [

frontend/app/src/components/chat-engine/create-chat-engine-form.tsx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import { z } from 'zod';
2323

2424
const schema = z.object({
2525
name: z.string().min(1),
26+
is_public: z.boolean().optional(),
2627
llm_id: z.number().optional(),
2728
fast_llm_id: z.number().optional(),
2829
reranker_id: z.number().optional(),
@@ -53,6 +54,9 @@ export function CreateChatEngineForm ({ defaultChatEngineOptions }: { defaultCha
5354

5455
const form = useForm({
5556
onSubmit: onSubmitHelper(schema, async data => {
57+
if (data.is_public == null) {
58+
data.is_public = true;
59+
}
5660
const ce = await createChatEngine(data);
5761
startTransition(() => {
5862
router.push(`/chat-engines/${ce.id}`);
@@ -85,6 +89,9 @@ export function CreateChatEngineForm ({ defaultChatEngineOptions }: { defaultCha
8589
<field.Basic required name="name" label="Name" defaultValue="" validators={{ onSubmit: nameSchema, onBlur: nameSchema }}>
8690
<FormInput placeholder="Enter chat engine name" />
8791
</field.Basic>
92+
<field.Contained name='is_public' label="Is Public" defaultValue={true}>
93+
<FormSwitch />
94+
</field.Contained>
8895
<SubSection title="Models">
8996
<field.Basic name="llm_id" label="LLM">
9097
<LLMSelect />
@@ -279,4 +286,4 @@ const llmPromptDescriptions: { [P in typeof llmPromptFields[number]]: string } =
279286
'clarifying_question_prompt': 'Prompt template for generating clarifying questions when the user\'s input needs more context or specificity',
280287
'generate_goal_prompt': 'Prompt template for generating conversation goals and objectives based on user input',
281288
'further_questions_prompt': 'Prompt template for generating follow-up questions to continue the conversation',
282-
};
289+
};
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import { listChatEngines } from '@/api/chat-engines';
1+
import { listChatEngines, listPublicChatEngines } from '@/api/chat-engines';
22
import { listAllHelper } from '@/lib/request';
33
import useSWR from 'swr';
44

5-
export function useAllChatEngines () {
6-
return useSWR('api.chat-engines.list-all', () => listAllHelper(listChatEngines, 'id'));
5+
export function useAllChatEngines (onlyPublic: boolean = false) {
6+
return useSWR(onlyPublic ? 'api.chat-engines.list-all-public' : 'api.chat-engines.list-all', () => listAllHelper(onlyPublic ? listPublicChatEngines : listChatEngines, 'id'));
77
}

0 commit comments

Comments
 (0)