Skip to content

Commit 6baf74a

Browse files
buua436JinHai-CN
andauthored
Refa: align chat and search restful APIs (infiniflow#14229)
### What problem does this PR solve? Refactor /api/v1/chats to be more RESTful. ### Type of change - [x] Refactoring --------- Co-authored-by: Jin Hai <haijin.chn@gmail.com>
1 parent bfac019 commit 6baf74a

14 files changed

Lines changed: 361 additions & 159 deletions

File tree

api/apps/restful_apis/chat_api.py

Lines changed: 98 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import re
2121
import tempfile
2222
from copy import deepcopy
23+
from types import SimpleNamespace
2324

2425
from quart import Response, request
2526

@@ -30,7 +31,7 @@
3031
)
3132
from api.db.services.chunk_feedback_service import ChunkFeedbackService
3233
from api.db.services.conversation_service import ConversationService, structure_answer
33-
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
34+
from api.db.services.dialog_service import DialogService, async_chat, gen_mindmap
3435
from api.db.services.knowledgebase_service import KnowledgebaseService
3536
from api.db.services.llm_service import LLMBundle
3637
from api.db.services.search_service import SearchService
@@ -67,6 +68,15 @@
6768
"tts": False,
6869
"refine_multiturn": True,
6970
}
71+
_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG = {
72+
"system": "",
73+
"prologue": "",
74+
"parameters": [],
75+
"empty_response": "",
76+
"quote": False,
77+
"tts": False,
78+
"refine_multiturn": True,
79+
}
7080
_DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"}
7181
_READONLY_FIELDS = {"id", "tenant_id", "created_by", "create_time", "create_date", "update_time", "update_date"}
7282
_PERSISTED_FIELDS = set(DialogService.model._meta.fields)
@@ -124,6 +134,39 @@ def _ensure_owned_chat(chat_id):
124134
)
125135

126136

137+
def _build_default_completion_dialog():
138+
return SimpleNamespace(
139+
tenant_id=current_user.id,
140+
llm_id="",
141+
tenant_llm_id=None,
142+
llm_setting={},
143+
prompt_config=deepcopy(_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG),
144+
kb_ids=[],
145+
top_n=6,
146+
top_k=1024,
147+
rerank_id="",
148+
similarity_threshold=0.1,
149+
vector_similarity_weight=0.3,
150+
meta_data_filter=None,
151+
)
152+
153+
154+
def _create_session_for_completion(chat_id, dialog, user_id):
155+
conv = {
156+
"id": get_uuid(),
157+
"dialog_id": chat_id,
158+
"name": "New session",
159+
"message": [{"role": "assistant", "content": dialog.prompt_config.get("prologue", "")}],
160+
"user_id": user_id,
161+
"reference": [],
162+
}
163+
ConversationService.save(**conv)
164+
ok, conv_obj = ConversationService.get_by_id(conv["id"])
165+
if not ok:
166+
raise LookupError("Fail to create a session!")
167+
return conv_obj
168+
169+
127170
def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
128171
if not llm_id:
129172
return None
@@ -671,7 +714,7 @@ async def get_session(chat_id, session_id):
671714
return server_error_response(ex)
672715

673716

674-
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
717+
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PATCH"]) # noqa: F821
675718
@login_required
676719
async def update_session(chat_id, session_id):
677720
if not _ensure_owned_chat(chat_id):
@@ -829,7 +872,7 @@ async def update_message_feedback(chat_id, session_id, msg_id):
829872
return server_error_response(ex)
830873

831874

832-
@manager.route("/chats/tts", methods=["POST"]) # noqa: F821
875+
@manager.route("/chat/audio/speech", methods=["POST"]) # noqa: F821
833876
@login_required
834877
async def tts():
835878
req = await get_request_json()
@@ -857,9 +900,9 @@ def stream_audio():
857900
return resp
858901

859902

860-
@manager.route("/chats/transcriptions", methods=["POST"]) # noqa: F821
903+
@manager.route("/chat/audio/transcription", methods=["POST"]) # noqa: F821
861904
@login_required
862-
async def transcriptions():
905+
async def transcription():
863906
req = await request.form
864907
stream_mode = req.get("stream", "false").lower() == "true"
865908
files = await request.files
@@ -915,7 +958,7 @@ async def event_stream():
915958
return Response(event_stream(), content_type="text/event-stream")
916959

917960

918-
@manager.route("/chats/mindmap", methods=["POST"]) # noqa: F821
961+
@manager.route("/chat/mindmap", methods=["POST"]) # noqa: F821
919962
@login_required
920963
@validate_request("question", "kb_ids")
921964
async def mindmap():
@@ -933,10 +976,10 @@ async def mindmap():
933976
return get_json_result(data=mind_map)
934977

