Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 0 additions & 57 deletions api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,22 +183,6 @@ async def create():
return server_error_response(e)


@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required
async def doc_infos():
req = await get_request_json()
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
docs = DocumentService.get_by_ids(doc_ids)
docs_list = list(docs.dicts())
# Add meta_fields for each document
for doc in docs_list:
doc["meta_fields"] = DocMetadataService.get_document_metadata(doc["id"])
return get_json_result(data=docs_list)


@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_ids")
Expand Down Expand Up @@ -226,26 +210,6 @@ async def metadata_update():
return get_json_result(data={"updated": updated, "matched_docs": len(document_ids)})


@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id", "metadata")
async def update_metadata_setting():
req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)

e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")

DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
e, doc = DocumentService.get_by_id(doc.id)
if not e:
return get_data_error_result(message="Document not found!")

return get_json_result(data=doc.to_dict())


@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
# @login_required
def thumbnails():
Expand Down Expand Up @@ -335,27 +299,6 @@ async def change_status():
return get_json_result(data=result)


@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id")
async def rm():
req = await get_request_json()
doc_ids = req["doc_id"]
if isinstance(doc_ids, str):
doc_ids = [doc_ids]

for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)

errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)

if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)

return get_json_result(data=True)


@manager.route("/run", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_ids", "run")
Expand Down
156 changes: 98 additions & 58 deletions api/apps/restful_apis/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import re
import tempfile
from copy import deepcopy
from types import SimpleNamespace

from quart import Response, request

Expand All @@ -30,7 +31,7 @@
)
from api.db.services.chunk_feedback_service import ChunkFeedbackService
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
from api.db.services.dialog_service import DialogService, async_chat, gen_mindmap
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
Expand Down Expand Up @@ -67,6 +68,15 @@
"tts": False,
"refine_multiturn": True,
}
_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG = {
"system": "",
"prologue": "",
"parameters": [],
"empty_response": "",
"quote": False,
"tts": False,
"refine_multiturn": True,
}
_DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"}
_READONLY_FIELDS = {"id", "tenant_id", "created_by", "create_time", "create_date", "update_time", "update_date"}
_PERSISTED_FIELDS = set(DialogService.model._meta.fields)
Expand Down Expand Up @@ -124,6 +134,39 @@ def _ensure_owned_chat(chat_id):
)


def _build_default_completion_dialog():
return SimpleNamespace(
tenant_id=current_user.id,
llm_id="",
tenant_llm_id=None,
llm_setting={},
prompt_config=deepcopy(_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG),
kb_ids=[],
top_n=6,
top_k=1024,
rerank_id="",
similarity_threshold=0.1,
vector_similarity_weight=0.3,
meta_data_filter=None,
)


def _create_session_for_completion(chat_id, dialog, user_id):
conv = {
"id": get_uuid(),
"dialog_id": chat_id,
"name": "New session",
"message": [{"role": "assistant", "content": dialog.prompt_config.get("prologue", "")}],
"user_id": user_id,
"reference": [],
}
ConversationService.save(**conv)
ok, conv_obj = ConversationService.get_by_id(conv["id"])
if not ok:
raise LookupError("Fail to create a session!")
return conv_obj
Comment on lines +154 to +167
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silent failure if ConversationService.save returns falsy.

_create_session_for_completion calls ConversationService.save(**conv) but never inspects the return value — if save fails but get_by_id still resolves some stale row (or partially), the caller only notices via the subsequent lookup. Mirror create_session at Line 657-660:

🛠️ Suggested
-    ConversationService.save(**conv)
-    ok, conv_obj = ConversationService.get_by_id(conv["id"])
-    if not ok:
-        raise LookupError("Fail to create a session!")
+    if not ConversationService.save(**conv):
+        raise LookupError("Fail to save a session!")
+    ok, conv_obj = ConversationService.get_by_id(conv["id"])
+    if not ok:
+        raise LookupError("Fail to create a session!")
+    logging.info("Created completion session %s for chat %s", conv["id"], chat_id)
     return conv_obj

