Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/apps/restful_apis/memory_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def create_memory():
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json()
req = ensure_tenant_model_id_for_params(current_user.id, req)
req = ensure_tenant_model_id_for_params(current_user.id, req, strict=True)
t_parsed = time.perf_counter() if timing_enabled else None
Comment thread
coderabbitai[bot] marked this conversation as resolved.
try:
memory_info = {
Expand Down
2 changes: 1 addition & 1 deletion api/db/joint_services/memory_message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict,
else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
if tenant_llm_id:
llm_config = get_model_config_by_id(tenant_llm_id)
llm_config = get_model_config_by_id(tenant_llm_id, LLMType.CHAT)
else:
llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id)
llm = LLMBundle(tenant_id, llm_config)
Expand Down
39 changes: 31 additions & 8 deletions api/db/joint_services/tenant_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,39 @@
from api.db.services.tenant_llm_service import TenantLLMService, TenantService


def get_model_config_by_id(tenant_model_id: int) -> dict:
def _normalize_model_type(model_type):
return model_type.value if hasattr(model_type, "value") else model_type


def _coerce_model_config_type(config_dict: dict, model_type=None) -> dict:
if model_type is None:
return config_dict

model_type_val = _normalize_model_type(model_type)
config_model_type = _normalize_model_type(config_dict.get("model_type"))
if config_model_type == model_type_val:
return config_dict

if TenantLLMService.model_supports_type(
config_dict["llm_name"],
model_type_val,
config_dict.get("llm_factory"),
):
adjusted = dict(config_dict)
adjusted["model_type"] = model_type_val
return adjusted

raise LookupError(
f"Tenant Model with name {config_dict['llm_name']} has type {config_model_type}, expected {model_type_val}"
)


def get_model_config_by_id(tenant_model_id: int, model_type: str | enum.Enum | None = None) -> dict:
found, model_config = TenantLLMService.get_by_id(tenant_model_id)
if not found:
raise LookupError(f"Tenant Model with id {tenant_model_id} not found")
config_dict = model_config.to_dict()
config_dict = _coerce_model_config_type(config_dict, model_type)
llm = LLMService.query(llm_name=config_dict["llm_name"])
if llm:
config_dict["is_tools"] = llm[0].is_tools
Expand All @@ -35,7 +63,7 @@ def get_model_config_by_id(tenant_model_id: int) -> dict:
def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
if not model_name:
raise Exception("Model Name is required")
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
model_type_val = _normalize_model_type(model_type)
model_config = TenantLLMService.get_api_key(tenant_id, model_name, model_type_val)
if not model_config:
# model_name in format 'name@factory', split model_name and try again
Expand Down Expand Up @@ -65,12 +93,7 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam
else:
# model_name without @factory
config_dict = model_config.to_dict()
config_model_type = config_dict.get("model_type")
config_model_type = config_model_type.value if hasattr(config_model_type, "value") else config_model_type
if config_model_type != model_type_val:
raise LookupError(
f"Tenant Model with name {model_name} has type {config_model_type}, expected {model_type_val}"
)
config_dict = _coerce_model_config_type(config_dict, model_type_val)
llm = LLMService.query(llm_name=config_dict["llm_name"])
if llm:
config_dict["is_tools"] = llm[0].is_tools
Expand Down
4 changes: 2 additions & 2 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ async def async_chat_solo(dialog, messages, stream=True):
if dialog.llm_id:
model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
model_config = get_model_config_by_id(dialog.tenant_llm_id)
model_config = get_model_config_by_id(dialog.tenant_llm_id, LLMType.CHAT)
else:
model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)

Expand Down Expand Up @@ -300,7 +300,7 @@ def get_models(dialog):
if dialog.llm_id:
chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id)
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id, LLMType.CHAT)
else:
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)

Expand Down
106 changes: 92 additions & 14 deletions api/db/services/tenant_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,97 @@ class LLMFactoriesService(CommonService):
class TenantLLMService(CommonService):
model = TenantLLM

_MODEL_TYPE_TAGS = {
LLMType.CHAT.value: {"CHAT"},
LLMType.EMBEDDING.value: {"EMBEDDING", "TEXT EMBEDDING"},
LLMType.SPEECH2TEXT.value: {"SPEECH2TEXT", "ASR"},
LLMType.IMAGE2TEXT.value: {"IMAGE2TEXT"},
LLMType.RERANK.value: {"RERANK"},
LLMType.TTS.value: {"TTS"},
LLMType.OCR.value: {"OCR"},
}

@classmethod
def _normalize_model_type(cls, model_type):
return model_type.value if hasattr(model_type, "value") else model_type

@classmethod
def _query_model_records(cls, tenant_id, llm_name, llm_factory=None, model_type=None):
query_kwargs = {"tenant_id": tenant_id, "llm_name": llm_name}
model_type_val = cls._normalize_model_type(model_type)
if model_type_val is not None:
query_kwargs["model_type"] = model_type_val
if llm_factory:
query_kwargs["llm_factory"] = llm_factory
return cls.query(**query_kwargs)