935978

936-
@manager.route("/chats/related_questions", methods=["POST"]) # noqa: F821
979+
@manager.route("/chat/recommendation", methods=["POST"]) # noqa: F821
937980
@login_required
938981
@validate_request("question")
939-
async def related_questions():
982+
async def recommendation():
940983
req = await get_request_json()
941984

942985
search_id = req.get("search_id", "")
@@ -971,10 +1014,10 @@ async def related_questions():
9711014
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
9721015

9731016

974-
@manager.route("/chats/<chat_id>/sessions/<session_id>/completions", methods=["POST"]) # noqa: F821
1017+
@manager.route("/chat/completions", methods=["POST"]) # noqa: F821
9751018
@login_required
9761019
@validate_request("messages")
977-
async def session_completion(chat_id, session_id):
1020+
async def session_completion():
9781021
req = await get_request_json()
9791022
msg = []
9801023
for m in req["messages"]:
@@ -984,6 +1027,8 @@ async def session_completion(chat_id, session_id):
9841027
continue
9851028
msg.append(m)
9861029
message_id = msg[-1].get("id") if msg else None
1030+
chat_id = req.pop("chat_id", "") or ""
1031+
session_id = req.pop("session_id", "") or ""
9871032
chat_model_id = req.pop("llm_id", "")
9881033

9891034
chat_model_config = {}
@@ -993,38 +1038,63 @@ async def session_completion(chat_id, session_id):
9931038
chat_model_config[model_config] = config
9941039

9951040
try:
996-
e, conv = ConversationService.get_by_id(session_id)
997-
if not e:
998-
return get_data_error_result(message="Session not found!")
999-
if conv.dialog_id != chat_id:
1000-
return get_data_error_result(message="Session does not belong to this chat!")
1001-
conv.message = deepcopy(req["messages"])
1002-
e, dia = DialogService.get_by_id(chat_id)
1003-
if not e:
1004-
return get_data_error_result(message="Chat not found!")
1041+
conv = None
1042+
if session_id and not chat_id:
1043+
return get_data_error_result(message="`chat_id` is required when `session_id` is provided.")
1044+
1045+
if chat_id:
1046+
if not _ensure_owned_chat(chat_id):
1047+
return get_json_result(
1048+
data=False,
1049+
message="No authorization.",
1050+
code=RetCode.AUTHENTICATION_ERROR,
1051+
)
1052+
e, dia = DialogService.get_by_id(chat_id)
1053+
if not e:
1054+
return get_data_error_result(message="Chat not found!")
1055+
if session_id:
1056+
e, conv = ConversationService.get_by_id(session_id)
1057+
if not e:
1058+
return get_data_error_result(message="Session not found!")
1059+
if conv.dialog_id != chat_id:
1060+
return get_data_error_result(message="Session does not belong to this chat!")
1061+
else:
1062+
conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id))
1063+
session_id = conv.id
1064+
conv.message = deepcopy(req["messages"])
1065+
else:
1066+
dia = _build_default_completion_dialog()
1067+
dia.llm_setting = chat_model_config
1068+
10051069
del req["messages"]
10061070

1007-
if not conv.reference:
1008-
conv.reference = []
1009-
conv.reference = [r for r in conv.reference if r]
1010-
conv.reference.append({"chunks": [], "doc_aggs": []})
1071+
if conv is not None:
1072+
if not conv.reference:
1073+
conv.reference = []
1074+
conv.reference = [r for r in conv.reference if r]
1075+
conv.reference.append({"chunks": [], "doc_aggs": []})
10111076

10121077
if chat_model_id:
10131078
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
10141079
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
10151080
dia.llm_id = chat_model_id
10161081
dia.llm_setting = chat_model_config
10171082

1018-
is_embedded = bool(chat_model_id)
10191083
stream_mode = req.pop("stream", True)
10201084

1085+
def _format_answer(ans):
1086+
formatted = structure_answer(conv, ans, message_id, session_id)
1087+
if chat_id:
1088+
formatted["chat_id"] = chat_id
1089+
return formatted
1090+
10211091
async def stream():
10221092
nonlocal dia, msg, req, conv
10231093
try:
10241094
async for ans in async_chat(dia, msg, True, **req):
1025-
ans = structure_answer(conv, ans, message_id, conv.id)
1095+
ans = _format_answer(ans)
10261096
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
1027-
if not is_embedded:
1097+
if conv is not None:
10281098
ConversationService.update_by_id(conv.id, conv.to_dict())
10291099
except Exception as ex:
10301100
logging.exception(ex)
@@ -1041,40 +1111,10 @@ async def stream():
10411111

