Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7dc331f
Fix: include usage tokens in streaming LLM calls to avoid inaccurate …
Kingsuperyzy Mar 19, 2026
820eca4
Merge remote-tracking branch 'origin/main'
Mar 20, 2026
e3ba810
feat: per-conversation token usage tracking for model requests
Kingsuperyzy Mar 20, 2026
cfa8965
feat: per-conversation token usage tracking for model requests
Kingsuperyzy Mar 20, 2026
39997db
Merge branch 'main' into main
Kingsuperyzy Mar 23, 2026
39399ee
Fix: include usage tokens in streaming LLM calls to avoid inaccurate …
Kingsuperyzy Mar 24, 2026
554f92d
Merge branch 'main' into main
Kingsuperyzy Mar 25, 2026
7bf2a05
Merge branch 'main' into main
JinHai-CN Mar 25, 2026
065cb10
Merge branch 'main' into main
Kingsuperyzy Mar 25, 2026
ddd5898
Fix: include usage tokens in streaming LLM calls to avoid inaccurate …
Kingsuperyzy Mar 25, 2026
aafb99c
Merge branch 'infiniflow:main' into main
Kingsuperyzy Mar 30, 2026
e5d9840
Merge branch 'infiniflow:main' into main
Kingsuperyzy Mar 31, 2026
8a60314
Merge branch 'infiniflow:main' into main
Kingsuperyzy Mar 31, 2026
b416cd4
Merge branch 'infiniflow:main' into main
Kingsuperyzy Apr 2, 2026
936564c
Merge branch 'infiniflow:main' into main
Kingsuperyzy Apr 2, 2026
7f9e277
Passing `biz_type` and `biz_id` when instantiating `LLMBundle`
Kingsuperyzy Apr 2, 2026
610b8dc
Passing `biz_type` and `biz_id` when instantiating `LLMBundle`
Kingsuperyzy Apr 2, 2026
0720425
Passing `biz_type` and `biz_id` when instantiating `LLMBundle`
Kingsuperyzy Apr 2, 2026
f82dd96
Passing `biz_type` and `biz_id` when instantiating `LLMBundle`
Kingsuperyzy Apr 2, 2026
7f090dc
Merge branch 'main' into main
Kingsuperyzy Apr 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions agent/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,10 @@ def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custo
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
}
self.variables = {}
super().__init__(dsl, tenant_id, task_id, custom_header=custom_header)
# Components are instantiated during super().__init__ -> load(),
# so identifiers used by component constructors must exist first.
self._id = canvas_id
super().__init__(dsl, tenant_id, task_id, custom_header=custom_header)

def load(self):
super().load()
Expand Down Expand Up @@ -323,6 +325,9 @@ def load(self):
self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])

def get_history_id(self):
return self.task_id

def __str__(self):
self.dsl["history"] = self.history
self.dsl["retrieval"] = self.retrieval
Expand Down Expand Up @@ -518,7 +523,10 @@ def _node_finished(cpn_obj):
if cpn_obj.component_name.lower() == "message":
if cpn_obj.get_param("auto_play"):
tts_model_config = get_tenant_default_model_by_type(self._tenant_id, LLMType.TTS)
tts_mdl = LLMBundle(self._tenant_id, tts_model_config)
tts_mdl = LLMBundle(self._tenant_id, tts_model_config,
biz_type="agent",
biz_id=self._id,
session_id=self.get_history_id())
if isinstance(cpn_obj.output("content"), partial):
_m = ""
buff_m = ""
Expand Down
3 changes: 3 additions & 0 deletions agent/component/agent_with_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def __init__(self, canvas, id, param: LLMParam):
retry_interval=self._param.delay_after_error,
max_rounds=self._param.max_rounds,
verbose_tool_use=False,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id()
)
self.tool_meta = []
for indexed_name, tool_obj in self.tools.items():
Expand Down
5 changes: 4 additions & 1 deletion agent/component/categorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ async def _invoke_async(self, **kwargs):
self.set_input_value(query_key, msg[-1]["content"])
self._param.update_prompt()
chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config)
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id())

