Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions api/apps/restful_apis/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,18 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None):

llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(llm_id)
model_type = (llm_setting or {}).get("model_type")
if model_type not in {"chat", "image2text"}:
model_type = "chat"
candidate_model_types = [model_type] if model_type in {"chat", "image2text"} else ["chat", "image2text"]

for current_model_type in candidate_model_types:
if TenantLLMService.query(
tenant_id=tenant_id,
llm_name=llm_name,
llm_factory=llm_factory,
model_type=current_model_type,
):
return None

if not TenantLLMService.query(
tenant_id=tenant_id,
llm_name=llm_name,
llm_factory=llm_factory,
model_type=model_type,
):
return f"`llm_id` {llm_id} doesn't exist"
return None
return f"`llm_id` {llm_id} doesn't exist"


def _validate_rerank_id(rerank_id, tenant_id):
Expand Down
107 changes: 82 additions & 25 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,74 @@ def get_null_tenant_rerank_id_row(cls):
return list(objs)


def _get_dialog_chat_model_config(dialog):
if dialog.llm_id:
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
candidate_model_types = [LLMType.CHAT, LLMType.IMAGE2TEXT]
if llm_type == "image2text":
candidate_model_types = [LLMType.IMAGE2TEXT, LLMType.CHAT]

for candidate_model_type in candidate_model_types:
try:
model_config = get_model_config_by_type_and_name(
dialog.tenant_id, candidate_model_type, dialog.llm_id
)
resolved_type = (
candidate_model_type.value
if hasattr(candidate_model_type, "value")
else candidate_model_type
)
logging.info(
"Resolved dialog llm_id=%s using model_type=%s",
dialog.llm_id,
resolved_type,
)
return model_config, resolved_type
except LookupError as exc:
candidate_type = (
candidate_model_type.value
if hasattr(candidate_model_type, "value")
else candidate_model_type
)
logging.debug(
"Dialog llm_id=%s not found as model_type=%s: %s",
dialog.llm_id,
candidate_type,
exc,
)
continue

raise LookupError(
f"Tenant Model with name {dialog.llm_id} not found for supported types: chat,image2text"
)
if dialog.tenant_llm_id:
model_config = get_model_config_by_id(dialog.tenant_llm_id)
model_type = model_config.get("model_type")
if hasattr(model_type, "value"):
model_type = model_type.value
resolved_type = model_type or LLMType.CHAT.value
logging.info(
"Resolved dialog tenant_llm_id=%s using model_type=%s",
dialog.tenant_llm_id,
resolved_type,
)
return model_config, resolved_type

model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
model_type = model_config.get("model_type")
if hasattr(model_type, "value"):
model_type = model_type.value
resolved_type = model_type or LLMType.CHAT.value
logging.info(
"Resolved default tenant chat model for tenant_id=%s using model_type=%s",
dialog.tenant_id,
resolved_type,
)
return model_config, resolved_type


async def async_chat_solo(dialog, messages, stream=True):
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
model_config, llm_type = _get_dialog_chat_model_config(dialog)
attachments = ""
image_attachments = []
image_files = []
Expand All @@ -242,13 +308,6 @@ async def async_chat_solo(dialog, messages, stream=True):
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
attachments = "\n\n".join(text_attachments)

if dialog.llm_id:
model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
model_config = get_model_config_by_id(dialog.tenant_llm_id)
else:
model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)

chat_mdl = LLMBundle(dialog.tenant_id, model_config)
factory = model_config.get("llm_factory", "") if model_config else ""

Expand Down Expand Up @@ -283,7 +342,7 @@ async def async_chat_solo(dialog, messages, stream=True):
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}


def get_models(dialog):
def get_models(dialog, llm_model_config=None, llm_type=None):
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
Expand All @@ -297,14 +356,10 @@ def get_models(dialog):
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0])

if dialog.llm_id:
chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id)
else:
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
if llm_model_config is None or llm_type is None:
llm_model_config, llm_type = _get_dialog_chat_model_config(dialog)

chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)
chat_mdl = LLMBundle(dialog.tenant_id, llm_model_config)

if dialog.rerank_id:
rerank_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
Expand Down Expand Up @@ -386,12 +441,16 @@ def convert_last_user_msg_to_multimodal(msg: list[dict], image_data_uris: list[s
text = _normalize_text_from_content(original_content)

if factory_norm == "gemini":
# LiteLLM validates OpenAI-compatible content blocks before provider-specific mapping.
# Keep Gemini inputs in OpenAI shape to avoid "invalid content type=None".
parts = []
if text:
parts.append({"text": text})
parts.append({"type": "text", "text": text})
for image in image_data_uris:
mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png")
parts.append({"inline_data": {"mime_type": mime, "data": b64}})
image_url = image if isinstance(image, str) else str(image)
if not image_url.startswith("data:"):
image_url = f"data:image/png;base64,{image_url}"
parts.append({"type": "image_url", "image_url": {"url": image_url}})
msg[idx]["content"] = parts
return

Expand Down Expand Up @@ -497,11 +556,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
return

chat_start_ts = timer()
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
if llm_type == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
llm_model_config, llm_type = _get_dialog_chat_model_config(dialog)

factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
max_tokens = llm_model_config.get("max_tokens", 8192)
Expand All @@ -523,7 +578,9 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
pass

check_langfuse_tracer_ts = timer()
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(
dialog, llm_model_config=llm_model_config, llm_type=llm_type
)
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
if toolcall_session and tools:
chat_mdl.bind_tools(toolcall_session, tools)
Expand Down