10421112
answer = None
10431113
async for ans in async_chat(dia, msg, **req):
1044-
answer = structure_answer(conv, ans, message_id, conv.id)
1045-
if not is_embedded:
1114+
answer = _format_answer(ans)
1115+
if conv is not None:
10461116
ConversationService.update_by_id(conv.id, conv.to_dict())
10471117
break
10481118
return get_json_result(data=answer)
10491119
except Exception as ex:
10501120
return server_error_response(ex)
1051-
1052-
1053-
@manager.route("/chats/ask", methods=["POST"]) # noqa: F821
1054-
@login_required
1055-
@validate_request("question", "kb_ids")
1056-
async def ask():
1057-
req = await get_request_json()
1058-
uid = current_user.id
1059-
1060-
search_id = req.get("search_id", "")
1061-
search_config = {}
1062-
if search_id:
1063-
if search_app := SearchService.get_detail(search_id):
1064-
search_config = search_app.get("search_config", {})
1065-
1066-
async def stream():
1067-
nonlocal req, uid
1068-
try:
1069-
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
1070-
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
1071-
except Exception as ex:
1072-
yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n"
1073-
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
1074-
1075-
resp = Response(stream(), mimetype="text/event-stream")
1076-
resp.headers.add_header("Cache-control", "no-cache")
1077-
resp.headers.add_header("Connection", "keep-alive")
1078-
resp.headers.add_header("X-Accel-Buffering", "no")
1079-
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
1080-
return resp

api/apps/restful_apis/search_api.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
# limitations under the License.
1515
#
1616

17-
from quart import request
17+
import json
18+
19+
from quart import Response, request
20+
from api.db.services.dialog_service import async_ask
1821
from api.apps import current_user, login_required
1922

2023
from api.constants import DATASET_NAME_LIMIT
@@ -168,3 +171,45 @@ def delete_search(search_id):
168171
return get_json_result(data=True)
169172
except Exception as e:
170173
return server_error_response(e)
174+
175+
176+
@manager.route("/searches/<search_id>/completion", methods=["POST"]) # noqa: F821
177+
@login_required
178+
@validate_request("question")
179+
async def completion(search_id):
180+
if not SearchService.accessible4deletion(search_id, current_user.id):
181+
return get_json_result(
182+
data=False,
183+
message="No authorization.",
184+
code=RetCode.AUTHENTICATION_ERROR,
185+
)
186+
187+
req = await get_request_json()
188+
uid = current_user.id
189+
search_app = SearchService.get_detail(search_id)
190+
if not search_app:
191+
return get_data_error_result(message=f"Cannot find search {search_id}")
192+
193+
search_config = search_app.get("search_config", {})
194+
kb_ids = search_config.get("kb_ids") or req.get("kb_ids") or []
195+
if not kb_ids:
196+
return get_data_error_result(message="`kb_ids` is required.")
197+
198+
async def stream():
199+
nonlocal req, uid, kb_ids, search_config
200+
try:
201+
async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config):
202+
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
203+
except Exception as ex:
204+
yield "data:" + json.dumps(
205+
{"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}},
206+
ensure_ascii=False,
207+
) + "\n\n"
208+
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
209+
210+
resp = Response(stream(), mimetype="text/event-stream")
211+
resp.headers.add_header("Cache-control", "no-cache")
212+
resp.headers.add_header("Connection", "keep-alive")
213+
resp.headers.add_header("X-Accel-Buffering", "no")
214+
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
215+
return resp

docs/guides/chat/set_chat_variables.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,19 @@ See [Converse with chat assistant](../../references/http_api_reference.md#conver
7272

7373
```json {9}
7474
curl --request POST \
75-
--url http://{address}/api/v1/chats/{chat_id}/completions \
75+
--url http://{address}/api/v1/chat/completions \
7676
--header 'Content-Type: application/json' \
7777
--header 'Authorization: Bearer <YOUR_API_KEY>' \
7878
--data-binary '
7979
{
80-
"question": "xxxxxxxxx",
80+
"chat_id": "{chat_id}",
8181
"stream": true,
82+
"messages": [
83+
{
84+
"role": "user",
85+
"content": "xxxxxxxxx"
86+
}
87+
],
8288
"style":"hilarious"
8389
}'
8490
```
@@ -109,4 +115,3 @@ while True:
109115
print(ans.content[len(cont):], end='', flush=True)
110116
cont = ans.content
111117
```
112-

0 commit comments

Comments
 (0)