@classmethod
def _iter_catalog_models(cls, model_name: str, llm_factory: str | None = None):
for factory_info in settings.FACTORY_LLM_INFOS:
if llm_factory and factory_info["name"] != llm_factory:
continue
for llm in factory_info.get("llm", []):
if llm.get("llm_name") == model_name:
yield llm, factory_info["name"]

@classmethod
def model_supports_type(cls, model_name, model_type, llm_factory=None) -> bool:
from api.db.services.llm_service import LLMService

model_type_val = cls._normalize_model_type(model_type)
if model_type_val is None:
return True

expected_tags = cls._MODEL_TYPE_TAGS.get(model_type_val, {str(model_type_val).upper()})
for llm, _factory in cls._iter_catalog_models(model_name, llm_factory):
declared_type = cls._normalize_model_type(llm.get("model_type"))
if declared_type == model_type_val:
return True
tags = {tag.strip().upper() for tag in str(llm.get("tags", "")).split(",") if tag.strip()}
if tags & expected_tags:
return True

llm_query_kwargs = {"llm_name": model_name}
if llm_factory:
llm_query_kwargs["fid"] = llm_factory
for llm in LLMService.query(**llm_query_kwargs):
declared_type = cls._normalize_model_type(llm.model_type)
if declared_type == model_type_val:
return True
tags = {tag.strip().upper() for tag in str(getattr(llm, "tags", "")).split(",") if tag.strip()}
if tags & expected_tags:
return True
return False

@classmethod
def _get_api_key_for_name(cls, tenant_id, model_name, llm_factory=None, model_type=None):
objs = cls._query_model_records(tenant_id, model_name, llm_factory=llm_factory, model_type=model_type)
if objs:
return objs[0]

model_type_val = cls._normalize_model_type(model_type)
if model_type_val is None:
return None

compatible_objs = cls._query_model_records(tenant_id, model_name, llm_factory=llm_factory)
if not compatible_objs:
return None

candidate = compatible_objs[0]
if cls.model_supports_type(candidate.llm_name, model_type_val, candidate.llm_factory):
return candidate
return None

@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name, model_type=None):
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
query_kwargs = {"tenant_id": tenant_id, "llm_name": mdlnm}
if model_type_val is not None:
query_kwargs["model_type"] = model_type_val
if not fid:
objs = cls.query(**query_kwargs)
else:
objs = cls.query(**query_kwargs, llm_factory=fid)
model_type_val = cls._normalize_model_type(model_type)
obj = cls._get_api_key_for_name(tenant_id, mdlnm, llm_factory=fid, model_type=model_type_val)
if obj:
return obj

if (not objs) and fid:
if fid:
if fid == "LocalAI":
mdlnm += "___LocalAI"
elif fid == "HuggingFace":
Expand All @@ -56,11 +133,8 @@ def get_api_key(cls, tenant_id, model_name, model_type=None):
mdlnm += "___OpenAI-API"
elif fid == "VLLM":
mdlnm += "___VLLM"
query_kwargs["llm_name"] = mdlnm
objs = cls.query(**query_kwargs, llm_factory=fid)
if not objs:
return None
return objs[0]
return cls._get_api_key_for_name(tenant_id, mdlnm, llm_factory=fid, model_type=model_type_val)
return None

@classmethod
@DB.connection_context()
Expand Down Expand Up @@ -123,6 +197,10 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None):
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
if model_config:
model_config = model_config.to_dict()
resolved_type = cls._normalize_model_type(model_config.get("model_type"))
requested_type = cls._normalize_model_type(llm_type)
if resolved_type != requested_type and cls.model_supports_type(model_config["llm_name"], requested_type, model_config.get("llm_factory")):
model_config["model_type"] = requested_type
elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
embedding_cfg = settings.EMBEDDING_CFG
model_config = {"llm_factory": "Builtin", "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
Expand Down
8 changes: 7 additions & 1 deletion api/utils/tenant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
from common.constants import LLMType
from common.exceptions import ArgumentException
from api.db.services.tenant_llm_service import TenantLLMService

_KEY_TO_MODEL_TYPE = {
Expand All @@ -25,13 +26,18 @@
"tts_id": LLMType.TTS,
}

def ensure_tenant_model_id_for_params(tenant_id: str, param_dict: dict) -> dict:
def ensure_tenant_model_id_for_params(tenant_id: str, param_dict: dict, *, strict: bool = False) -> dict:
for key in ["llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"]:
if param_dict.get(key) and not param_dict.get(f"tenant_{key}"):
model_type = _KEY_TO_MODEL_TYPE.get(key)
tenant_model = TenantLLMService.get_api_key(tenant_id, param_dict[key], model_type)
if tenant_model:
param_dict.update({f"tenant_{key}": tenant_model.id})
else:
if strict:
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
raise ArgumentException(
f"Tenant Model with name {param_dict[key]} and type {model_type_val} not found"
)
param_dict.update({f"tenant_{key}": 0})
return param_dict