2020import re
2121import tempfile
2222from copy import deepcopy
23+ from types import SimpleNamespace
2324
2425from quart import Response , request
2526
3031)
3132from api .db .services .chunk_feedback_service import ChunkFeedbackService
3233from api .db .services .conversation_service import ConversationService , structure_answer
33- from api .db .services .dialog_service import DialogService , async_ask , async_chat , gen_mindmap
34+ from api .db .services .dialog_service import DialogService , async_chat , gen_mindmap
3435from api .db .services .knowledgebase_service import KnowledgebaseService
3536from api .db .services .llm_service import LLMBundle
3637from api .db .services .search_service import SearchService
6768 "tts" : False ,
6869 "refine_multiturn" : True ,
6970}
71+ _DEFAULT_DIRECT_CHAT_PROMPT_CONFIG = {
72+ "system" : "" ,
73+ "prologue" : "" ,
74+ "parameters" : [],
75+ "empty_response" : "" ,
76+ "quote" : False ,
77+ "tts" : False ,
78+ "refine_multiturn" : True ,
79+ }
7080_DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3" , "maidalun1020/bce-reranker-base_v1" }
7181_READONLY_FIELDS = {"id" , "tenant_id" , "created_by" , "create_time" , "create_date" , "update_time" , "update_date" }
7282_PERSISTED_FIELDS = set (DialogService .model ._meta .fields )
@@ -124,6 +134,39 @@ def _ensure_owned_chat(chat_id):
124134 )
125135
126136
137+ def _build_default_completion_dialog ():
138+ return SimpleNamespace (
139+ tenant_id = current_user .id ,
140+ llm_id = "" ,
141+ tenant_llm_id = None ,
142+ llm_setting = {},
143+ prompt_config = deepcopy (_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG ),
144+ kb_ids = [],
145+ top_n = 6 ,
146+ top_k = 1024 ,
147+ rerank_id = "" ,
148+ similarity_threshold = 0.1 ,
149+ vector_similarity_weight = 0.3 ,
150+ meta_data_filter = None ,
151+ )
152+
153+
154+ def _create_session_for_completion (chat_id , dialog , user_id ):
155+ conv = {
156+ "id" : get_uuid (),
157+ "dialog_id" : chat_id ,
158+ "name" : "New session" ,
159+ "message" : [{"role" : "assistant" , "content" : dialog .prompt_config .get ("prologue" , "" )}],
160+ "user_id" : user_id ,
161+ "reference" : [],
162+ }
163+ ConversationService .save (** conv )
164+ ok , conv_obj = ConversationService .get_by_id (conv ["id" ])
165+ if not ok :
166+ raise LookupError ("Fail to create a session!" )
167+ return conv_obj
168+
169+
127170def _validate_llm_id (llm_id , tenant_id , llm_setting = None ):
128171 if not llm_id :
129172 return None
@@ -671,7 +714,7 @@ async def get_session(chat_id, session_id):
671714 return server_error_response (ex )
672715
673716
674- @manager .route ("/chats/<chat_id>/sessions/<session_id>" , methods = ["PUT " ]) # noqa: F821
717+ @manager .route ("/chats/<chat_id>/sessions/<session_id>" , methods = ["PATCH " ]) # noqa: F821
675718@login_required
676719async def update_session (chat_id , session_id ):
677720 if not _ensure_owned_chat (chat_id ):
@@ -829,7 +872,7 @@ async def update_message_feedback(chat_id, session_id, msg_id):
829872 return server_error_response (ex )
830873
831874
832- @manager .route ("/chats/tts " , methods = ["POST" ]) # noqa: F821
875+ @manager .route ("/chat/audio/speech " , methods = ["POST" ]) # noqa: F821
833876@login_required
834877async def tts ():
835878 req = await get_request_json ()
@@ -857,9 +900,9 @@ def stream_audio():
857900 return resp
858901
859902
860- @manager .route ("/chats/transcriptions " , methods = ["POST" ]) # noqa: F821
903+ @manager .route ("/chat/audio/transcription " , methods = ["POST" ]) # noqa: F821
861904@login_required
862- async def transcriptions ():
905+ async def transcription ():
863906 req = await request .form
864907 stream_mode = req .get ("stream" , "false" ).lower () == "true"
865908 files = await request .files
@@ -915,7 +958,7 @@ async def event_stream():
915958 return Response (event_stream (), content_type = "text/event-stream" )
916959
917960
918- @manager .route ("/chats /mindmap" , methods = ["POST" ]) # noqa: F821
961+ @manager .route ("/chat /mindmap" , methods = ["POST" ]) # noqa: F821
919962@login_required
920963@validate_request ("question" , "kb_ids" )
921964async def mindmap ():
@@ -933,10 +976,10 @@ async def mindmap():
933976 return get_json_result (data = mind_map )
934977
935978
936- @manager .route ("/chats/related_questions " , methods = ["POST" ]) # noqa: F821
979+ @manager .route ("/chat/recommendation " , methods = ["POST" ]) # noqa: F821
937980@login_required
938981@validate_request ("question" )
939- async def related_questions ():
982+ async def recommendation ():
940983 req = await get_request_json ()
941984
942985 search_id = req .get ("search_id" , "" )
@@ -971,10 +1014,10 @@ async def related_questions():
9711014 return get_json_result (data = [re .sub (r"^[0-9]\. " , "" , a ) for a in ans .split ("\n " ) if re .match (r"^[0-9]\. " , a )])
9721015
9731016
974- @manager .route ("/chats/<chat_id>/sessions/<session_id> /completions" , methods = ["POST" ]) # noqa: F821
1017+ @manager .route ("/chat /completions" , methods = ["POST" ]) # noqa: F821
9751018@login_required
9761019@validate_request ("messages" )
977- async def session_completion (chat_id , session_id ):
1020+ async def session_completion ():
9781021 req = await get_request_json ()
9791022 msg = []
9801023 for m in req ["messages" ]:
@@ -984,6 +1027,8 @@ async def session_completion(chat_id, session_id):
9841027 continue
9851028 msg .append (m )
9861029 message_id = msg [- 1 ].get ("id" ) if msg else None
1030+ chat_id = req .pop ("chat_id" , "" ) or ""
1031+ session_id = req .pop ("session_id" , "" ) or ""
9871032 chat_model_id = req .pop ("llm_id" , "" )
9881033
9891034 chat_model_config = {}
@@ -993,38 +1038,63 @@ async def session_completion(chat_id, session_id):
9931038 chat_model_config [model_config ] = config
9941039
9951040 try :
996- e , conv = ConversationService .get_by_id (session_id )
997- if not e :
998- return get_data_error_result (message = "Session not found!" )
999- if conv .dialog_id != chat_id :
1000- return get_data_error_result (message = "Session does not belong to this chat!" )
1001- conv .message = deepcopy (req ["messages" ])
1002- e , dia = DialogService .get_by_id (chat_id )
1003- if not e :
1004- return get_data_error_result (message = "Chat not found!" )
1041+ conv = None
1042+ if session_id and not chat_id :
1043+ return get_data_error_result (message = "`chat_id` is required when `session_id` is provided." )
1044+
1045+ if chat_id :
1046+ if not _ensure_owned_chat (chat_id ):
1047+ return get_json_result (
1048+ data = False ,
1049+ message = "No authorization." ,
1050+ code = RetCode .AUTHENTICATION_ERROR ,
1051+ )
1052+ e , dia = DialogService .get_by_id (chat_id )
1053+ if not e :
1054+ return get_data_error_result (message = "Chat not found!" )
1055+ if session_id :
1056+ e , conv = ConversationService .get_by_id (session_id )
1057+ if not e :
1058+ return get_data_error_result (message = "Session not found!" )
1059+ if conv .dialog_id != chat_id :
1060+ return get_data_error_result (message = "Session does not belong to this chat!" )
1061+ else :
1062+ conv = _create_session_for_completion (chat_id , dia , req .get ("user_id" , current_user .id ))
1063+ session_id = conv .id
1064+ conv .message = deepcopy (req ["messages" ])
1065+ else :
1066+ dia = _build_default_completion_dialog ()
1067+ dia .llm_setting = chat_model_config
1068+
10051069 del req ["messages" ]
10061070
1007- if not conv .reference :
1008- conv .reference = []
1009- conv .reference = [r for r in conv .reference if r ]
1010- conv .reference .append ({"chunks" : [], "doc_aggs" : []})
1071+ if conv is not None :
1072+ if not conv .reference :
1073+ conv .reference = []
1074+ conv .reference = [r for r in conv .reference if r ]
1075+ conv .reference .append ({"chunks" : [], "doc_aggs" : []})
10111076
10121077 if chat_model_id :
10131078 if not TenantLLMService .get_api_key (tenant_id = dia .tenant_id , model_name = chat_model_id ):
10141079 return get_data_error_result (message = f"Cannot use specified model { chat_model_id } ." )
10151080 dia .llm_id = chat_model_id
10161081 dia .llm_setting = chat_model_config
10171082
1018- is_embedded = bool (chat_model_id )
10191083 stream_mode = req .pop ("stream" , True )
10201084
1085+ def _format_answer (ans ):
1086+ formatted = structure_answer (conv , ans , message_id , session_id )
1087+ if chat_id :
1088+ formatted ["chat_id" ] = chat_id
1089+ return formatted
1090+
10211091 async def stream ():
10221092 nonlocal dia , msg , req , conv
10231093 try :
10241094 async for ans in async_chat (dia , msg , True , ** req ):
1025- ans = structure_answer ( conv , ans , message_id , conv . id )
1095+ ans = _format_answer ( ans )
10261096 yield "data:" + json .dumps ({"code" : 0 , "message" : "" , "data" : ans }, ensure_ascii = False ) + "\n \n "
1027- if not is_embedded :
1097+ if conv is not None :
10281098 ConversationService .update_by_id (conv .id , conv .to_dict ())
10291099 except Exception as ex :
10301100 logging .exception (ex )
@@ -1041,40 +1111,10 @@ async def stream():
10411111
10421112 answer = None
10431113 async for ans in async_chat (dia , msg , ** req ):
1044- answer = structure_answer ( conv , ans , message_id , conv . id )
1045- if not is_embedded :
1114+ answer = _format_answer ( ans )
1115+ if conv is not None :
10461116 ConversationService .update_by_id (conv .id , conv .to_dict ())
10471117 break
10481118 return get_json_result (data = answer )
10491119 except Exception as ex :
10501120 return server_error_response (ex )
1051-
1052-
1053- @manager .route ("/chats/ask" , methods = ["POST" ]) # noqa: F821
1054- @login_required
1055- @validate_request ("question" , "kb_ids" )
1056- async def ask ():
1057- req = await get_request_json ()
1058- uid = current_user .id
1059-
1060- search_id = req .get ("search_id" , "" )
1061- search_config = {}
1062- if search_id :
1063- if search_app := SearchService .get_detail (search_id ):
1064- search_config = search_app .get ("search_config" , {})
1065-
1066- async def stream ():
1067- nonlocal req , uid
1068- try :
1069- async for ans in async_ask (req ["question" ], req ["kb_ids" ], uid , search_config = search_config ):
1070- yield "data:" + json .dumps ({"code" : 0 , "message" : "" , "data" : ans }, ensure_ascii = False ) + "\n \n "
1071- except Exception as ex :
1072- yield "data:" + json .dumps ({"code" : 500 , "message" : str (ex ), "data" : {"answer" : "**ERROR**: " + str (ex ), "reference" : []}}, ensure_ascii = False ) + "\n \n "
1073- yield "data:" + json .dumps ({"code" : 0 , "message" : "" , "data" : True }, ensure_ascii = False ) + "\n \n "
1074-
1075- resp = Response (stream (), mimetype = "text/event-stream" )
1076- resp .headers .add_header ("Cache-control" , "no-cache" )
1077- resp .headers .add_header ("Connection" , "keep-alive" )
1078- resp .headers .add_header ("X-Accel-Buffering" , "no" )
1079- resp .headers .add_header ("Content-Type" , "text/event-stream; charset=utf-8" )
1080- return resp
0 commit comments