Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/apps/restful_apis/memory_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ async def create_memory():
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json()
req = ensure_tenant_model_id_for_params(current_user.id, req)
t_parsed = time.perf_counter() if timing_enabled else None
Comment thread
coderabbitai[bot] marked this conversation as resolved.
try:
req = ensure_tenant_model_id_for_params(current_user.id, req, strict=True)
memory_info = {
"name": req["name"],
"memory_type": req["memory_type"],
Expand Down
10 changes: 9 additions & 1 deletion api/utils/tenant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
from common.constants import LLMType
from common.exceptions import ArgumentException
from api.db.services.tenant_llm_service import TenantLLMService

_KEY_TO_MODEL_TYPE = {
Expand All @@ -25,13 +26,20 @@
"tts_id": LLMType.TTS,
}

def ensure_tenant_model_id_for_params(tenant_id: str, param_dict: dict) -> dict:
def ensure_tenant_model_id_for_params(tenant_id: str, param_dict: dict, *, strict: bool = False) -> dict:
for key in ["llm_id", "embd_id", "asr_id", "img2txt_id", "rerank_id", "tts_id"]:
if param_dict.get(key) and not param_dict.get(f"tenant_{key}"):
model_type = _KEY_TO_MODEL_TYPE.get(key)
tenant_model = TenantLLMService.get_api_key(tenant_id, param_dict[key], model_type)
if not tenant_model and model_type == LLMType.CHAT:
tenant_model = TenantLLMService.get_api_key(tenant_id, param_dict[key])
if tenant_model:
param_dict.update({f"tenant_{key}": tenant_model.id})
else:
if strict:
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
raise ArgumentException(
f"Tenant Model with name {param_dict[key]} and type {model_type_val} not found"
)
param_dict.update({f"tenant_{key}": 0})
return param_dict
Loading