Also note LookupError raised here propagates to the outer except Exception at Line 1119 and becomes a generic 500 — acceptable, but consider a dedicated error code/message since the caller reached this point with a valid, owned chat.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@api/apps/restful_apis/chat_api.py` around lines 154 - 167, The helper
_create_session_for_completion currently ignores the return value of
ConversationService.save and can silently continue on failure; update
_create_session_for_completion to capture and check the result of
ConversationService.save(**conv) and if it indicates failure, raise a clear
error (or return an error tuple) instead of proceeding to get_by_id; mirror the
behavior used in create_session (use the same success/failure check and error
handling pattern) and use a more specific exception/message (e.g.,
SessionCreationError or a descriptive LookupError message) so callers can
distinguish a save failure from a missing lookup.



def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
if not llm_id:
return None
Expand Down Expand Up @@ -671,7 +714,7 @@ async def get_session(chat_id, session_id):
return server_error_response(ex)


@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PATCH"]) # noqa: F821
@login_required
async def update_session(chat_id, session_id):
if not _ensure_owned_chat(chat_id):
Expand Down Expand Up @@ -829,7 +872,7 @@ async def update_message_feedback(chat_id, session_id, msg_id):
return server_error_response(ex)


@manager.route("/chats/tts", methods=["POST"]) # noqa: F821
@manager.route("/chat/audio/speech", methods=["POST"]) # noqa: F821
@login_required
async def tts():
req = await get_request_json()
Expand Down Expand Up @@ -857,9 +900,9 @@ def stream_audio():
return resp


@manager.route("/chats/transcriptions", methods=["POST"]) # noqa: F821
@manager.route("/chat/audio/transcription", methods=["POST"]) # noqa: F821
@login_required
async def transcriptions():
async def transcription():
req = await request.form
stream_mode = req.get("stream", "false").lower() == "true"
files = await request.files
Expand Down Expand Up @@ -915,7 +958,7 @@ async def event_stream():
return Response(event_stream(), content_type="text/event-stream")


@manager.route("/chats/mindmap", methods=["POST"]) # noqa: F821
@manager.route("/chat/mindmap", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
async def mindmap():
Expand All @@ -933,10 +976,10 @@ async def mindmap():
return get_json_result(data=mind_map)


@manager.route("/chats/related_questions", methods=["POST"]) # noqa: F821
@manager.route("/chat/recommendation", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question")
async def related_questions():
async def recommendation():
req = await get_request_json()

search_id = req.get("search_id", "")
Expand Down Expand Up @@ -971,10 +1014,10 @@ async def related_questions():
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])


@manager.route("/chats/<chat_id>/sessions/<session_id>/completions", methods=["POST"]) # noqa: F821
@manager.route("/chat/completions", methods=["POST"]) # noqa: F821
@login_required
@validate_request("messages")
async def session_completion(chat_id, session_id):
async def session_completion():
req = await get_request_json()
msg = []
for m in req["messages"]:
Expand All @@ -984,6 +1027,8 @@ async def session_completion(chat_id, session_id):
continue
msg.append(m)
message_id = msg[-1].get("id") if msg else None
chat_id = req.pop("chat_id", "") or ""
session_id = req.pop("session_id", "") or ""
chat_model_id = req.pop("llm_id", "")

chat_model_config = {}
Expand All @@ -993,38 +1038,63 @@ async def session_completion(chat_id, session_id):
chat_model_config[model_config] = config

try:
e, conv = ConversationService.get_by_id(session_id)
if not e:
return get_data_error_result(message="Session not found!")
if conv.dialog_id != chat_id:
return get_data_error_result(message="Session does not belong to this chat!")
conv.message = deepcopy(req["messages"])
e, dia = DialogService.get_by_id(chat_id)
if not e:
return get_data_error_result(message="Chat not found!")
conv = None
if session_id and not chat_id:
return get_data_error_result(message="`chat_id` is required when `session_id` is provided.")

if chat_id:
if not _ensure_owned_chat(chat_id):
return get_json_result(
data=False,
message="No authorization.",
code=RetCode.AUTHENTICATION_ERROR,
)
e, dia = DialogService.get_by_id(chat_id)
if not e:
return get_data_error_result(message="Chat not found!")
if session_id:
e, conv = ConversationService.get_by_id(session_id)
if not e:
return get_data_error_result(message="Session not found!")
if conv.dialog_id != chat_id:
return get_data_error_result(message="Session does not belong to this chat!")
else:
conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id))
session_id = conv.id
conv.message = deepcopy(req["messages"])
else:
dia = _build_default_completion_dialog()
dia.llm_setting = chat_model_config

Comment on lines 1040 to +1068
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

No LLM fallback when chat_id is absent — async_chat will be invoked with llm_id="".

In the new "direct completions" branch (chat_id missing), dia = _build_default_completion_dialog() sets llm_id="" and tenant_llm_id=None. If the caller also omits llm_id from the body, chat_model_id at Line 1032 is "", the if chat_model_id: block at Line 1077 is skipped, and async_chat(dia, ...) will run with an empty model id. That will either 500 deep inside LLMBundle or silently pick the wrong model.

Please either require llm_id in the body when chat_id is not provided, or fall back to the tenant default:

🛠️ Suggested fallback
         else:
             dia = _build_default_completion_dialog()
             dia.llm_setting = chat_model_config
+            if not chat_model_id:
+                default_chat = get_tenant_default_model_by_type(current_user.id, LLMType.CHAT)
+                dia.llm_id = getattr(default_chat, "llm_name", "") or default_chat.get("llm_name", "")

Per coding guidelines ("Add logging for new flows"), please also add a logging.info(...) around this new direct-completion path and inside _create_session_for_completion so the two newly introduced flows are observable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@api/apps/restful_apis/chat_api.py` around lines 1040 - 1068, When chat_id is
absent the dialog created by _build_default_completion_dialog has llm_id="" and
chat_model_id can be empty, causing async_chat(dia, ...) to be called with an
invalid model; change the branch where chat_id is falsy to resolve a valid LLM
id: fetch tenant default LLM (or require req["llm_id"] if preferred) and set
dia.llm_setting/llm_id accordingly before calling async_chat, and add
logging.info calls to (1) the "direct completion" path where dia is created and
the chosen llm_id/model is recorded, and (2) inside
_create_session_for_completion to log when a session is auto-created and which
model/tenant_llm_id was used; update references in this branch to use
chat_model_id or the tenant default consistently so async_chat always receives a
non-empty model id.

