Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions api/apps/restful_apis/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,18 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None):

llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(llm_id)
model_type = (llm_setting or {}).get("model_type")
if model_type not in {"chat", "image2text"}:
model_type = "chat"
candidate_model_types = [model_type] if model_type in {"chat", "image2text"} else ["chat", "image2text"]

for current_model_type in candidate_model_types:
if TenantLLMService.query(
tenant_id=tenant_id,
llm_name=llm_name,
llm_factory=llm_factory,
model_type=current_model_type,
):
return None

if not TenantLLMService.query(
tenant_id=tenant_id,
llm_name=llm_name,
llm_factory=llm_factory,
model_type=model_type,
):
return f"`llm_id` {llm_id} doesn't exist"
return None
return f"`llm_id` {llm_id} doesn't exist"


def _validate_rerank_id(rerank_id, tenant_id):
Expand Down
35 changes: 23 additions & 12 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,27 @@ def get_null_tenant_rerank_id_row(cls):
return list(objs)


def _get_dialog_chat_model_config(dialog):
if dialog.llm_id:
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
candidate_model_types = [LLMType.CHAT, LLMType.IMAGE2TEXT]
if llm_type == "image2text":
candidate_model_types = [LLMType.IMAGE2TEXT, LLMType.CHAT]

for candidate_model_type in candidate_model_types:
try:
return get_model_config_by_type_and_name(dialog.tenant_id, candidate_model_type, dialog.llm_id)
except LookupError:
continue

raise LookupError(
f"Tenant Model with name {dialog.llm_id} not found for supported types: chat,image2text"
)
if dialog.tenant_llm_id:
return get_model_config_by_id(dialog.tenant_llm_id)
return get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated


async def async_chat_solo(dialog, messages, stream=True):
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
attachments = ""
Expand All @@ -242,12 +263,7 @@ async def async_chat_solo(dialog, messages, stream=True):
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
attachments = "\n\n".join(text_attachments)

if dialog.llm_id:
model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
model_config = get_model_config_by_id(dialog.tenant_llm_id)
else:
model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
model_config = _get_dialog_chat_model_config(dialog)

chat_mdl = LLMBundle(dialog.tenant_id, model_config)
factory = model_config.get("llm_factory", "") if model_config else ""
Expand Down Expand Up @@ -297,12 +313,7 @@ def get_models(dialog):
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0])

if dialog.llm_id:
chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id)
else:
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
chat_model_config = _get_dialog_chat_model_config(dialog)

chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)

Expand Down