diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 672adde6ea0..8f92661e700 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -18,7 +18,7 @@ import time from quart import request -from common.constants import RetCode +from common.constants import LLMType, RetCode from common.exceptions import ArgumentException, NotFoundException from api.apps import login_required, current_user from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result @@ -33,9 +33,13 @@ async def create_memory(): timing_enabled = os.getenv("RAGFLOW_API_TIMING") t_start = time.perf_counter() if timing_enabled else None req = await get_request_json() - req = ensure_tenant_model_id_for_params(current_user.id, req) t_parsed = time.perf_counter() if timing_enabled else None try: + req = ensure_tenant_model_id_for_params(current_user.id, req) + if not req.get("tenant_llm_id"): + raise ArgumentException( + f"Tenant Model with name {req['llm_id']} and type {LLMType.CHAT.value} not found" + ) memory_info = { "name": req["name"], "memory_type": req["memory_type"], diff --git a/api/utils/tenant_utils.py b/api/utils/tenant_utils.py index 83da91f1c4a..80f75b6fd6e 100644 --- a/api/utils/tenant_utils.py +++ b/api/utils/tenant_utils.py @@ -14,6 +14,7 @@ # limitations under the License. # from common.constants import LLMType +from common.exceptions import ArgumentException from api.db.services.tenant_llm_service import TenantLLMService _KEY_TO_MODEL_TYPE = { @@ -25,13 +26,20 @@ "tts_id": LLMType.TTS, } -def ensure_tenant_model_id_for_params(tenant_id: str, param_dict: dict) -> dict: +def ensure_tenant_model_id_for_params(tenant_id: str, param_dict: dict, *, strict: bool = False) -> dict: for key in ["llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"]: if param_dict.get(key) and not param_dict.get(f"tenant_{key}"): model_type = _KEY_TO_MODEL_TYPE.get(key) tenant_model = TenantLLMService.get_api_key(tenant_id, param_dict[key], model_type) + if not tenant_model and model_type == LLMType.CHAT: + tenant_model = TenantLLMService.get_api_key(tenant_id, param_dict[key]) if tenant_model: param_dict.update({f"tenant_{key}": tenant_model.id}) else: + if strict: + model_type_val = model_type.value if hasattr(model_type, "value") else model_type + raise ArgumentException( + f"Tenant Model with name {param_dict[key]} and type {model_type_val} not found" + ) param_dict.update({f"tenant_{key}": 0}) return param_dict