del req["messages"]

if not conv.reference:
conv.reference = []
conv.reference = [r for r in conv.reference if r]
conv.reference.append({"chunks": [], "doc_aggs": []})
if conv is not None:
if not conv.reference:
conv.reference = []
conv.reference = [r for r in conv.reference if r]
conv.reference.append({"chunks": [], "doc_aggs": []})

if chat_model_id:
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
dia.llm_id = chat_model_id
dia.llm_setting = chat_model_config

is_embedded = bool(chat_model_id)
stream_mode = req.pop("stream", True)

def _format_answer(ans):
formatted = structure_answer(conv, ans, message_id, session_id)
if chat_id:
formatted["chat_id"] = chat_id
return formatted

async def stream():
nonlocal dia, msg, req, conv
try:
async for ans in async_chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id)
ans = _format_answer(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if not is_embedded:
if conv is not None:
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as ex:
logging.exception(ex)
Expand All @@ -1041,40 +1111,10 @@ async def stream():

answer = None
async for ans in async_chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id)
if not is_embedded:
answer = _format_answer(ans)
if conv is not None:
ConversationService.update_by_id(conv.id, conv.to_dict())
break
return get_json_result(data=answer)
except Exception as ex:
return server_error_response(ex)


@manager.route("/chats/ask", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
async def ask():
req = await get_request_json()
uid = current_user.id

search_id = req.get("search_id", "")
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})

async def stream():
nonlocal req, uid
try:
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as ex:
yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"

resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
Loading