user_prompt = """
---- Real Data ----
Expand Down
10 changes: 8 additions & 2 deletions agent/component/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def __init__(self, canvas, component_id, param: ComponentParamBase):
chat_model_config = get_model_config_by_type_and_name(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id)
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), chat_model_config,
max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error)
retry_interval=self._param.delay_after_error,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id())
self.imgs = []

def get_input_form(self) -> dict[str, dict]:
Expand Down Expand Up @@ -249,7 +252,10 @@ def _prepare_prompt_variables(self):
if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
self._param.llm_id, max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error
retry_interval=self._param.delay_after_error,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id()
)

msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
Expand Down
33 changes: 26 additions & 7 deletions agent/tools/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,18 @@ async def _retrieve_kb(self, query_text: str):
if embd_nms:
tenant_id = self._canvas.get_tenant_id()
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_nms[0])
embd_mdl = LLMBundle(tenant_id, embd_model_config)
embd_mdl = LLMBundle(tenant_id, embd_model_config,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id())

rerank_mdl = None
if self._param.rerank_id:
rerank_model_config = get_model_config_by_type_and_name(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
rerank_mdl = LLMBundle(kbs[0].tenant_id, rerank_model_config)
rerank_mdl = LLMBundle(kbs[0].tenant_id, rerank_model_config,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id())

vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
Expand Down Expand Up @@ -164,7 +170,10 @@ def _resolve_manual_filter(flt: dict) -> dict:
if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]:
tenant_id = self._canvas.get_tenant_id()
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
chat_mdl = LLMBundle(tenant_id, chat_model_config,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id())

doc_ids = await apply_meta_data_filter(
self._param.meta_data_filter,
Expand All @@ -176,7 +185,8 @@ def _resolve_manual_filter(flt: dict) -> dict:
)

if self._param.cross_languages:
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages,
biz_type="agent", biz_id=self._canvas._id, session_id=self._canvas.get_history_id())

if kbs:
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
Expand All @@ -200,7 +210,10 @@ def _resolve_manual_filter(flt: dict) -> dict:
if self._param.toc_enhance:
tenant_id = self._canvas._tenant_id
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
chat_mdl = LLMBundle(tenant_id, chat_model_config,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id())
cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
Expand All @@ -216,7 +229,10 @@ def _resolve_manual_filter(flt: dict) -> dict:
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(tenant_id, chat_model_config))
LLMBundle(tenant_id, chat_model_config,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id()))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
Expand All @@ -227,7 +243,10 @@ def _resolve_manual_filter(flt: dict) -> dict:
if self._param.use_kg and kbs:
chat_model_config = get_tenant_default_model_by_type(kbs[0].tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, chat_model_config))
LLMBundle(kbs[0].tenant_id, chat_model_config,
biz_type="agent",
biz_id=self._canvas._id,
session_id=self._canvas.get_history_id()))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
Expand Down
25 changes: 16 additions & 9 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _set_sync():
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(tenant_id, embd_model_config)
embd_mdl = LLMBundle(tenant_id, embd_model_config, biz_type="document", biz_id=req["doc_id"])

_d = d
if doc.parser_id == ParserType.QA:
Expand Down Expand Up @@ -375,7 +375,7 @@ def _create_sync():
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(tenant_id, embd_model_config)
embd_mdl = LLMBundle(tenant_id, embd_model_config, biz_type="document", biz_id=req["doc_id"])

if image_base64:
d["img_id"] = "{}-{}".format(doc.kb_id, chunck_id)
Expand Down Expand Up @@ -426,6 +426,13 @@ async def _retrieval():
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []

if req.get("search_id", ""):
biz_id = req.get("search_id")
biz_type = "search"
else:
biz_id = kb_ids[0]
biz_type = "kb_retrieval"

meta_data_filter = {}
chat_mdl = None
if req.get("search_id", ""):
Expand All @@ -437,12 +444,12 @@ async def _retrieval():
chat_model_config = get_model_config_by_type_and_name(user_id, LLMType.CHAT, search_config["chat_id"])
else:
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
chat_mdl = LLMBundle(user_id, chat_model_config)
chat_mdl = LLMBundle(user_id, chat_model_config, biz_type=biz_type, biz_id=biz_id)
else:
meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
chat_mdl = LLMBundle(user_id, chat_model_config)
chat_mdl = LLMBundle(user_id, chat_model_config, biz_type=biz_type, biz_id=biz_id)

if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
Expand Down Expand Up @@ -473,19 +480,19 @@ async def _retrieval():
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, biz_type=biz_type, biz_id=biz_id)

rerank_mdl = None
if req.get("tenant_rerank_id"):
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, biz_type=biz_type, biz_id=biz_id)
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, biz_type=biz_type, biz_id=biz_id)

if req.get("keyword", False):
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config, biz_type=biz_type, biz_id=biz_id)
_question += await keyword_extraction(chat_mdl, _question)

labels = label_question(_question, [kb])
Expand All @@ -510,7 +517,7 @@ async def _retrieval():
tenant_ids,
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, default_chat_model_config))
LLMBundle(kb.tenant_id, default_chat_model_config, biz_type=biz_type, biz_id=biz_id))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
Expand Down
2 changes: 1 addition & 1 deletion api/apps/kb_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def _clean(s: str) -> str:
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
else:
return get_error_data_result("`tenant_embd_id` or `embd_id` is required.")
emb_mdl = LLMBundle(tenant_id, embd_model_config)
emb_mdl = LLMBundle(tenant_id, embd_model_config, biz_type="kb_check", biz_id=kb_id)
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)

results, eff_sims = [], []
Expand Down
6 changes: 3 additions & 3 deletions api/apps/restful_apis/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ async def tts():
except Exception as e:
return get_data_error_result(message=str(e))

tts_mdl = LLMBundle(current_user.id, default_tts_model_config)
tts_mdl = LLMBundle(current_user.id, default_tts_model_config, biz_type="tts")

def stream_audio():
try:
Expand Down Expand Up @@ -856,7 +856,7 @@ async def transcriptions():
except Exception as e:
return get_data_error_result(message=str(e))

asr_mdl = LLMBundle(current_user.id, default_asr_model_config)
asr_mdl = LLMBundle(current_user.id, default_asr_model_config, biz_type="speech2text")
if not stream_mode:
text = asr_mdl.transcription(temp_audio_path)
try:
Expand Down Expand Up @@ -918,7 +918,7 @@ async def related_questions():
chat_model_config = get_model_config_by_type_and_name(current_user.id, LLMType.CHAT, chat_id)
else:
chat_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.CHAT)
chat_mdl = LLMBundle(current_user.id, chat_model_config)
chat_mdl = LLMBundle(current_user.id, chat_model_config, biz_type="search", biz_id=search_id)

gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
if "parameter" in gen_conf:
Expand Down
4 changes: 2 additions & 2 deletions api/apps/sdk/dify_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def retrieval(tenant_id):
model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, model_config)
embd_mdl = LLMBundle(kb.tenant_id, model_config, biz_type="kb_retrieval", biz_id=kb_id)
if metadata_condition:
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
if not doc_ids and metadata_condition:
Expand All @@ -161,7 +161,7 @@ async def retrieval(tenant_id):
[tenant_id],
[kb_id],
embd_mdl,
LLMBundle(kb.tenant_id, model_config))
LLMBundle(kb.tenant_id, model_config, biz_type="kb_retrieval", biz_id=kb_id))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

Expand Down
12 changes: 6 additions & 6 deletions api/apps/sdk/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,22 +1771,22 @@ async def retrieval_test(tenant_id):
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0])

rerank_mdl = None
if req.get("tenant_rerank_id"):
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0])
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0])

if langs:
question = await cross_languages(kb.tenant_id, None, question, langs)

if req.get("keyword", False):
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, chat_model_config)
chat_mdl = LLMBundle(kb.tenant_id, chat_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0])
question += await keyword_extraction(chat_mdl, question)

ranks = await settings.retriever.retrieval(
Expand All @@ -1806,14 +1806,14 @@ async def retrieval_test(tenant_id):
)
if toc_enhance:
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, chat_model_config)
chat_mdl = LLMBundle(kb.tenant_id, chat_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0])
cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
if cks:
ranks["chunks"] = cks
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
if use_kg:
chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config))
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config, biz_type="kb_retrieval", biz_id=kb_ids[0]))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

Expand Down
Loading
Loading