diff --git a/agent/canvas.py b/agent/canvas.py index 65303ca9e9e..e3549c374b4 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -294,8 +294,10 @@ def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custo "sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") } self.variables = {} - super().__init__(dsl, tenant_id, task_id, custom_header=custom_header) + # Components are instantiated during super().__init__ -> load(), + # so identifiers used by component constructors must exist first. self._id = canvas_id + super().__init__(dsl, tenant_id, task_id, custom_header=custom_header) def load(self): super().load() @@ -323,6 +325,9 @@ def load(self): self.retrieval = self.dsl["retrieval"] self.memory = self.dsl.get("memory", []) + def get_history_id(self): + return self.task_id + def __str__(self): self.dsl["history"] = self.history self.dsl["retrieval"] = self.retrieval @@ -518,7 +523,10 @@ def _node_finished(cpn_obj): if cpn_obj.component_name.lower() == "message": if cpn_obj.get_param("auto_play"): tts_model_config = get_tenant_default_model_by_type(self._tenant_id, LLMType.TTS) - tts_mdl = LLMBundle(self._tenant_id, tts_model_config) + tts_mdl = LLMBundle(self._tenant_id, tts_model_config, + biz_type="agent", + biz_id=self._id, + session_id=self.get_history_id()) if isinstance(cpn_obj.output("content"), partial): _m = "" buff_m = "" diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 3938dc03e86..55bc031acdc 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -89,6 +89,9 @@ def __init__(self, canvas, id, param: LLMParam): retry_interval=self._param.delay_after_error, max_rounds=self._param.max_rounds, verbose_tool_use=False, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id() ) self.tool_meta = [] for indexed_name, tool_obj in self.tools.items(): diff --git a/agent/component/categorize.py b/agent/component/categorize.py index 708ce142fe5..11dd4b5f5eb 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -124,7 +124,10 @@ async def _invoke_async(self, **kwargs): self.set_input_value(query_key, msg[-1]["content"]) self._param.update_prompt() chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) - chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config) + chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id()) user_prompt = """ ---- Real Data ---- diff --git a/agent/component/llm.py b/agent/component/llm.py index 24254ce20cf..68593821405 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -88,7 +88,10 @@ def __init__(self, canvas, component_id, param: ComponentParamBase): chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id) self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config, max_retries=self._param.max_retries, - retry_interval=self._param.delay_after_error) + retry_interval=self._param.delay_after_error, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id()) self.imgs = [] def get_input_form(self) -> dict[str, dict]: @@ -249,7 +252,10 @@ def _prepare_prompt_variables(self): if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value: self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value, self._param.llm_id, max_retries=self._param.max_retries, - retry_interval=self._param.delay_after_error + retry_interval=self._param.delay_after_error, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id() ) msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args) diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 6c7ca8695d2..2ef8de09757 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -122,12 +122,18 @@ async def _retrieve_kb(self, query_text: str): if embd_nms: tenant_id = self._canvas.get_tenant_id() embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_nms[0]) - embd_mdl = LLMBundle(tenant_id, embd_model_config) + embd_mdl = LLMBundle(tenant_id, embd_model_config, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id()) rerank_mdl = None if self._param.rerank_id: rerank_model_config = get_model_config_by_type_and_name(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) - rerank_mdl = LLMBundle(kbs[0].tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(kbs[0].tenant_id, rerank_model_config, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id()) vars = self.get_input_elements_from_text(query_text) vars = {k: o["value"] for k, o in vars.items()} @@ -170,7 +176,10 @@ def _resolve_manual_filter(flt: dict) -> dict: if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]: tenant_id = self._canvas.get_tenant_id() chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id()) doc_ids = await apply_meta_data_filter( self._param.meta_data_filter, @@ -182,7 +191,8 @@ def _resolve_manual_filter(flt: dict) -> dict: ) if self._param.cross_languages: - query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) + query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages, + biz_type="agent", biz_id=self._canvas._id, session_id=self._canvas.get_history_id()) if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) @@ -206,7 +216,10 @@ def _resolve_manual_filter(flt: dict) -> dict: if self._param.toc_enhance: tenant_id = self._canvas._tenant_id chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id()) cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) if self.check_if_canceled("Retrieval processing"): @@ -222,7 +235,10 @@ def _resolve_manual_filter(flt: dict) -> dict: [kb.tenant_id for kb in kbs], kb_ids, embd_mdl, - LLMBundle(tenant_id, chat_model_config)) + LLMBundle(tenant_id, chat_model_config, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id())) if self.check_if_canceled("Retrieval processing"): return if ck["content_with_weight"]: @@ -233,7 +249,10 @@ def _resolve_manual_filter(flt: dict) -> dict: if self._param.use_kg and kbs: chat_model_config = get_tenant_default_model_by_type(kbs[0].tenant_id, LLMType.CHAT) ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, - LLMBundle(kbs[0].tenant_id, chat_model_config)) + LLMBundle(kbs[0].tenant_id, chat_model_config, + biz_type="agent", + biz_id=self._canvas._id, + session_id=self._canvas.get_history_id())) if self.check_if_canceled("Retrieval processing"): return if ck["content_with_weight"]: diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 5e06d872a69..77851456b78 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -184,7 +184,7 @@ def _set_sync(): embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id) else: embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING) - embd_mdl = LLMBundle(tenant_id, embd_model_config) + embd_mdl = LLMBundle(tenant_id, embd_model_config, biz_type="document", biz_id=req["doc_id"]) _d = d if doc.parser_id == ParserType.QA: @@ -375,7 +375,7 @@ def _create_sync(): embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id) else: embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING) - embd_mdl = LLMBundle(tenant_id, embd_model_config) + embd_mdl = LLMBundle(tenant_id, embd_model_config, biz_type="document", biz_id=req["doc_id"]) if image_base64: d["img_id"] = "{}-{}".format(doc.kb_id, chunck_id) @@ -426,6 +426,13 @@ async def _retrieval(): local_doc_ids = list(doc_ids) if doc_ids else [] tenant_ids = [] + if req.get("search_id", ""): + biz_id = req.get("search_id") + biz_type = "search" + else: + biz_id = kb_ids[0] + biz_type = "kb_retrieval" + meta_data_filter = {} chat_mdl = None if req.get("search_id", ""): @@ -437,12 +444,12 @@ async def _retrieval(): chat_model_config = get_model_config_by_type_and_name(user_id, LLMType.CHAT, search_config["chat_id"]) else: chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT) - chat_mdl = LLMBundle(user_id, chat_model_config) + chat_mdl = LLMBundle(user_id, chat_model_config, biz_type=biz_type, biz_id=biz_id) else: meta_data_filter = req.get("meta_data_filter") or {} if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT) - chat_mdl = LLMBundle(user_id, chat_model_config) + chat_mdl = LLMBundle(user_id, chat_model_config, biz_type=biz_type, biz_id=biz_id) if meta_data_filter: metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) @@ -473,19 +480,19 @@ async def _retrieval(): embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) else: embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING) - embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) + embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, biz_type=biz_type, biz_id=biz_id) rerank_mdl = None if req.get("tenant_rerank_id"): rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, biz_type=biz_type, biz_id=biz_id) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, biz_type=biz_type, biz_id=biz_id) if req.get("keyword", False): default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config) + chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config, biz_type=biz_type, biz_id=biz_id) _question += await keyword_extraction(chat_mdl, _question) labels = label_question(_question, [kb]) @@ -510,7 +517,7 @@ async def _retrieval(): tenant_ids, kb_ids, embd_mdl, - LLMBundle(kb.tenant_id, default_chat_model_config)) + LLMBundle(kb.tenant_id, default_chat_model_config, biz_type=biz_type, biz_id=biz_id)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index f817de6330d..d6009d8d200 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -956,7 +956,7 @@ def _clean(s: str) -> str: embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id) else: return get_error_data_result("`tenant_embd_id` or `embd_id` is required.") - emb_mdl = LLMBundle(tenant_id, embd_model_config) + emb_mdl = LLMBundle(tenant_id, embd_model_config, biz_type="kb_check", biz_id=kb_id) samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) results, eff_sims = [], [] diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index 7c311ae4bf2..bdc10f461a3 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -843,7 +843,7 @@ async def tts(): except Exception as e: return get_data_error_result(message=str(e)) - tts_mdl = LLMBundle(current_user.id, default_tts_model_config) + tts_mdl = LLMBundle(current_user.id, default_tts_model_config, biz_type="tts") def stream_audio(): try: @@ -893,7 +893,7 @@ async def transcriptions(): except Exception as e: return get_data_error_result(message=str(e)) - asr_mdl = LLMBundle(current_user.id, default_asr_model_config) + asr_mdl = LLMBundle(current_user.id, default_asr_model_config, biz_type="speech2text") if not stream_mode: text = asr_mdl.transcription(temp_audio_path) try: @@ -955,7 +955,7 @@ async def related_questions(): chat_model_config = get_model_config_by_type_and_name(current_user.id, LLMType.CHAT, chat_id) else: chat_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.CHAT) - chat_mdl = LLMBundle(current_user.id, chat_model_config) + chat_mdl = LLMBundle(current_user.id, chat_model_config, biz_type="search", biz_id=search_id) gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) if "parameter" in gen_conf: diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index e6dd61d035e..817b4802c38 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -135,7 +135,7 @@ async def retrieval(tenant_id): model_config = get_model_config_by_id(kb.tenant_embd_id) else: model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) - embd_mdl = LLMBundle(kb.tenant_id, model_config) + embd_mdl = LLMBundle(kb.tenant_id, model_config, biz_type="kb_retrieval", biz_id=kb_id) if metadata_condition: doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) if not doc_ids and metadata_condition: @@ -161,7 +161,7 @@ async def retrieval(tenant_id): [tenant_id], [kb_id], embd_mdl, - LLMBundle(kb.tenant_id, model_config)) + LLMBundle(kb.tenant_id, model_config, biz_type="kb_retrieval", biz_id=kb_id)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 72964ae35a7..1a3be501e91 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1587,22 +1587,22 @@ async def retrieval_test(tenant_id): embd_model_config = get_model_config_by_id(kb.tenant_embd_id) else: embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) - embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) + embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0]) rerank_mdl = None if req.get("tenant_rerank_id"): rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0]) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0]) if langs: question = await cross_languages(kb.tenant_id, None, question, langs) if req.get("keyword", False): chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(kb.tenant_id, chat_model_config) + chat_mdl = LLMBundle(kb.tenant_id, chat_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0]) question += await keyword_extraction(chat_mdl, question) ranks = await settings.retriever.retrieval( @@ -1622,14 +1622,14 @@ async def retrieval_test(tenant_id): ) if toc_enhance: chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(kb.tenant_id, chat_model_config) + chat_mdl = LLMBundle(kb.tenant_id, chat_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0]) cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size) if cks: ranks["chunks"] = cks ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) if use_kg: chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config)) + ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0])) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 82e048ff17b..2526c35c8bc 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -720,7 +720,7 @@ async def ask_about(tenant_id): async def stream(): nonlocal req, uid try: - async for ans in async_ask(req["question"], req["kb_ids"], uid): + async for ans in async_ask(req["question"], req["kb_ids"], uid, biz_type="dialog"): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( @@ -745,7 +745,7 @@ async def related_questions(tenant_id): question = req["question"] industry = req.get("industry", "") chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, biz_type="dialog") prompt = """ Objective: To generate search terms related to the user's search keywords, helping users find more valuable information. Instructions: @@ -929,7 +929,8 @@ async def ask_about_embedded(): async def stream(): nonlocal req, uid try: - async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): + async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config, + biz_type="search" if search_id else "dialog", biz_id=search_id or None): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( @@ -995,7 +996,8 @@ async def _retrieval(): chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id) else: chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, + biz_type="search", biz_id=req.get("search_id") or None) # Apply search_config settings if not explicitly provided in request if not req.get("similarity_threshold"): similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold)) @@ -1009,7 +1011,7 @@ async def _retrieval(): meta_data_filter = req.get("meta_data_filter") or {} if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, biz_type="search") if meta_data_filter: metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) @@ -1035,19 +1037,23 @@ async def _retrieval(): embd_model_config = get_model_config_by_id(kb.tenant_embd_id) else: embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) - embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) + embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, + biz_type="search", biz_id=req.get("search_id") or None) rerank_mdl = None if tenant_rerank_id: rerank_model_config = get_model_config_by_id(tenant_rerank_id) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, + biz_type="search", biz_id=req.get("search_id") or None) elif rerank_id: rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, + biz_type="search", biz_id=req.get("search_id") or None) if req.get("keyword", False): default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(kb.tenant_id, default_chat_model) + chat_mdl = LLMBundle(kb.tenant_id, default_chat_model, + biz_type="search", biz_id=req.get("search_id") or None) _question += await keyword_extraction(chat_mdl, _question) labels = label_question(_question, [kb]) @@ -1058,7 +1064,8 @@ async def _retrieval(): if use_kg: default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, - LLMBundle(kb.tenant_id, default_chat_model)) + LLMBundle(kb.tenant_id, default_chat_model, + biz_type="search", biz_id=req.get("search_id") or None)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -1106,7 +1113,8 @@ async def related_questions_embedded(): chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id) else: chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, + biz_type="search" if search_id else "dialog", biz_id=search_id or None) gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) prompt = load_prompt("related_question") @@ -1174,7 +1182,8 @@ async def mindmap(): search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} - mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) + mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}), + biz_type="search" if search_id else "dialog", biz_id=search_id or None) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) @@ -1211,7 +1220,7 @@ async def sequence2txt(tenant_id): default_asr_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.SPEECH2TEXT) except Exception as e: return get_error_data_result(message=str(e)) - asr_mdl=LLMBundle(tenant_id, default_asr_model_config) + asr_mdl=LLMBundle(tenant_id, default_asr_model_config, biz_type="speech2text") if not stream_mode: text = asr_mdl.transcription(temp_audio_path) try: @@ -1244,7 +1253,7 @@ async def tts(tenant_id): default_tts_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.TTS) except Exception as e: return get_error_data_result(message=str(e)) - tts_mdl = LLMBundle(tenant_id, default_tts_model_config) + tts_mdl = LLMBundle(tenant_id, default_tts_model_config, biz_type="tts") def stream_audio(): try: diff --git a/api/db/db_models.py b/api/db/db_models.py index 97a05c6cde4..8cdfa087e33 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -836,6 +836,32 @@ class Meta: ) +class LLMUsageLog(DataBaseModel): + """Detailed LLM usage log for recording token consumption, billing, and related business context for each call.""" + id = CharField(max_length=32, primary_key=True) + tenant_id = CharField(max_length=32, null=False, index=True) + user_id = CharField(max_length=32, null=True, index=True) + biz_type = CharField(max_length=32, null=False, index=True, + help_text="Business type: dialog/agent/document_parse/graphrag/raptor/other") + biz_id = CharField(max_length=32, null=True, index=True, + help_text="Business object ID, such as dialog.id or canvas.id") + session_id = CharField(max_length=32, null=True, index=True, + help_text="Session ID, such as Conversation.id or API4Conversation.id") + tenant_llm_id = IntegerField(null=False, index=True, + help_text="Associated TenantLLM.id") + model_type = CharField(max_length=32, null=False, index=True, + help_text="chat/embedding/rerank/image2text/speech2text") + prompt_tokens = IntegerField(default=0) + completion_tokens = IntegerField(default=0) + total_tokens = IntegerField(default=0) + cost = FloatField(default=0.0, help_text="USD") + created_at = BigIntegerField(null=False, index=True, + help_text="Unix timestamp in milliseconds") + + class Meta: + db_table = "llm_usage_log" + + class TenantLangfuse(DataBaseModel): tenant_id = CharField(max_length=32, null=False, primary_key=True) secret_key = CharField(max_length=2048, null=False, help_text="SECRET KEY", index=True) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 4765b2bdbb6..f6ecdb328b2 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -157,7 +157,7 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, llm_config = get_model_config_by_id(tenant_llm_id) else: llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) - llm = LLMBundle(tenant_id, llm_config) + llm = LLMBundle(tenant_id, llm_config, biz_type="memory", biz_id=task_id) if task_id: TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."}) res = await llm.async_chat(system_prompt, user_prompts, extract_conf) @@ -177,7 +177,7 @@ async def embed_and_save(memory, message_list: list[dict], task_id: str=None): embd_model_config = get_model_config_by_id(memory.tenant_embd_id) else: embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) - embedding_model = LLMBundle(memory.tenant_id, embd_model_config) + embedding_model = LLMBundle(memory.tenant_id, embd_model_config, biz_type="memory", biz_id=memory.id) if task_id: TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."}) vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list]) @@ -251,7 +251,7 @@ def query_message(filter_dict: dict, params: dict): embd_model_config = get_model_config_by_id(memory.tenant_embd_id) else: embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) - embd_model = LLMBundle(memory.tenant_id, embd_model_config) + embd_model = LLMBundle(memory.tenant_id, embd_model_config, biz_type="memory", biz_id=memory.id) match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"]) match_text, _ = MsgTextQuery().question(question, min_match=params["similarity_threshold"]) keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7) diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 5a205b14219..728fab11e68 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -183,7 +183,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses if stream: try: - async for ans in async_chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, conversation_id=conv.id, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -195,7 +195,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses else: answer = None - async for ans in async_chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, conversation_id=conv.id, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) ConversationService.update_by_id(conv.id, conv.to_dict()) break diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 83f79c285a1..b3b2e0fba46 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -15,6 +15,7 @@ # import asyncio import binascii +import inspect import logging import re import time @@ -209,7 +210,11 @@ def get_null_tenant_rerank_id_row(cls): return list(objs) -async def async_chat_solo(dialog, messages, stream=True): +async def async_chat_solo(dialog, messages, stream=True, **kwargs): + conversation_id = kwargs.get("conversation_id") + biz_type = "dialog" + biz_id = getattr(dialog, "id", None) + session_id = conversation_id llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id) attachments = "" image_attachments = [] @@ -221,14 +226,14 @@ async def async_chat_solo(dialog, messages, stream=True): text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) attachments = "\n\n".join(text_attachments) model_config = get_model_config_by_id(dialog.tenant_llm_id) - chat_mdl = LLMBundle(dialog.tenant_id, model_config) + chat_mdl = LLMBundle(dialog.tenant_id, model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) factory = model_config.get("llm_factory", "") if model_config else "" prompt_config = dialog.prompt_config tts_mdl = None if prompt_config.get("tts"): default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS) - tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model) + tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model, biz_type=biz_type, biz_id=biz_id, session_id=session_id) msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] if attachments and msg: msg[-1]["content"] += attachments @@ -255,7 +260,9 @@ async def async_chat_solo(dialog, messages, stream=True): yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} -def get_models(dialog): +def get_models(dialog, biz_type="dialog", biz_id=None, session_id=None): + if biz_id is None: + biz_id = getattr(dialog, "id", None) embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -265,7 +272,7 @@ def get_models(dialog): if embedding_list: embd_owner_tenant_id = kbs[0].tenant_id embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0]) - embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config) + embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) if not embd_mdl: raise LookupError("Embedding model(%s) not found" % embedding_list[0]) @@ -276,18 +283,30 @@ def get_models(dialog): else: chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config) + chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) if dialog.rerank_id: rerank_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) - rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) if dialog.prompt_config.get("tts"): default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS) - tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config) + tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl +def _get_models_with_runtime_context(dialog, conversation_id=None): + call_kwargs = {} + params = inspect.signature(get_models).parameters + if "biz_type" in params: + call_kwargs["biz_type"] = "dialog" + if "biz_id" in params: + call_kwargs["biz_id"] = getattr(dialog, "id", None) + if "session_id" in params: + call_kwargs["session_id"] = conversation_id + return get_models(dialog, **call_kwargs) + + def split_file_attachments(files: list[dict] | None, raw: bool = False) -> tuple[list[str], list[str] | list[dict]]: if not files: return [], [] @@ -460,9 +479,10 @@ def find_and_replace(pattern, group_index=1, repl=lambda digits: f"ID:{digits}") async def async_chat(dialog, messages, stream=True, **kwargs): logging.debug("Begin async_chat") + conversation_id = kwargs.get("conversation_id") assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): - async for ans in async_chat_solo(dialog, messages, stream): + async for ans in async_chat_solo(dialog, messages, stream, **kwargs): yield ans return @@ -493,7 +513,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs): pass check_langfuse_tracer_ts = timer() - kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog) + dialog_biz_id = getattr(dialog, "id", None) + kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = _get_models_with_runtime_context(dialog, conversation_id) toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") if toolcall_session and tools: chat_mdl.bind_tools(toolcall_session, tools) @@ -542,12 +563,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") if len(questions) > 1 and prompt_config.get("refine_multiturn"): - questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)] + questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages, + biz_type="dialog", biz_id=getattr(dialog, "id", None), session_id=conversation_id)] else: questions = questions[-1:] if prompt_config.get("cross_languages"): - questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] + questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"], + biz_type="dialog", biz_id=getattr(dialog, "id", None), session_id=conversation_id)] if dialog.meta_data_filter: metas = DocMetadataService.get_flatted_meta_by_kbs(dialog.kb_ids) @@ -637,7 +660,7 @@ async def callback(msg:str): if prompt_config.get("use_kg"): default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, - LLMBundle(dialog.tenant_id, default_chat_model)) + LLMBundle(dialog.tenant_id, default_chat_model, biz_type="dialog", biz_id=dialog_biz_id, session_id=conversation_id)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) @@ -1361,7 +1384,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): if state.endswith_think: yield ("marker", "", state) -async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): +async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}, biz_type="dialog", biz_id=None, session_id=None): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None kb_ids = search_config.get("kb_ids", kb_ids) @@ -1376,12 +1399,12 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever embd_owner_tenant_id = kbs[0].tenant_id embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0]) - embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config) + embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_llm_name) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) if rerank_id: rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id) - rerank_mdl = LLMBundle(tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(tenant_id, rerank_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) @@ -1445,7 +1468,7 @@ def decorate_answer(answer): yield final -async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): +async def gen_mindmap(question, kb_ids, tenant_id, search_config={}, biz_type="dialog", biz_id=None, session_id=None): meta_data_filter = search_config.get("meta_data_filter", {}) doc_ids = search_config.get("doc_ids", []) rerank_id = search_config.get("rerank_id", "") @@ -1461,16 +1484,16 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): else: embd_owner_tenant_id = kbs[0].tenant_id embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, kbs[0].embd_id) - embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config) + embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) chat_id = search_config.get("chat_id", "") if chat_id: chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id) else: chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) if rerank_id: rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id) - rerank_mdl = LLMBundle(tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(tenant_id, rerank_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) if meta_data_filter: metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index c31d415189b..48118fcac42 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -1055,7 +1055,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): embd_model_config = get_model_config_by_id(kb.tenant_embd_id) else: embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) - embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language) + # The embd_mdl instantiation was moved into the embedding function to allow passing the per-document doc_id err, files = FileService.upload_document(kb, file_objs, user_id) assert not err, "\n".join(err) @@ -1105,10 +1105,11 @@ def dummy(prog=None, msg=""): es_bulk_size = 64 def embedding(doc_id, cnts, batch_size=16): - nonlocal embd_mdl, chunk_counts, token_counts + nonlocal chunk_counts, token_counts, embd_model_config, kb + doc_embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language, biz_type="document", biz_id=doc_id) vectors = [] for i in range(0, len(cnts), batch_size): - vts, c = embd_mdl.encode(cnts[i : i + batch_size]) + vts, c = doc_embd_mdl.encode(cnts[i : i + batch_size]) vectors.extend(vts.tolist()) chunk_counts[doc_id] += len(cnts[i : i + batch_size]) token_counts[doc_id] += c @@ -1119,8 +1120,9 @@ def embedding(doc_id, cnts, batch_size=16): _, tenant = TenantService.get_by_id(kb.tenant_id) tenant_llm_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config) + for doc_id in docids: + llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config, biz_type="document", biz_id=doc_id) cks = [c for c in docs if c["doc_id"] == doc_id] if parser_ids[doc_id] != ParserType.PICTURE.value: diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 6058c6b69f7..fe65b2cfc94 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -24,9 +24,10 @@ from api.db.db_models import LLM from api.db.services.common_service import CommonService +from api.db.services.llm_usage_log_service import LLMUsageLogService from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService from common.constants import LLMType -from common.token_utils import num_tokens_from_string +from common.token_utils import LLMUsage, num_tokens_from_string class LLMService(CommonService): @@ -83,8 +84,33 @@ def get_init_tenant_llm(user_id): class LLMBundle(LLM4Tenant): - def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): + def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", + user_id: str = None, biz_type: str = None, biz_id: str = None, + session_id: str = None, **kwargs): super().__init__(tenant_id, model_config, lang, **kwargs) + self.user_id = user_id + self.biz_type = biz_type or "other" + self.biz_id = biz_id + self.session_id = session_id + + def _log_usage(self, model_type: str, usage: LLMUsage): + """Write token usage for this LLM call to the detail log table; failures are logged without affecting the main flow.""" + tenant_llm_id = self.model_config.get("id") + if not tenant_llm_id: + return + LLMUsageLogService.create( + tenant_id=self.tenant_id, + tenant_llm_id=tenant_llm_id, + model_type=model_type, + user_id=self.user_id, + biz_type=self.biz_type, + biz_id=self.biz_id, + session_id=self.session_id, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + cost=usage.cost, + ) def bind_tools(self, toolcall_session, tools): if not self.is_tools: @@ -115,6 +141,7 @@ def encode(self, texts: list): generation.update(usage_details={"total_tokens": used_tokens}) generation.end() + self._log_usage("embedding", LLMUsage(total_tokens=used_tokens)) return embeddings, used_tokens def encode_queries(self, query: str): @@ -131,6 +158,7 @@ def encode_queries(self, query: str): generation.update(usage_details={"total_tokens": used_tokens}) generation.end() + self._log_usage("embedding", LLMUsage(total_tokens=used_tokens)) return emd, used_tokens def similarity(self, query: str, texts: list): @@ -145,6 +173,7 @@ def similarity(self, query: str, texts: list): generation.update(usage_details={"total_tokens": used_tokens}) generation.end() + self._log_usage("rerank", LLMUsage(total_tokens=used_tokens)) return sim, used_tokens def describe(self, image, max_tokens=300): @@ -159,6 +188,7 @@ def describe(self, image, max_tokens=300): generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) generation.end() + self._log_usage("image2text", LLMUsage(total_tokens=used_tokens)) return txt def describe_with_prompt(self, image, prompt): @@ -173,6 +203,7 @@ def describe_with_prompt(self, image, prompt): generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) generation.end() + self._log_usage("image2text", LLMUsage(total_tokens=used_tokens)) return txt def transcription(self, audio): @@ -187,6 +218,7 @@ def transcription(self, audio): generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) generation.end() + self._log_usage("speech2text", LLMUsage(total_tokens=used_tokens)) return txt def stream_transcription(self, audio): @@ -225,6 +257,8 @@ def stream_transcription(self, audio): ) generation.end() + self._log_usage("speech2text", LLMUsage(total_tokens=used_tokens)) + return if self.langfuse: @@ -245,6 +279,7 @@ def stream_transcription(self, audio): ) generation.end() + self._log_usage("speech2text", LLMUsage(total_tokens=used_tokens)) yield { "event": "final", "text": full_text, @@ -382,7 +417,7 @@ async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kw use_kwargs = self._clean_param(chat_partial, **kwargs) try: - txt, used_tokens = await chat_partial(**use_kwargs) + txt, usage = await chat_partial(**use_kwargs) except Exception as e: if generation: generation.update(output={"error": str(e)}) @@ -393,17 +428,22 @@ async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kw if not self.verbose_tool_use: txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - if used_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): - logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], used_tokens)) + if usage.total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], usage.total_tokens): + logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], usage.total_tokens)) if generation: - generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) + generation.update(output={"output": txt}, usage_details={ + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + }) generation.end() + self._log_usage("chat", usage) return txt async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): - total_tokens = 0 + usage = LLMUsage() ans = "" _bundle_is_tools = self.is_tools _mdl_is_tools = getattr(self.mdl, "is_tools", False) @@ -424,8 +464,8 @@ async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = use_kwargs = self._clean_param(chat_partial, **kwargs) try: async for txt in chat_partial(**use_kwargs): - if isinstance(txt, int): - total_tokens = txt + if isinstance(txt, LLMUsage): + usage = txt break if txt.endswith("") and ans.endswith(""): @@ -441,15 +481,20 @@ async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = generation.update(output={"error": str(e)}) generation.end() raise - if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens): - logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens)) + if usage.total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], usage.total_tokens): + logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], usage.total_tokens)) if generation: - generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) + generation.update(output={"output": ans}, usage_details={ + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + }) generation.end() + self._log_usage("chat", usage) return async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs): - total_tokens = 0 + usage = LLMUsage() ans = "" if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"): stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None) @@ -467,8 +512,8 @@ async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: use_kwargs = self._clean_param(chat_partial, **kwargs) try: async for txt in chat_partial(**use_kwargs): - if isinstance(txt, int): - total_tokens = txt + if isinstance(txt, LLMUsage): + usage = txt break if txt.endswith("") and ans.endswith(""): @@ -484,9 +529,14 @@ async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: generation.update(output={"error": str(e)}) generation.end() raise - if total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], total_tokens): - logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], total_tokens)) + if usage.total_tokens and not TenantLLMService.increase_usage_by_id(self.model_config["id"], usage.total_tokens): + logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.model_config["llm_name"], usage.total_tokens)) if generation: - generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) + generation.update(output={"output": ans}, usage_details={ + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + }) generation.end() + self._log_usage("chat", usage) return diff --git a/api/db/services/llm_usage_log_service.py b/api/db/services/llm_usage_log_service.py new file mode 100644 index 00000000000..9bbeb7f1677 --- /dev/null +++ b/api/db/services/llm_usage_log_service.py @@ -0,0 +1,63 @@ +import logging +import time +from uuid import uuid4 + +from api.db.db_models import DB, LLMUsageLog +from api.db.services.common_service import CommonService + + +class LLMUsageLogService(CommonService): + model = LLMUsageLog + + @classmethod + @DB.connection_context() + def create( + cls, + tenant_id: str, + tenant_llm_id: int, + model_type: str, + total_tokens: int, + user_id: str = None, + biz_type: str = "other", + biz_id: str = None, + session_id: str = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + cost: float = 0.0, + ): + """Create one detailed LLM usage record. + + Args: + tenant_id: Tenant ID. + tenant_llm_id: TenantLLM.id linked to the concrete model configuration. + model_type: Model type, such as "chat", "embedding", or "rerank". + total_tokens: Total tokens consumed by this call. + user_id: User ID that initiated the call, if any. + biz_type: Business type, such as "dialog", "agent", or "document_parse". + biz_id: Business object ID, such as dialog.id, canvas.id, or document_id. + session_id: Session ID, such as Conversation.id or API4Conversation.id. + prompt_tokens: Input token count. Non-chat modes usually use 0. + completion_tokens: Output token count. Non-chat modes usually use 0. + cost: Call cost in USD. Populated in LiteLLM mode; otherwise currently 0. + """ + try: + cls.model.create( + id=uuid4().hex, + tenant_id=tenant_id, + user_id=user_id, + biz_type=biz_type, + biz_id=biz_id, + session_id=session_id, + tenant_llm_id=tenant_llm_id, + model_type=model_type, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=cost, + created_at=int(time.time() * 1000), + ) + except Exception: + logging.exception( + "LLMUsageLogService.create failed for tenant_id=%s, biz_type=%s, biz_id=%s, session_id=%s", + tenant_id, biz_type, biz_id, session_id, + ) diff --git a/common/token_utils.py b/common/token_utils.py index 981e98a1b5c..3b6b409e7f4 100644 --- a/common/token_utils.py +++ b/common/token_utils.py @@ -16,6 +16,8 @@ import os +from dataclasses import dataclass + import tiktoken from common.file_utils import get_project_base_directory @@ -26,6 +28,22 @@ encoder = tiktoken.get_encoding("cl100k_base") +@dataclass +class LLMUsage: + """Token usage and billing information for LLM calls. + + Replaces bare int total_tokens across the codebase, carrying prompt/completion token + breakdown and cost in a unified structure. + - Chat mode: prompt_tokens / completion_tokens / cost are all populated. + - Embedding / Rerank mode: completion_tokens=0, cost defaults to 0. + - Native SDK mode (Mistral, Baidu, etc.): only total_tokens is set, others default to 0. + """ + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + cost: float = 0.0 + + def num_tokens_from_string(string: str) -> int: """Returns the number of tokens in a text string.""" try: diff --git a/deepdoc/parser/figure_parser.py b/deepdoc/parser/figure_parser.py index e062f462538..80ec8527f69 100644 --- a/deepdoc/parser/figure_parser.py +++ b/deepdoc/parser/figure_parser.py @@ -49,7 +49,8 @@ def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs): return tbls try: vision_model_config = get_tenant_default_model_by_type(kwargs["tenant_id"], LLMType.IMAGE2TEXT) - vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config) + vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config, + biz_type="document", biz_id=kwargs.get("doc_id", "")) callback(0.7, "Visual model detected. Attempting to enhance figure extraction...") except Exception: vision_model = None @@ -69,7 +70,8 @@ def vision_figure_parser_figure_xlsx_wrapper(images,callback=None, **kwargs): return [] try: vision_model_config = get_tenant_default_model_by_type(kwargs["tenant_id"], LLMType.IMAGE2TEXT) - vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config) + vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config, + biz_type="document", biz_id=kwargs.get("doc_id", "")) callback(0.2, "Visual model detected. Attempting to enhance Excel image extraction...") except Exception: vision_model = None @@ -98,7 +100,8 @@ def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs): context_size = max(0, int(parser_config.get("image_context_size", 0) or 0)) try: vision_model_config = get_tenant_default_model_by_type(kwargs["tenant_id"], LLMType.IMAGE2TEXT) - vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config) + vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config, + biz_type="document", biz_id=kwargs.get("doc_id", "")) callback(0.7, "Visual model detected. Attempting to enhance figure extraction...") except Exception: vision_model = None @@ -137,7 +140,8 @@ def vision_figure_parser_docx_wrapper_naive(chunks, idx_lst, callback=None, **kw return [] try: vision_model_config = get_tenant_default_model_by_type(kwargs["tenant_id"], LLMType.IMAGE2TEXT) - vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config) + vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config, + biz_type="document", biz_id=kwargs.get("doc_id", "")) callback(0.7, "Visual model detected. Attempting to enhance figure extraction...") except Exception: vision_model = None diff --git a/rag/app/audio.py b/rag/app/audio.py index 29ef625fad4..dd7b8f5d444 100644 --- a/rag/app/audio.py +++ b/rag/app/audio.py @@ -47,7 +47,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): callback(0.1, "USE Sequence2Txt LLM to transcription the audio") seq2txt_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.SPEECH2TEXT) - seq2txt_mdl = LLMBundle(tenant_id, seq2txt_model_config, lang=lang) + seq2txt_mdl = LLMBundle(tenant_id, seq2txt_model_config, lang=lang, biz_type="document", biz_id=kwargs.get("doc_id", "")) ans = seq2txt_mdl.transcription(tmp_path) callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32]) diff --git a/rag/app/naive.py b/rag/app/naive.py index 25b715b6edf..ca8b4523e20 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -129,7 +129,7 @@ def by_mineru( if mineru_llm_name: try: ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, mineru_llm_name) - ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang) + ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang, biz_type="document", biz_id=kwargs.get("doc_id", "")) pdf_parser = ocr_model.mdl sections, tables = pdf_parser.parse_pdf( filepath=filename, @@ -211,7 +211,7 @@ def by_paddleocr( if paddleocr_llm_name: try: ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, paddleocr_llm_name) - ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang) + ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang, biz_type="document", biz_id=kwargs.get("doc_id", "")) pdf_parser = ocr_model.mdl sections, tables = pdf_parser.parse_pdf( filepath=filename, @@ -244,6 +244,8 @@ def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=No tenant_id, model_config=vision_model_config, lang=kwargs.get("lang", "Chinese"), + biz_type="document", + biz_id=kwargs.get("doc_id", ""), ) pdf_parser = VisionParser(vision_model=vision_model, **kwargs) @@ -912,7 +914,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca try: vision_model_config = get_tenant_default_model_by_type(kwargs["tenant_id"], LLMType.IMAGE2TEXT) - vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config) + vision_model = LLMBundle(kwargs["tenant_id"], vision_model_config, biz_type="document", biz_id=kwargs.get("doc_id", "")) callback(0.2, "Visual model detected. Attempting to enhance figure extraction...") except Exception as e: logging.warning(f"Failed to detect figure extraction: {e}") diff --git a/rag/app/picture.py b/rag/app/picture.py index d58f923eb80..838b71e36e3 100644 --- a/rag/app/picture.py +++ b/rag/app/picture.py @@ -52,7 +52,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): } ) cv_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.IMAGE2TEXT) - cv_mdl = LLMBundle(tenant_id, model_config=cv_model_config, lang=lang) + cv_mdl = LLMBundle(tenant_id, model_config=cv_model_config, lang=lang, biz_type="document", biz_id=kwargs.get("doc_id", "")) video_prompt = str(parser_config.get("video_prompt", "") or "") ans = asyncio.run( cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename, video_prompt=video_prompt)) @@ -81,7 +81,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): try: callback(0.4, "Use CV LLM to describe the picture.") cv_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.IMAGE2TEXT) - cv_mdl = LLMBundle(tenant_id, model_config=cv_model_config, lang=lang) + cv_mdl = LLMBundle(tenant_id, model_config=cv_model_config, lang=lang, biz_type="document", biz_id=kwargs.get("doc_id", "")) with io.BytesIO() as img_binary: img.save(img_binary, format="JPEG") img_binary.seek(0) diff --git a/rag/app/resume.py b/rag/app/resume.py index 80fc322bd09..c08a26744a2 100644 --- a/rag/app/resume.py +++ b/rag/app/resume.py @@ -1048,7 +1048,7 @@ def _parse_json_with_repair(text: str) -> dict: raise json.JSONDecodeError("All JSON repair strategies failed", text, 0) -def _call_llm(prompt: str, tenant_id , lang: str) -> Optional[dict]: +def _call_llm(prompt: str, tenant_id, lang: str, doc_id: str = "") -> Optional[dict]: """ Call LLM and parse JSON response (ref SmartResume's retry + fault-tolerance strategy). @@ -1068,7 +1068,7 @@ def _call_llm(prompt: str, tenant_id , lang: str) -> Optional[dict]: from api.db.services.llm_service import LLMBundle from common.constants import LLMType - llm = LLMBundle(tenant_id, LLMType.CHAT, lang=lang) + llm = LLMBundle(tenant_id, LLMType.CHAT, lang=lang, biz_type="document", biz_id=doc_id) for attempt in range(_LLM_MAX_RETRIES + 1): try: @@ -1265,44 +1265,44 @@ def _extract_description_from_range( return "\n".join(line.strip() for line in extracted_lines if line.strip()) -def _extract_basic_info(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: +def _extract_basic_info(indexed_text: str, tenant_id, lang: str, doc_id: str = "") -> Optional[dict]: """Extract basic info (subtask 1). Basic info is usually at the beginning of the resume, first 8000 chars suffice. """ prompt = get_basic_info_prompt(lang).format(indexed_text=indexed_text[:8000]) - return _call_llm(prompt,tenant_id, lang) + return _call_llm(prompt, tenant_id, lang, doc_id) -def _extract_work_experience(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: +def _extract_work_experience(indexed_text: str, tenant_id, lang: str, doc_id: str = "") -> Optional[dict]: """Extract work experience (subtask 2, using index pointers). Work experience may span the middle-to-end of the resume, use full text to avoid truncation. """ prompt = get_work_exp_prompt(lang).format(indexed_text=indexed_text) - return _call_llm(prompt, tenant_id , lang) + return _call_llm(prompt, tenant_id, lang, doc_id) -def _extract_education(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: +def _extract_education(indexed_text: str, tenant_id, lang: str, doc_id: str = "") -> Optional[dict]: """Extract education background (subtask 3). Education is usually at the end of the resume, must use full text to avoid truncation. Resume text is generally under 30K chars, within LLM context window. """ prompt = get_education_prompt(lang).format(indexed_text=indexed_text) - return _call_llm(prompt,tenant_id, lang) + return _call_llm(prompt, tenant_id, lang, doc_id) -def _extract_project_experience(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: +def _extract_project_experience(indexed_text: str, tenant_id, lang: str, doc_id: str = "") -> Optional[dict]: """Extract project experience (subtask 4, using index pointers). Project experience may span the middle-to-end of the resume, use full text to avoid truncation. """ prompt = get_project_exp_prompt(lang).format(indexed_text=indexed_text) - return _call_llm(prompt, tenant_id , lang) + return _call_llm(prompt, tenant_id, lang, doc_id) -def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) -> Optional[dict]: +def parse_with_llm(indexed_text: str, lines: list[str], tenant_id, lang: str, doc_id: str = "") -> Optional[dict]: """ Extract resume info using parallel task decomposition strategy (ref SmartResume Section 3.2). @@ -1322,10 +1322,10 @@ def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) - try: # Execute four subtasks in parallel with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - future_basic = executor.submit(_extract_basic_info, indexed_text, tenant_id , lang) - future_work = executor.submit(_extract_work_experience, indexed_text, tenant_id , lang) - future_edu = executor.submit(_extract_education, indexed_text, tenant_id, lang) - future_project = executor.submit(_extract_project_experience, indexed_text, tenant_id , lang) + future_basic = executor.submit(_extract_basic_info, indexed_text, tenant_id, lang, doc_id) + future_work = executor.submit(_extract_work_experience, indexed_text, tenant_id, lang, doc_id) + future_edu = executor.submit(_extract_education, indexed_text, tenant_id, lang, doc_id) + future_project = executor.submit(_extract_project_experience, indexed_text, tenant_id, lang, doc_id) basic_info = future_basic.result(timeout=60) work_exp = future_work.result(timeout=60) @@ -2053,7 +2053,7 @@ def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") - # ==================== Pipeline Orchestration & Chunk Construction ==================== -def parse_resume(filename: str, binary: bytes, tenant_id , lang: str = "Chinese") -> tuple[dict, list[str], list[dict]]: +def parse_resume(filename: str, binary: bytes, tenant_id, lang: str = "Chinese", doc_id: str = "") -> tuple[dict, list[str], list[dict]]: """ Resume parsing pipeline orchestration function @@ -2081,7 +2081,7 @@ def parse_resume(filename: str, binary: bytes, tenant_id , lang: str = "Chinese" return {"name_kwd": default_name}, [], [] # Phase 2: Parallel LLM structured extraction - resume = parse_with_llm(indexed_text, lines, tenant_id , lang) + resume = parse_with_llm(indexed_text, lines, tenant_id, lang, doc_id) # Phase 3: Fallback to regex parsing when LLM fails if not resume: @@ -2489,7 +2489,13 @@ def callback(prog, msg): return None callback(0.1, "Starting resume parsing...") # Parse resume - resume, lines, line_positions = parse_resume(filename, binary, tenant_id , lang) + resume, lines, line_positions = parse_resume( + filename, + binary, + tenant_id, + lang, + kwargs.get("doc_id", ""), + ) callback(0.6, "Resume structured extraction complete") # Build document chunks (with coordinate info) diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index cf756649b76..5806acc49ea 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -372,7 +372,7 @@ def resolve_mineru_llm_name(): tenant_id = self._canvas._tenant_id ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, parser_model_name) - ocr_model = LLMBundle(tenant_id, ocr_model_config, lang=conf.get("lang", "Chinese")) + ocr_model = LLMBundle(tenant_id, ocr_model_config, lang=conf.get("lang", "Chinese"), biz_type="document", biz_id=self._canvas._doc_id or "") pdf_parser = ocr_model.mdl lines, _ = pdf_parser.parse_pdf( @@ -490,7 +490,7 @@ def resolve_paddleocr_llm_name(): tenant_id = self._canvas._tenant_id ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, parser_model_name) - ocr_model = LLMBundle(tenant_id, ocr_model_config) + ocr_model = LLMBundle(tenant_id, ocr_model_config, biz_type="document", biz_id=self._canvas._doc_id or "") pdf_parser = ocr_model.mdl lines, _ = pdf_parser.parse_pdf( @@ -523,7 +523,8 @@ def resolve_paddleocr_llm_name(): vision_model_config = get_model_config_by_type_and_name(self._canvas._tenant_id, LLMType.IMAGE2TEXT, conf["parse_method"]) else: vision_model_config = get_tenant_default_model_by_type(self._canvas._tenant_id, LLMType.IMAGE2TEXT) - vision_model = LLMBundle(self._canvas._tenant_id, vision_model_config, lang=self._param.setups["pdf"].get("lang")) + + vision_model = LLMBundle(self._canvas._tenant_id, vision_model_config, lang=self._param.setups["pdf"].get("lang"), biz_type="document", biz_id=self._canvas._doc_id or "") pdf_parser = VisionParser(vision_model=vision_model) lines, _ = pdf_parser(blob, callback=self.callback) bboxes = [] @@ -1029,7 +1030,7 @@ def _image(self, name, blob, **kwargs): lang = conf["lang"] # use VLM to describe the picture cv_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, conf["parse_method"]) - cv_model = LLMBundle(self._canvas.get_tenant_id(), cv_model_config, lang=lang) + cv_model = LLMBundle(self._canvas.get_tenant_id(), cv_model_config, lang=lang, biz_type="document", biz_id=self._canvas._doc_id or "") img_binary = io.BytesIO() img.save(img_binary, format="JPEG") img_binary.seek(0) @@ -1064,7 +1065,7 @@ def _audio(self, name, blob, **kwargs): tmpf.flush() tmp_path = os.path.abspath(tmpf.name) seq2txt_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT, conf["llm_id"]) - seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), seq2txt_model_config) + seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), seq2txt_model_config, biz_type="document", biz_id=self._canvas._doc_id or "") txt = seq2txt_mdl.transcription(tmp_path) self.set_output("text", txt) @@ -1076,7 +1077,7 @@ def _video(self, name, blob, **kwargs): conf = self._param.setups["video"] self.set_output("output_format", conf["output_format"]) cv_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, conf["llm_id"]) - cv_mdl = LLMBundle(self._canvas.get_tenant_id(), cv_model_config) + cv_mdl = LLMBundle(self._canvas.get_tenant_id(), cv_model_config, biz_type="document", biz_id=self._canvas._doc_id or "") video_prompt = str(conf.get("prompt", "") or "") txt = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name, video_prompt=video_prompt)) diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index ea2e59aec4d..02400b5374f 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -62,7 +62,9 @@ async def _embedding(self, name, chunks): embd_model_config = get_model_config_by_type_and_name(self._canvas._tenant_id, LLMType.EMBEDDING, kb.embd_id) else: embd_model_config = get_tenant_default_model_by_type(self._canvas._tenant_id, LLMType.EMBEDDING) - embedding_model = LLMBundle(self._canvas._tenant_id, embd_model_config) + embedding_model = LLMBundle(self._canvas._tenant_id, embd_model_config, + biz_type="agent", biz_id=self._canvas._id, + session_id=self._canvas.get_history_id()) texts = [] for c in chunks: txt = "" diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index fb1353706de..1eb80b3a677 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -30,7 +30,7 @@ from openai import AsyncOpenAI, OpenAI from strenum import StrEnum -from common.token_utils import num_tokens_from_string, total_token_count_from_response +from common.token_utils import num_tokens_from_string, total_token_count_from_response, LLMUsage from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.nlp import is_chinese, is_english @@ -187,14 +187,20 @@ async def _async_chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) reasoning_start = False - request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf} + request_kwargs = {"model": self.model_name, "messages": history, "stream": True, "stream_options": {"include_usage": True}, **gen_conf} stop = kwargs.get("stop") if stop: request_kwargs["stop"] = stop response = await self.async_client.chat.completions.create(**request_kwargs) + prompt_tokens = 0 + completion_tokens = 0 + total_tokens_est = 0 async for resp in response: if not resp.choices: + if hasattr(resp, "usage") and resp.usage: + prompt_tokens = getattr(resp.usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(resp.usage, "completion_tokens", 0) or 0 continue if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" @@ -208,9 +214,8 @@ async def _async_chat_streamly(self, history, gen_conf, **kwargs): else: reasoning_start = False ans = resp.choices[0].delta.content - tol = total_token_count_from_response(resp) - if not tol: - tol = num_tokens_from_string(resp.choices[0].delta.content) + tol = num_tokens_from_string(resp.choices[0].delta.content) + total_tokens_est += tol finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" if finish_reason == "length": @@ -219,28 +224,33 @@ async def _async_chat_streamly(self, history, gen_conf, **kwargs): else: ans += LENGTH_NOTIFICATION_EN yield ans, tol + total = prompt_tokens + completion_tokens if (prompt_tokens + completion_tokens) else total_tokens_est + yield LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total) async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) ans = "" - total_tokens = 0 + usage = LLMUsage() for attempt in range(self.max_retries + 1): try: - async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs): + async for item in self._async_chat_streamly(history, gen_conf, **kwargs): + if isinstance(item, LLMUsage): + usage = item + break + delta_ans, _ = item ans = delta_ans - total_tokens += tol yield ans - yield total_tokens + yield usage return except Exception as e: e = await self._exceptions_async(e, attempt) if e: yield e - yield total_tokens + yield usage return def _length_stop(self, ans): @@ -359,6 +369,8 @@ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict history.insert(0, {"role": "system", "content": system}) ans = "" + prompt_tokens = 0 + completion_tokens = 0 tk_count = 0 hist = deepcopy(history) for attempt in range(self.max_retries + 1): @@ -367,7 +379,12 @@ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf) - tk_count += total_token_count_from_response(response) + if response.usage: + prompt_tokens += getattr(response.usage, "prompt_tokens", 0) or 0 + completion_tokens += getattr(response.usage, "completion_tokens", 0) or 0 + tk_count += response.usage.total_tokens or 0 + else: + tk_count += total_token_count_from_response(response) if any([not response.choices, not response.choices[0].message]): raise Exception(f"500 response structure error. Response: {response}") @@ -380,7 +397,7 @@ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) - return ans, tk_count + return ans, LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=tk_count or prompt_tokens + completion_tokens) async def _exec_tool(tc): name = tc.function.name @@ -405,12 +422,14 @@ async def _exec_tool(tc): history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) response, token_count = await self._async_chat(history, gen_conf) ans += response - tk_count += token_count - return ans, tk_count + prompt_tokens += token_count.prompt_tokens + completion_tokens += token_count.completion_tokens + tk_count += token_count.total_tokens + return ans, LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=tk_count or prompt_tokens + completion_tokens) except Exception as e: e = await self._exceptions_async(e, attempt) if e: - return e, tk_count + return e, LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=tk_count or prompt_tokens + completion_tokens) assert False, "Shouldn't be here." @@ -421,6 +440,8 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c history.insert(0, {"role": "system", "content": system}) total_tokens = 0 + prompt_tokens = 0 + completion_tokens = 0 hist = deepcopy(history) for attempt in range(self.max_retries + 1): @@ -430,13 +451,17 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c reasoning_start = False logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}") - response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, stream_options={"include_usage": True}, tools=tools, tool_choice="auto", **gen_conf) final_tool_calls = {} answer = "" async for resp in response: if not hasattr(resp, "choices") or not resp.choices: + if hasattr(resp, "usage") and resp.usage: + prompt_tokens += getattr(resp.usage, "prompt_tokens", 0) or 0 + completion_tokens += getattr(resp.usage, "completion_tokens", 0) or 0 + total_tokens += resp.usage.total_tokens or 0 continue delta = resp.choices[0].delta @@ -480,7 +505,7 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c if answer and not final_tool_calls: logging.info(f"[ToolLoop] round={_round} completed with text response, exiting") - yield total_tokens + yield LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens or prompt_tokens + completion_tokens) return async def _exec_tool(tc): @@ -512,22 +537,28 @@ async def _exec_tool(tc): logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, stream_options={"include_usage": True}, tools=tools, tool_choice="auto", **gen_conf) + tokens_from_api = 0 async for resp in response: if not hasattr(resp, "choices") or not resp.choices: + if hasattr(resp, "usage") and resp.usage: + prompt_tokens += getattr(resp.usage, "prompt_tokens", 0) or 0 + completion_tokens += getattr(resp.usage, "completion_tokens", 0) or 0 + tokens_from_api += resp.usage.total_tokens or 0 continue delta = resp.choices[0].delta if not hasattr(delta, "content") or delta.content is None: + # fallback: 部分模型将 usage 嵌在 delta chunk 而非独立 chunk + tol = total_token_count_from_response(resp) + if tol: + tokens_from_api = tol continue - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(delta.content) - else: - total_tokens = tol + total_tokens += num_tokens_from_string(delta.content) yield delta.content - yield total_tokens + final_total = tokens_from_api or total_tokens + yield LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=final_total or prompt_tokens + completion_tokens) return except Exception as e: @@ -535,7 +566,7 @@ async def _exec_tool(tc): if e: logging.error(f"async_chat_streamly failed: {e}") yield e - yield total_tokens + yield LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens or prompt_tokens + completion_tokens) return assert False, "Shouldn't be here." @@ -546,17 +577,20 @@ async def _async_chat(self, history, gen_conf, **kwargs): logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly") final_ans = "" - tol_token = 0 - async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs): + final_usage = LLMUsage() + async for item in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs): + if isinstance(item, LLMUsage): + final_usage = item + break + delta, _ = item if delta.startswith("") or delta.endswith(""): continue final_ans += delta - tol_token = tol if len(final_ans.strip()) == 0: final_ans = "**ERROR**: Empty response from reasoning model" - return final_ans.strip(), tol_token + return final_ans.strip(), final_usage _, kwargs = _apply_model_family_policies( self.model_name, @@ -567,11 +601,16 @@ async def _async_chat(self, history, gen_conf, **kwargs): response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) if not response.choices or not response.choices[0].message or not response.choices[0].message.content: - return "", 0 + return "", LLMUsage() ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) - return ans, total_token_count_from_response(response) + usage = response.usage + return ans, LLMUsage( + prompt_tokens=usage.prompt_tokens if usage else 0, + completion_tokens=usage.completion_tokens if usage else 0, + total_tokens=usage.total_tokens if usage else 0, + ) async def async_chat(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": @@ -584,7 +623,7 @@ async def async_chat(self, system, history, gen_conf={}, **kwargs): except Exception as e: e = await self._exceptions_async(e, attempt) if e: - return e, 0 + return e, LLMUsage() assert False, "Shouldn't be here." @@ -786,6 +825,8 @@ def _clean_conf(self, gen_conf): return gen_conf def _chat(self, history, gen_conf={}, **kwargs): + # TODO(billing): Mistral 原生 SDK 只返回 total_tokens,无 prompt/completion 明细和 cost。 + # 待 mistralai SDK 支持后补充,或改用 LiteLLM 路由统一处理。 gen_conf = self._clean_conf(gen_conf) response = self.client.chat(model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content @@ -865,6 +906,7 @@ def __init__(self, key, model_name, base_url=None, **kwargs): self.client = Client(api_token=key) def _chat(self, history, gen_conf={}, **kwargs): + # TODO(billing): Replicate 原生 SDK 只能用 tiktoken 估算 total_tokens,无精确明细和 cost。 system = history[0]["content"] if history and history[0]["role"] == "system" else "" prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"]) response = self.client.run( @@ -938,6 +980,8 @@ def _clean_conf(self, gen_conf): return gen_conf def _chat(self, history, gen_conf): + # TODO(billing): 百度千帆(qianfan SDK)只返回 total_tokens,无 prompt/completion 明细和 cost。 + # 可通过 response body 中 usage 字段进一步解析,待确认 SDK 版本支持后完善。 system = history[0]["content"] if history and history[0]["role"] == "system" else "" response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body ans = response["result"] @@ -1018,6 +1062,8 @@ def _clean_conf(self, gen_conf): return gen_conf def _chat(self, history, gen_conf={}, **kwargs): + # TODO(billing): Google Cloud 原生 SDK(google.genai / AnthropicVertex)只返回 total_tokens, + # prompt/completion 明细和 cost 待按各子模型 SDK 的 usage_metadata 字段完善。 system = history[0]["content"] if history and history[0]["role"] == "system" else "" if "claude" in self.model_name: @@ -1329,16 +1375,26 @@ async def async_chat(self, system, history, gen_conf, **kwargs): ) if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): - return "", 0 + return "", LLMUsage() ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) - return ans, total_token_count_from_response(response) + usage = response.usage + try: + cost = litellm.completion_cost(completion_response=response) + except Exception: + cost = 0.0 + return ans, LLMUsage( + prompt_tokens=usage.prompt_tokens if usage else 0, + completion_tokens=usage.completion_tokens if usage else 0, + total_tokens=usage.total_tokens if usage else 0, + cost=cost, + ) except Exception as e: e = await self._exceptions_async(e, attempt) if e: - return e, 0 + return e, LLMUsage() assert False, "Shouldn't be here." @@ -1349,6 +1405,7 @@ async def async_chat_streamly(self, system, history, gen_conf, **kwargs): gen_conf = self._clean_conf(gen_conf) reasoning_start = False total_tokens = 0 + chunks = [] # 收集所有 chunk,流结束后用于计算 cost completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) stop = kwargs.get("stop") @@ -1361,12 +1418,14 @@ async def async_chat_streamly(self, system, history, gen_conf, **kwargs): **completion_args, drop_params=True, timeout=self.timeout, + stream_options={"include_usage": True}, ) async for resp in stream: if not hasattr(resp, "choices") or not resp.choices: continue + chunks.append(resp) delta = resp.choices[0].delta if not hasattr(delta, "content") or delta.content is None: delta.content = "" @@ -1395,13 +1454,33 @@ async def async_chat_streamly(self, system, history, gen_conf, **kwargs): ans += LENGTH_NOTIFICATION_EN yield ans - yield total_tokens + + # 流结束:用 stream_chunk_builder 重建完整 response 以获取精确 token 明细和 cost + prompt_tokens = 0 + completion_tokens = 0 + cost = 0.0 + try: + full_response = litellm.stream_chunk_builder(chunks, messages=history) + if full_response and full_response.usage: + prompt_tokens = full_response.usage.prompt_tokens or 0 + completion_tokens = full_response.usage.completion_tokens or 0 + total_tokens = full_response.usage.total_tokens or total_tokens + cost = litellm.completion_cost(completion_response=full_response) + except Exception: + pass # 降级:保留已累计的 total_tokens,cost=0 + + yield LLMUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=cost, + ) return except Exception as e: e = await self._exceptions_async(e, attempt) if e: yield e - yield total_tokens + yield LLMUsage(total_tokens=total_tokens) return def _length_stop(self, ans): @@ -1502,6 +1581,8 @@ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict history.insert(0, {"role": "system", "content": system}) ans = "" + prompt_tokens = 0 + completion_tokens = 0 tk_count = 0 hist = deepcopy(history) for attempt in range(self.max_retries + 1): @@ -1517,7 +1598,12 @@ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict timeout=self.timeout, ) - tk_count += total_token_count_from_response(response) + if response.usage: + prompt_tokens += getattr(response.usage, "prompt_tokens", 0) or 0 + completion_tokens += getattr(response.usage, "completion_tokens", 0) or 0 + tk_count += response.usage.total_tokens or 0 + else: + tk_count += total_token_count_from_response(response) if not hasattr(response, "choices") or not response.choices or not response.choices[0].message: raise Exception(f"500 response structure error. Response: {response}") @@ -1531,7 +1617,7 @@ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict ans += message.content or "" if response.choices[0].finish_reason == "length": ans = self._length_stop(ans) - return ans, tk_count + return ans, LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=tk_count or prompt_tokens + completion_tokens) async def _exec_tool(tc): name = tc.function.name @@ -1557,13 +1643,15 @@ async def _exec_tool(tc): response, token_count = await self.async_chat("", history, gen_conf) ans += response - tk_count += token_count - return ans, tk_count + prompt_tokens += token_count.prompt_tokens + completion_tokens += token_count.completion_tokens + tk_count += token_count.total_tokens + return ans, LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=tk_count or prompt_tokens + completion_tokens) except Exception as e: e = await self._exceptions_async(e, attempt) if e: - return e, tk_count + return e, LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=tk_count or prompt_tokens + completion_tokens) assert False, "Shouldn't be here." @@ -1574,6 +1662,8 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c history.insert(0, {"role": "system", "content": system}) total_tokens = 0 + prompt_tokens = 0 + completion_tokens = 0 hist = deepcopy(history) for attempt in range(self.max_retries + 1): @@ -1588,6 +1678,7 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c **completion_args, drop_params=True, timeout=self.timeout, + stream_options={"include_usage": True}, ) final_tool_calls = {} @@ -1595,6 +1686,10 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c async for resp in response: if not hasattr(resp, "choices") or not resp.choices: + if hasattr(resp, "usage") and resp.usage: + prompt_tokens += getattr(resp.usage, "prompt_tokens", 0) or 0 + completion_tokens += getattr(resp.usage, "completion_tokens", 0) or 0 + total_tokens += resp.usage.total_tokens or 0 continue delta = resp.choices[0].delta @@ -1638,7 +1733,7 @@ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_c if answer and not final_tool_calls: logging.info(f"[ToolLoop] round={_round} completed with text response, exiting") - yield total_tokens + yield LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens or prompt_tokens + completion_tokens) return async def _exec_tool(tc): @@ -1675,29 +1770,36 @@ async def _exec_tool(tc): **completion_args, drop_params=True, timeout=self.timeout, + stream_options={"include_usage": True}, ) + tokens_from_api = 0 async for resp in response: if not hasattr(resp, "choices") or not resp.choices: + if hasattr(resp, "usage") and resp.usage: + prompt_tokens += getattr(resp.usage, "prompt_tokens", 0) or 0 + completion_tokens += getattr(resp.usage, "completion_tokens", 0) or 0 + tokens_from_api += resp.usage.total_tokens or 0 continue delta = resp.choices[0].delta if not hasattr(delta, "content") or delta.content is None: + # fallback: 部分模型将 usage 嵌在 delta chunk 而非独立 chunk + tol = total_token_count_from_response(resp) + if tol: + tokens_from_api = tol continue - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(delta.content) - else: - total_tokens = tol + total_tokens += num_tokens_from_string(delta.content) yield delta.content - yield total_tokens + final_total = tokens_from_api or total_tokens + yield LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=final_total or prompt_tokens + completion_tokens) return except Exception as e: e = await self._exceptions_async(e, attempt) if e: yield e - yield total_tokens + yield LLMUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens or prompt_tokens + completion_tokens) return assert False, "Shouldn't be here." diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index e363fe180c4..a500def829c 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -226,7 +226,8 @@ async def question_proposal(chat_mdl, content, topn=3): return kwd -async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): +async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None, + biz_type=None, biz_id=None, session_id=None): from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService @@ -237,7 +238,7 @@ async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) conv = [] for m in messages: if m["role"] not in ["user", "assistant"]: @@ -262,7 +263,7 @@ async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] -async def cross_languages(tenant_id, llm_id, query, languages=[]): +async def cross_languages(tenant_id, llm_id, query, languages=[], biz_type=None, biz_id=None, session_id=None): from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService @@ -275,7 +276,7 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]): chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) else: chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) - chat_mdl = LLMBundle(tenant_id, chat_model_config) + chat_mdl = LLMBundle(tenant_id, chat_model_config, biz_type=biz_type, biz_id=biz_id, session_id=session_id) rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 2909181c8f6..fbac93ce3f3 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -281,6 +281,7 @@ async def build_chunks(task, progress_callback): kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"], + doc_id=task["doc_id"], ) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) except TaskCanceledException: @@ -344,7 +345,7 @@ async def upload_to_minio(document, chunk): st = timer() progress_callback(msg="Start to generate keywords for every chunk ...") chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"]) - chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) + chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"], biz_type="document", biz_id=task["doc_id"]) async def doc_keyword_extraction(chat_mdl, d, topn): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) @@ -378,7 +379,7 @@ async def doc_keyword_extraction(chat_mdl, d, topn): st = timer() progress_callback(msg="Start to generate questions for every chunk ...") chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"]) - chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) + chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"], biz_type="document", biz_id=task["doc_id"]) async def doc_question_proposal(chat_mdl, d, topn): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) @@ -411,7 +412,7 @@ async def doc_question_proposal(chat_mdl, d, topn): st = timer() progress_callback(msg="Start to generate meta-data for every chunk ...") chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"]) - chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) + chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"], biz_type="document", biz_id=task["doc_id"]) async def gen_metadata_task(chat_mdl, d): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", @@ -466,7 +467,7 @@ async def gen_metadata_task(chat_mdl, d): else: all_tags = json.loads(all_tags) chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, task["llm_id"]) - chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) + chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"], biz_type="document", biz_id=task["doc_id"]) docs_to_tag = [] for d in docs: @@ -522,7 +523,7 @@ async def doc_content_tagging(chat_mdl, d, topn_tags): def build_TOC(task, docs, progress_callback): progress_callback(msg="Start to generate table of content ...") chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"]) - chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) + chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"], biz_type="document", biz_id=task.get("doc_id", "")) docs = sorted(docs, key=lambda d: ( d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) @@ -674,7 +675,7 @@ async def run_dataflow(task: dict): e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) embedding_id = kb.embd_id embd_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.EMBEDDING, embedding_id) - embedding_model = LLMBundle(task["tenant_id"], embd_model_config) + embedding_model = LLMBundle(task["tenant_id"], embd_model_config, biz_type="document", biz_id=doc_id) @timeout(60) def batch_encode(txts): @@ -1004,7 +1005,7 @@ async def do_handle_task(task): embd_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.EMBEDDING, task_embedding_id) else: embd_model_config = get_tenant_default_model_by_type(task_tenant_id, LLMType.EMBEDDING) - embedding_model = LLMBundle(task_tenant_id, embd_model_config, lang=task_language) + embedding_model = LLMBundle(task_tenant_id, embd_model_config, lang=task_language, biz_type="document", biz_id=task_doc_id) vts, _ = embedding_model.encode(["ok"]) vector_size = len(vts[0]) except Exception as e: @@ -1057,7 +1058,7 @@ async def do_handle_task(task): # bind LLM for raptor chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id) - chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language) + chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language, biz_type="document", biz_id=task_doc_id) # run RAPTOR async with kg_limiter: chunks, token_count = await run_raptor_for_kb( @@ -1102,7 +1103,7 @@ async def do_handle_task(task): graphrag_conf = kb_parser_config.get("graphrag", {}) start_ts = timer() chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id) - chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language) + chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language, biz_type="document", biz_id=task_doc_id) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) async with kg_limiter: diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index dcbe105e37f..9a47fc80ec0 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -1048,7 +1048,7 @@ def test_sessions_ask_route_validation_and_stream_unit(monkeypatch): monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [SimpleNamespace(chunk_num=1)]) captured = {} - async def _streaming_async_ask(question, kb_ids, uid): + async def _streaming_async_ask(question, kb_ids, uid, **kwargs): captured["question"] = question captured["kb_ids"] = kb_ids captured["uid"] = uid @@ -1234,7 +1234,7 @@ def test_searchbots_ask_embedded_auth_and_stream_unit(monkeypatch): monkeypatch.setattr(module.SearchService, "get_detail", lambda _search_id: {"search_config": {"mode": "test"}}) captured = {} - async def _embedded_async_ask(question, kb_ids, uid, search_config=None): + async def _embedded_async_ask(question, kb_ids, uid, search_config=None, **kwargs): captured["question"] = question captured["kb_ids"] = kb_ids captured["uid"] = uid @@ -1575,7 +1575,7 @@ def test_searchbots_mindmap_embedded_matrix_unit(monkeypatch): captured = {} - async def _gen_ok(question, kb_ids, tenant_id, search_config): + async def _gen_ok(question, kb_ids, tenant_id, search_config, **kwargs): captured["params"] = (question, kb_ids, tenant_id, search_config) return {"nodes": [question]}