From 844a355186fd080056f791287196d62e5c6d6886 Mon Sep 17 00:00:00 2001 From: Benjamin PILIA Date: Mon, 16 Feb 2026 13:45:19 +0100 Subject: [PATCH 01/13] Refacto use-case avec tests integration --- .gitignore | 1 - api/clients/model/_basemodelprovider.py | 2 +- api/dependencies.py | 30 +- api/domain/model/__init__.py | 4 + api/domain/model/entities.py | 34 ++ api/domain/model/errors.py | 15 + api/domain/provider/__init__.py | 17 + api/domain/provider/_providergateway.py | 24 ++ api/domain/provider/_providerrepository.py | 27 ++ api/domain/provider/entities.py | 40 ++ api/domain/provider/errors.py | 18 + api/domain/router/_routerrepository.py | 8 + api/domain/router/entities.py | 24 +- api/domain/router/errors.py | 5 + api/endpoints/admin/providers.py | 35 -- api/helpers/models/_modelregistry.py | 18 +- .../fastapi/endpoints/admin/providers.py | 184 ++++++++++ .../fastapi/endpoints/exceptions.py | 34 ++ api/infrastructure/fastapi/schemas/models.py | 2 +- .../fastapi/schemas/providers.py | 115 ++++++ api/infrastructure/model/__init__.py | 3 + .../model/_modelprovidergateway.py | 28 ++ api/infrastructure/postgres/__init__.py | 10 +- .../postgres/_postgresproviderrepository.py | 78 ++++ .../postgres/_postgresrouterrepository.py | 55 ++- api/tests/integration/conftest.py | 57 ++- api/tests/integration/test_admin_providers.py | 345 ++++++++++++++++++ .../test_postgresrouterrepository.py | 2 +- .../test_modelregistry/test_providers.py | 2 +- api/use_cases/admin/providers/__init__.py | 5 + .../admin/providers/_createproviderusecase.py | 144 ++++++++ api/use_cases/models/_getmodelsusecase.py | 2 +- playground/app/features/providers/models.py | 2 +- playground/app/features/providers/state.py | 6 +- 34 files changed, 1271 insertions(+), 105 deletions(-) create mode 100644 api/domain/model/__init__.py create mode 100644 api/domain/model/entities.py create mode 100644 api/domain/model/errors.py create mode 100644 api/domain/provider/__init__.py create mode 100644 api/domain/provider/_providergateway.py create mode 100644 api/domain/provider/_providerrepository.py create mode 100644 api/domain/provider/entities.py create mode 100644 api/domain/provider/errors.py create mode 100644 api/infrastructure/fastapi/endpoints/admin/providers.py create mode 100644 api/infrastructure/fastapi/schemas/providers.py create mode 100644 api/infrastructure/model/__init__.py create mode 100644 api/infrastructure/model/_modelprovidergateway.py create mode 100644 api/infrastructure/postgres/_postgresproviderrepository.py create mode 100644 api/tests/integration/test_admin_providers.py create mode 100644 api/use_cases/admin/providers/__init__.py create mode 100644 api/use_cases/admin/providers/_createproviderusecase.py diff --git a/.gitignore b/.gitignore index 285b58bde..6474fa8d1 100644 --- a/.gitignore +++ b/.gitignore @@ -222,5 +222,4 @@ playground/.gitignore playground/requirements.txt run.sh .claude - bruno \ No newline at end of file diff --git a/api/clients/model/_basemodelprovider.py b/api/clients/model/_basemodelprovider.py index 9f43e0a78..3978c6fed 100644 --- a/api/clients/model/_basemodelprovider.py +++ b/api/clients/model/_basemodelprovider.py @@ -76,7 +76,7 @@ def import_module(type: ProviderType) -> "type[BaseModelProvider]": return getattr(module, f"{type.capitalize()}ModelProvider") @staticmethod - async def get_max_context_length(self) -> int | None: + async def get_max_context_length() -> int | None: """ Get the max context length of the model provider to store in the database. Useful to check provider consistency. diff --git a/api/dependencies.py b/api/dependencies.py index d1877a9a1..b9e92610f 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -5,22 +5,17 @@ from sqlalchemy.ext.asyncio import AsyncSession from api.domain.key import KeyRepository -from api.infrastructure.postgres import PostgresKeyRepository, PostgresRouterRepository, PostgresUserInfoRepository +from api.infrastructure.model import ModelProviderGateway +from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository from api.schemas.core.context import RequestContext from api.use_cases.admin import CreateRouterUseCase +from api.use_cases.admin.providers import CreateProviderUseCase from api.use_cases.models import GetModelsUseCase from api.utils.configuration import configuration from api.utils.context import global_context, request_context async def get_postgres_session() -> AsyncGenerator[AsyncSession]: - """ - Get a PostgreSQL postgres_session from the global context. - - Returns: - AsyncSession: A PostgreSQL postgres_session instance. - """ - session_factory = global_context.postgres_session_factory async with session_factory() as postgres_session: try: @@ -34,13 +29,6 @@ async def get_postgres_session() -> AsyncGenerator[AsyncSession]: def get_request_context() -> ContextVar[RequestContext]: - """ - Get the RequestContext ContextVar from the global context. - - Returns: - ContextVar[RequestContext]: The RequestContext ContextVar instance. - """ - return request_context @@ -55,6 +43,18 @@ def get_models_use_case( ) +def create_provider_use_case_factory( + postgres_session: AsyncSession = Depends(get_postgres_session), + request_context: RequestContext = Depends(get_request_context), +) -> CreateProviderUseCase: + return CreateProviderUseCase( + router_repository=PostgresRouterRepository(postgres_session=postgres_session, app_title=configuration.settings.app_title), + user_info_repository=PostgresUserInfoRepository(postgres_session=postgres_session), + provider_repository=PostgresProviderRepository(postgres_session=postgres_session), + provider_gateway=ModelProviderGateway(), + ) + + def create_router_use_case(postgres_session: AsyncSession = Depends(get_postgres_session)) -> CreateRouterUseCase: return CreateRouterUseCase( router_repository=PostgresRouterRepository(postgres_session=postgres_session, app_title=configuration.settings.app_title), diff --git a/api/domain/model/__init__.py b/api/domain/model/__init__.py new file mode 100644 index 000000000..21c9905e0 --- /dev/null +++ b/api/domain/model/__init__.py @@ -0,0 +1,4 @@ +from .entities import Model, ModelCosts, ModelType +from .errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError + +__all__ = ["ModelType", "Model", "ModelCosts", "InconsistentModelMaxContextLengthError", "InconsistentModelVectorSizeError"] diff --git a/api/domain/model/entities.py b/api/domain/model/entities.py new file mode 100644 index 000000000..e2f017359 --- /dev/null +++ b/api/domain/model/entities.py @@ -0,0 +1,34 @@ +from enum import Enum + +from pydantic import BaseModel, Field + + +class Metric(str, Enum): + TTFT = "ttft" # time to first token + LATENCY = "latency" # requests latency + INFLIGHT = "inflight" # requests concurrency + PERFORMANCE = "performance" # custom performance metric + + +class ModelCosts(BaseModel): + prompt_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)") + completion_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million completion tokens (decrease user budget)") + + +class ModelType(str, Enum): + AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition" + IMAGE_TEXT_TO_TEXT = "image-text-to-text" + IMAGE_TO_TEXT = "image-to-text" + TEXT_EMBEDDINGS_INFERENCE = "text-embeddings-inference" + TEXT_GENERATION = "text-generation" + TEXT_CLASSIFICATION = "text-classification" + + +class Model(BaseModel): + id: str = Field(..., description="The model identifier, which can be referenced in the API endpoints.") + type: ModelType = Field(..., description="The type of the model, which can be used to identify the model type.", examples=["text-generation"]) # fmt: off + aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off + created: int = Field(..., description="Time of creation, as Unix timestamp.") + owned_by: str = Field(..., description="The organization that owns the model.") + max_context_length: int | None = Field(default=None, description="Maximum amount of tokens a context could contains. Makes sure it is the same for all models.") # fmt: off + costs: ModelCosts = Field(..., description="Costs of the model.") diff --git a/api/domain/model/errors.py b/api/domain/model/errors.py new file mode 100644 index 000000000..a8d37c349 --- /dev/null +++ b/api/domain/model/errors.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + + +@dataclass +class InconsistentModelVectorSizeError: + expected_vector_size: int + actual_vector_size: int + router_name: str + + +@dataclass +class InconsistentModelMaxContextLengthError: + expected_max_context_length: int + actual_max_context_length: int + router_name: str diff --git a/api/domain/provider/__init__.py b/api/domain/provider/__init__.py new file mode 100644 index 000000000..7ba949199 --- /dev/null +++ b/api/domain/provider/__init__.py @@ -0,0 +1,17 @@ +from api.domain.provider._providergateway import ProviderCapabilities, ProviderGateway +from api.domain.provider._providerrepository import ProviderRepository + +from .entities import Provider, ProviderCarbonFootprintZone, ProviderType +from .errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError + +__all__ = [ + "InvalidProviderTypeError", + "ProviderNotReachableError", + "ProviderRepository", + "ProviderGateway", + "ProviderCapabilities", + "ProviderType", + "Provider", + "ProviderCarbonFootprintZone", + "ProviderAlreadyExistsError", +] diff --git a/api/domain/provider/_providergateway.py b/api/domain/provider/_providergateway.py new file mode 100644 index 000000000..da01c26e7 --- /dev/null +++ b/api/domain/provider/_providergateway.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from api.domain.provider.entities import ProviderType +from api.domain.provider.errors import ProviderNotReachableError + + +@dataclass +class ProviderCapabilities: + max_context_length: int | None + vector_size: int | None + + +class ProviderGateway(ABC): + @abstractmethod + async def get_capabilities( + self, + provider_type: ProviderType, + url: str, + key: str | None, + timeout: int, + model_name: str, + ) -> ProviderCapabilities | ProviderNotReachableError: + pass diff --git a/api/domain/provider/_providerrepository.py b/api/domain/provider/_providerrepository.py new file mode 100644 index 000000000..cf5d715cf --- /dev/null +++ b/api/domain/provider/_providerrepository.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod + +from api.domain.model.entities import Metric +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.errors import ProviderAlreadyExistsError + + +class ProviderRepository(ABC): + @abstractmethod + async def create_provider( + self, + router_id: int, + user_id: int, + provider_type: ProviderType, + url: str, + key: str | None, + timeout: int, + model_name: str, + model_hosting_zone: ProviderCarbonFootprintZone, + model_total_params: int, + model_active_params: int, + qos_metric: Metric | None, + qos_limit: float | None, + vector_size: int, + max_context_length: int, + ) -> Provider | ProviderAlreadyExistsError: + pass diff --git a/api/domain/provider/entities.py b/api/domain/provider/entities.py new file mode 100644 index 000000000..48401a485 --- /dev/null +++ b/api/domain/provider/entities.py @@ -0,0 +1,40 @@ +from enum import Enum +from typing import Literal + +import pycountry +from pydantic import Field, constr + +from api.schemas import BaseModel +from api.schemas.core.models import Metric + +# Add world as a country code, default value of the carbon footprint computation framework +country_codes = [country.alpha_3 for country in pycountry.countries] + ["WOR"] +country_codes_dict = {str(code).upper(): str(code) for code in sorted(set(country_codes))} +ProviderCarbonFootprintZone: type[Enum] = Enum("ProviderCarbonFootprintZone", country_codes_dict, type=str) + + +class ProviderType(str, Enum): + ALBERT = "albert" + OPENAI = "openai" + MISTRAL = "mistral" + TEI = "tei" + VLLM = "vllm" + + +class Provider(BaseModel): + object: Literal["provider"] = "provider" + id: int = Field(..., description="Provider ID.") # fmt: off + router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off + user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off + type: ProviderType = Field(..., description="Provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off + key: str | None = Field(description="Provider API key.") # fmt: off + timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off + updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off diff --git a/api/domain/provider/errors.py b/api/domain/provider/errors.py new file mode 100644 index 000000000..63a63786f --- /dev/null +++ b/api/domain/provider/errors.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + + +@dataclass +class InvalidProviderTypeError: + type: str + + +@dataclass +class ProviderNotReachableError: + model_name: str + + +@dataclass +class ProviderAlreadyExistsError: + model_name: str + url: str + router_id: int diff --git a/api/domain/router/_routerrepository.py b/api/domain/router/_routerrepository.py index f89225fc8..8bbc60083 100644 --- a/api/domain/router/_routerrepository.py +++ b/api/domain/router/_routerrepository.py @@ -13,6 +13,14 @@ async def get_organization_name(self, user_id) -> str: async def get_all_routers(self) -> list[Router]: pass + @abstractmethod + async def get_router_by_id(self, router_id: int) -> Router | None: + pass + + @abstractmethod + async def get_aliases_by_router_id(self, router_id: int) -> Router | None: + pass + @abstractmethod async def create_router( self, diff --git a/api/domain/router/entities.py b/api/domain/router/entities.py index 57fbbac9c..6abea7b5b 100644 --- a/api/domain/router/entities.py +++ b/api/domain/router/entities.py @@ -2,29 +2,7 @@ from pydantic import BaseModel, Field - -class ModelCosts(BaseModel): - prompt_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)") - completion_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million completion tokens (decrease user budget)") - - -class ModelType(str, Enum): - AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition" - IMAGE_TEXT_TO_TEXT = "image-text-to-text" - IMAGE_TO_TEXT = "image-to-text" - TEXT_EMBEDDINGS_INFERENCE = "text-embeddings-inference" - TEXT_GENERATION = "text-generation" - TEXT_CLASSIFICATION = "text-classification" - - -class Model(BaseModel): - id: str = Field(..., description="The model identifier, which can be referenced in the API endpoints.") - type: ModelType = Field(..., description="The type of the model, which can be used to identify the model type.", examples=["text-generation"]) # fmt: off - aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off - created: int = Field(..., description="Time of creation, as Unix timestamp.") - owned_by: str = Field(..., description="The organization that owns the model.") - max_context_length: int | None = Field(default=None, description="Maximum amount of tokens a context could contains. Makes sure it is the same for all models.") # fmt: off - costs: ModelCosts = Field(..., description="Costs of the model.") +from api.domain.model import ModelType class RouterLoadBalancingStrategy(str, Enum): diff --git a/api/domain/router/errors.py b/api/domain/router/errors.py index b031e3073..6ff5a988a 100644 --- a/api/domain/router/errors.py +++ b/api/domain/router/errors.py @@ -9,3 +9,8 @@ class RouterAliasAlreadyExistsError: @dataclass class RouterNameAlreadyExistsError: name: str + + +@dataclass +class RouterNotFoundError: + router_id: int diff --git a/api/endpoints/admin/providers.py b/api/endpoints/admin/providers.py index b5bd4d0c1..7369bd276 100644 --- a/api/endpoints/admin/providers.py +++ b/api/endpoints/admin/providers.py @@ -8,50 +8,15 @@ from api.helpers._accesscontroller import AccessController from api.helpers.models import ModelRegistry from api.schemas.admin.providers import ( - CreateProvider, - CreateProviderResponse, Provider, Providers, UpdateProvider, ) from api.schemas.admin.roles import PermissionType -from api.utils.context import request_context from api.utils.dependencies import get_model_registry, get_postgres_session from api.utils.variables import EndpointRoute -@router.post( - path=EndpointRoute.ADMIN_PROVIDERS, - dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))], - status_code=201, -) -async def create_provider( - request: Request, - body: CreateProvider, - postgres_session: AsyncSession = Depends(get_postgres_session), - model_registry: ModelRegistry = Depends(get_model_registry), -) -> CreateProviderResponse: - """ - Create a model provider. - """ - provider_id = await model_registry.create_provider( - router_id=body.router, - user_id=request_context.get().user_info.id, - type=body.type, - url=body.url, - key=body.key, - timeout=body.timeout, - model_name=body.model_name, - model_hosting_zone=body.model_hosting_zone, - model_total_params=body.model_total_params, - model_active_params=body.model_active_params, - qos_metric=body.qos_metric, - qos_limit=body.qos_limit, - postgres_session=postgres_session, - ) - return JSONResponse(status_code=201, content=CreateProviderResponse(id=provider_id).model_dump()) - - @router.delete( path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))], diff --git a/api/helpers/models/_modelregistry.py b/api/helpers/models/_modelregistry.py index fbb268ab2..b8b10fae3 100644 --- a/api/helpers/models/_modelregistry.py +++ b/api/helpers/models/_modelregistry.py @@ -150,16 +150,16 @@ async def setup(self, models: list[ModelConfiguration], postgres_session: AsyncS postgres_session=postgres_session, ) except ProviderAlreadyExistsException: - logger.warning(f"Provider {provider.model_name} already exists for router {model.name} (skipping)") + logger.warning(f"provider {provider.model_name} already exists for router {model.name} (skipping)") continue except ProviderNotReachableException: - logger.warning(f"Provider {provider.model_name} is not reachable for router {model.name} (skipping)") + logger.warning(f"provider {provider.model_name} is not reachable for router {model.name} (skipping)") continue except Exception as e: await postgres_session.rollback() - logger.error(f"Provider {provider.model_name} failed to be created for router {model.name} ({e})") + logger.error(f"provider {provider.model_name} failed to be created for router {model.name} ({e})") raise e - logging.info(f"Provider {provider.model_name} successfully created for router {model.name} (id: {provider_id})") + logging.info(f"provider {provider.model_name} successfully created for router {model.name} (id: {provider_id})") if self.queuing_enabled: routers = await self.get_routers(router_id=None, name=None, postgres_session=postgres_session) @@ -459,9 +459,9 @@ async def create_provider( Args: router_id(int): The model router ID user_id(int): The user ID of owner of the provider - type(ProviderType): Provider type - url(str): Provider URL - key(str | None): Provider API key + type(ProviderType): provider type + url(str): provider URL + key(str | None): provider API key timeout(int): Request timeout model_name(str): Model name model_hosting_zone(ProviderCarbonFootprintZone): ProviderCarbonFootprintZone @@ -498,7 +498,7 @@ async def create_provider( vector_size = None except AssertionError as e: - logger.debug(f"Provider {provider.model_name} not reachable: {e}", exc_info=True) + logger.debug(f"provider {provider.model_name} not reachable: {e}", exc_info=True) raise ProviderNotReachableException() # consistency check @@ -724,7 +724,7 @@ async def update_provider( await postgres_session.commit() except IntegrityError: await postgres_session.rollback() - raise ProviderAlreadyExistsException("Provider already exists for the new router.") + raise ProviderAlreadyExistsException("provider already exists for the new router.") async def get_models(self, name: str | None, user_info: UserInfo, postgres_session: AsyncSession) -> list[Model]: """ diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py new file mode 100644 index 000000000..d0d53e3ea --- /dev/null +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -0,0 +1,184 @@ +import logging +from typing import Literal + +from fastapi import APIRouter, Body, Depends, Path, Query, Request, Security +from fastapi.responses import JSONResponse, Response +from sqlalchemy.ext.asyncio import AsyncSession + +from api.dependencies import create_provider_use_case_factory, get_request_context +from api.domain.model import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError +from api.domain.provider import InvalidProviderTypeError, ProviderNotReachableError +from api.domain.provider.errors import ProviderAlreadyExistsError +from api.domain.router.errors import RouterNotFoundError +from api.helpers.models import ModelRegistry +from api.infrastructure.fastapi.access import get_current_key +from api.infrastructure.fastapi.context import RequestContext +from api.infrastructure.fastapi.endpoints.exceptions import ( + InconsistentModelMaxContextLengthHTTPException, + InconsistentModelVectorSizeHTTPException, + InternalServerHTTPException, + InvalidProviderTypeHTTPException, + ProviderAlreadyExistsHTTPException, + ProviderNotReachableHTTPException, + RouterNotFoundHTTPException, +) +from api.infrastructure.fastapi.schemas.providers import ( + CreateProvider, + CreateProviderResponse, + Provider, + Providers, + UpdateProvider, +) +from api.use_cases.admin.providers import CreateProviderUseCase +from api.use_cases.admin.providers._createproviderusecase import CreateProviderUseCaseSuccess +from api.utils.dependencies import get_model_registry, get_postgres_session +from api.utils.variables import ENDPOINT__ADMIN_PROVIDERS, ROUTER__ADMIN + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/v1", tags=[ROUTER__ADMIN.title()]) + + +@router.post( + path=ENDPOINT__ADMIN_PROVIDERS, + dependencies=[Security(dependency=get_current_key)], + status_code=201, +) +async def create_provider( + request: Request, + body: CreateProvider, + create_provider_use_case: CreateProviderUseCase = Depends(create_provider_use_case_factory), + request_context: RequestContext = Depends(get_request_context), +) -> CreateProviderResponse: + try: + result = await create_provider_use_case.execute( + router_id=body.router, + user_id=request_context.get().user_id, + provider_type=body.type, + url=body.url, + key=body.key, + timeout=body.timeout, + model_name=body.model_name, + model_hosting_zone=body.model_hosting_zone, + model_total_params=body.model_total_params, + model_active_params=body.model_active_params, + qos_metric=body.qos_metric, + qos_limit=body.qos_limit, + ) + except Exception as e: + logger.exception( + "Unexpected error while executing create_router use case", + extra={ + "user_id": request_context.get().user_id, + "router_name": body.name, + "error_type": type(e).__name__, + }, + ) + raise InternalServerHTTPException() + + match result: + case CreateProviderUseCaseSuccess(created_provider): + return CreateProviderResponse.model_validate(created_provider, from_attributes=True) + case InvalidProviderTypeError(provider_type): + raise InvalidProviderTypeHTTPException(provider_type) + case ProviderNotReachableError(name): + raise ProviderNotReachableHTTPException(name) + case ProviderAlreadyExistsError(model_name, url, router_id): + raise ProviderAlreadyExistsHTTPException(model_name, url, router_id) + case InconsistentModelMaxContextLengthError(actual_max_context_length, expected_max_context_length, router_name): + raise InconsistentModelMaxContextLengthHTTPException( + input_max_context_length=actual_max_context_length, model_max_context_length=expected_max_context_length, model_name=router_name + ) + case InconsistentModelVectorSizeError(actual_vector_size, expected_vector_size, router_name): + raise InconsistentModelVectorSizeHTTPException(actual_vector_size, expected_vector_size, router_name) + case RouterNotFoundError(router_id): + raise RouterNotFoundHTTPException(router_id) + + +@router.delete( + path=ENDPOINT__ADMIN_PROVIDERS + "/{provider}", + dependencies=[Security(dependency=get_current_key)], + status_code=204, +) +async def delete_provider( + request: Request, + provider: int = Path(description="The ID of the provider to delete."), + postgres_session: AsyncSession = Depends(get_postgres_session), + model_registry: ModelRegistry = Depends(get_model_registry), +) -> Response: + await model_registry.delete_provider(provider_id=provider, postgres_session=postgres_session) + + return Response(status_code=204) + + +@router.patch( + path=ENDPOINT__ADMIN_PROVIDERS + "/{provider}", + dependencies=[Security(dependency=get_current_key)], + status_code=204, +) +async def update_provider( + request: Request, + provider: int = Path(description="The ID of the provider to update."), + body: UpdateProvider = Body(description="The provider update request."), + postgres_session: AsyncSession = Depends(get_postgres_session), + model_registry: ModelRegistry = Depends(get_model_registry), +) -> Response: + await model_registry.update_provider( + provider_id=provider, + router_id=body.router, + timeout=body.timeout, + model_hosting_zone=body.model_hosting_zone, + model_total_params=body.model_total_params, + model_active_params=body.model_active_params, + qos_metric=body.qos_metric, + qos_limit=body.qos_limit, + postgres_session=postgres_session, + ) + + return Response(status_code=204) + + +@router.get( + path=ENDPOINT__ADMIN_PROVIDERS + "/{provider}", + dependencies=[Security(dependency=get_current_key)], + status_code=200, + response_model=Provider, +) +async def get_provider( + request: Request, + provider: int = Path(description="The ID of the provider to get."), + postgres_session: AsyncSession = Depends(get_postgres_session), + model_registry: ModelRegistry = Depends(get_model_registry), +) -> JSONResponse: + providers = await model_registry.get_providers(router_id=None, provider_id=provider, postgres_session=postgres_session) + provider = providers[0] + + return JSONResponse(status_code=200, content=provider.model_dump()) + + +@router.get( + path=ENDPOINT__ADMIN_PROVIDERS, + dependencies=[Security(dependency=get_current_key)], + status_code=200, + response_model=Providers, +) +async def get_providers( + request: Request, + router: int | None = Query(default=None, description="Filter providers by router ID."), + offset: int = Query(default=0, ge=0, description="The offset of the tokens to get."), + limit: int = Query(default=10, ge=1, le=100, description="The limit of the tokens to get."), + order_by: Literal["id", "model_name", "created"] = Query(default="id", description="The field to order the tokens by."), + order_direction: Literal["asc", "desc"] = Query(default="asc", description="The direction to order the tokens by."), + postgres_session: AsyncSession = Depends(get_postgres_session), + model_registry: ModelRegistry = Depends(get_model_registry), +) -> JSONResponse: + providers = await model_registry.get_providers( + router_id=router, + provider_id=None, + postgres_session=postgres_session, + offset=offset, + limit=limit, + order_by=order_by, + order_direction=order_direction, + ) + + return JSONResponse(status_code=200, content=Providers(data=providers).model_dump()) diff --git a/api/infrastructure/fastapi/endpoints/exceptions.py b/api/infrastructure/fastapi/endpoints/exceptions.py index 7036b3e8e..09a94e1f0 100644 --- a/api/infrastructure/fastapi/endpoints/exceptions.py +++ b/api/infrastructure/fastapi/endpoints/exceptions.py @@ -1,6 +1,12 @@ from fastapi import HTTPException + # 400 +class InvalidProviderTypeHTTPException(HTTPException): + def __init__(self, incorrect_provider_type: str) -> None: + super().__init__( + status_code=400, detail=f"Invalid model provider type {incorrect_provider_type} for this model router type. Allowed types are: " + ) # 401 @@ -22,12 +28,32 @@ def __init__(self, detail: str = "Insufficient rights.") -> None: super().__init__(status_code=403, detail=detail) +class InconsistentModelMaxContextLengthHTTPException(HTTPException): + def __init__(self, input_max_context_length: int, model_max_context_length: int, model_name: str) -> None: + super().__init__( + status_code=403, + detail=f"Inconsistent max context length for {model_name}. Expected: {model_max_context_length}. Actual: {input_max_context_length}", + ) + + +class InconsistentModelVectorSizeHTTPException(HTTPException): + def __init__(self, input_vector_size: int, model_vector_size: int, model_name: str) -> None: + super().__init__( + status_code=403, detail=f"Inconsistent vector size for {model_name}. Expected: {model_vector_size}. Actual: {input_vector_size}" + ) + + # 404 class ModelNotFoundHTTPException(HTTPException): def __init__(self, detail: str = "Model not found.") -> None: super().__init__(status_code=404, detail=detail) +class RouterNotFoundHTTPException(HTTPException): + def __init__(self, router_id: int) -> None: + super().__init__(status_code=404, detail=f"Model router {router_id} not found.") + + # 409 class RouterAliasAlreadyExistsHTTPException(HTTPException): def __init__(self, aliases: list[str]): @@ -39,6 +65,11 @@ def __init__(self, name: str): super().__init__(status_code=409, detail=f"Router '{name}' already exists.") +class ProviderAlreadyExistsHTTPException(HTTPException): + def __init__(self, model_name: str, url: str, router_id: int) -> None: + super().__init__(status_code=409, detail=f"Model provider {model_name} for url {url} already exists for router {router_id}.") + + # 413 @@ -46,6 +77,9 @@ def __init__(self, name: str): # 424 +class ProviderNotReachableHTTPException(HTTPException): + def __init__(self, name: str) -> None: + super().__init__(status_code=424, detail=f"Model provider {name} not reachable.") # 429 diff --git a/api/infrastructure/fastapi/schemas/models.py b/api/infrastructure/fastapi/schemas/models.py index 17b185c75..94d80496e 100644 --- a/api/infrastructure/fastapi/schemas/models.py +++ b/api/infrastructure/fastapi/schemas/models.py @@ -3,7 +3,7 @@ from pydantic import Field -from api.domain.router.entities import Model as ModelEntity +from api.domain.model import Model as ModelEntity from api.schemas import BaseModel diff --git a/api/infrastructure/fastapi/schemas/providers.py b/api/infrastructure/fastapi/schemas/providers.py new file mode 100644 index 000000000..e92d22405 --- /dev/null +++ b/api/infrastructure/fastapi/schemas/providers.py @@ -0,0 +1,115 @@ +from enum import Enum +from typing import Literal + +import pycountry +from pydantic import Field, constr, model_validator + +from api.schemas import BaseModel +from api.schemas.core.models import Metric +from api.utils.variables import DEFAULT_TIMEOUT + +# Add world as a country code, default value of the carbon footprint computation framework +country_codes = [country.alpha_3 for country in pycountry.countries] + ["WOR"] +country_codes_dict = {str(code).upper(): str(code) for code in sorted(set(country_codes))} +ProviderCarbonFootprintZone: type[Enum] = Enum("ProviderCarbonFootprintZone", country_codes_dict, type=str) + + +class ProviderType(str, Enum): + ALBERT = "albert" + OPENAI = "openai" + MISTRAL = "mistral" + TEI = "tei" + VLLM = "vllm" + + +class CreateProvider(BaseModel): + router: int = Field(..., description="ID of the model to create the provider for (router ID, eg. 123).") # fmt: off + type: ProviderType = Field(..., description="Model provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1) | None = Field(default=None, description="Model provider API url. The url must only contain the domain name (without `/v1` suffix for example). Depends of the model provider type, the url can be optional (Albert, OpenAI).") # fmt: off + key: constr(strip_whitespace=True, min_length=1) | None = Field(default=None, description="Model provider API key.") # fmt: off + timeout: int = Field(default=DEFAULT_TIMEOUT, description="Timeout for the model provider requests, after user receive an 503 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai") # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(default=None, description="The metric to use for the quality of service policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + + @model_validator(mode="after") + def format_provider(self): + if self.qos_metric is not None and self.qos_limit is None: + raise ValueError("QoS value is required if QoS metric is provided.") + + if self.url is None: + if self.type == ProviderType.ALBERT: + self.url = "https://albert.api.etalab.gouv.fr/" + elif self.type == ProviderType.MISTRAL: + self.url = "https://albert.api.etalab.gouv.fr/" + elif self.type == ProviderType.OPENAI: + self.url = "https://api.openai.com/" + else: + raise ValueError("URL is required for this model provider type.") + + elif not self.url.endswith("/"): + self.url = f"{self.url}/" + + return self + + +class CreateProviderResponse(BaseModel): + id: int = Field(..., description="Provider ID.") # fmt: off + router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off + user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off + type: ProviderType = Field(..., description="Provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off + key: str | None = Field(description="Provider API key.") # fmt: off + timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off + updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off + + +class UpdateProvider(BaseModel): + router: int | None = Field(default=None, description="The ID of the new router to assign to the provider.") # fmt: off + timeout: int | None = Field(default=None, description="Timeout for the model provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone | None = Field(default=None, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai") # fmt: off + model_total_params: int | None = Field(default=None, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int | None = Field(default=None, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(default=None, description="The metric to use for the quality of service policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + + @model_validator(mode="after") + def validate_model(self): + if self.qos_metric is not None and self.qos_limit is None: + raise ValueError("QoS value is required if QoS metric is provided.") + + return self + + +class Provider(BaseModel): + object: Literal["provider"] = "provider" + id: int = Field(..., description="provider ID.") # fmt: off + router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off + user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off + type: ProviderType = Field(..., description="provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off + key: str | None = Field(description="provider API key.") # fmt: off + timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off + updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off + + +class Providers(BaseModel): + object: Literal["list"] = "list" + data: list[Provider] diff --git a/api/infrastructure/model/__init__.py b/api/infrastructure/model/__init__.py new file mode 100644 index 000000000..6597b06c7 --- /dev/null +++ b/api/infrastructure/model/__init__.py @@ -0,0 +1,3 @@ +from api.infrastructure.model._modelprovidergateway import ModelProviderGateway + +__all__ = ["ModelProviderGateway"] diff --git a/api/infrastructure/model/_modelprovidergateway.py b/api/infrastructure/model/_modelprovidergateway.py new file mode 100644 index 000000000..f882a7806 --- /dev/null +++ b/api/infrastructure/model/_modelprovidergateway.py @@ -0,0 +1,28 @@ +from api.clients.model import BaseModelProvider +from api.domain.provider import ProviderCapabilities, ProviderGateway, ProviderNotReachableError + + +class ModelProviderGateway(ProviderGateway): + async def get_capabilities(self, provider_type, url, key, timeout, model_name): + try: + client = self._build_client(provider_type, url, key, timeout, model_name) + max_context_length = await client.get_max_context_length() + vector_size = await client.get_vector_size() + return ProviderCapabilities( + max_context_length=max_context_length, + vector_size=vector_size, + ) + except Exception as e: + return ProviderNotReachableError(model_name) + + def _build_client(self, provider_type, url, key, timeout, model_name): + cls = BaseModelProvider.import_module(type=provider_type) + return cls( + url=url, + key=key, + timeout=timeout, + model_name=model_name, + model_hosting_zone=None, + model_total_params=0, + model_active_params=0, + ) diff --git a/api/infrastructure/postgres/__init__.py b/api/infrastructure/postgres/__init__.py index 5f0b1a76b..f4f72e897 100644 --- a/api/infrastructure/postgres/__init__.py +++ b/api/infrastructure/postgres/__init__.py @@ -1,7 +1,15 @@ from ._postgreskeyrepository import PostgresKeyRepository +from ._postgresproviderrepository import PostgresProviderRepository from ._postgresrolesrepository import PostgresRolesRepository from ._postgresrouterrepository import PostgresRouterRepository from ._postgresuserinforepository import PostgresUserInfoRepository from ._postgresusersrepository import PostgresUserRepository -__all__ = ["PostgresKeyRepository", "PostgresUserInfoRepository", "PostgresUserRepository", "PostgresRolesRepository", "PostgresRouterRepository"] +__all__ = [ + "PostgresKeyRepository", + "PostgresUserInfoRepository", + "PostgresUserRepository", + "PostgresRolesRepository", + "PostgresRouterRepository", + "PostgresProviderRepository", +] diff --git a/api/infrastructure/postgres/_postgresproviderrepository.py b/api/infrastructure/postgres/_postgresproviderrepository.py new file mode 100644 index 000000000..681ab97ab --- /dev/null +++ b/api/infrastructure/postgres/_postgresproviderrepository.py @@ -0,0 +1,78 @@ +from sqlalchemy import insert +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from api.domain.model.entities import Metric +from api.domain.provider import ProviderRepository +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.errors import ProviderAlreadyExistsError +from api.sql.models import Provider as ProviderTable + + +class PostgresProviderRepository(ProviderRepository): + def __init__(self, postgres_session: AsyncSession): + self.postgres_session = postgres_session + + async def create_provider( + self, + router_id: int, + user_id: int, + provider_type: ProviderType, + url: str, + key: str | None, + timeout: int, + model_name: str, + model_hosting_zone: ProviderCarbonFootprintZone, + model_total_params: int, + model_active_params: int, + qos_metric: Metric | None, + qos_limit: float | None, + vector_size: int, + max_context_length: int, + ) -> Provider | ProviderAlreadyExistsError: + try: + user_id = None if user_id == 0 else user_id # 0 corresponds to master user ID + qos_metric = qos_metric.value if qos_metric is not None else None + query = ( + insert(ProviderTable) + .values( + router_id=router_id, + user_id=user_id, + type=provider_type.value, + url=url, + key=key, + timeout=timeout, + model_name=model_name, + model_hosting_zone=model_hosting_zone, + model_total_params=model_total_params, + model_active_params=model_active_params, + qos_metric=qos_metric, + qos_limit=qos_limit, + max_context_length=max_context_length, + vector_size=vector_size, + ) + .returning(ProviderTable) + ) + result = await self.postgres_session.execute(query) + row = result.scalar_one() + return Provider( + router_id=row.router_id, + user_id=row.user_id, + type=row.type, + url=row.url, + key=row.key, + timeout=row.timeout, + model_name=row.model_name, + model_hosting_zone=row.model_hosting_zone, + model_total_params=row.model_total_params, + model_active_params=row.model_active_params, + qos_metric=row.qos_metric, + qos_limit=row.qos_limit, + max_context_length=row.max_context_length, + vector_size=row.vector_size, + id=row.id, + ) + except IntegrityError as e: + if "unique_provider_router_id_url_model_name" in str(e.orig): + return ProviderAlreadyExistsError(model_name=model_name, url=url, router_id=router_id) + raise diff --git a/api/infrastructure/postgres/_postgresrouterrepository.py b/api/infrastructure/postgres/_postgresrouterrepository.py index 50d166c76..e3b3ac119 100644 --- a/api/infrastructure/postgres/_postgresrouterrepository.py +++ b/api/infrastructure/postgres/_postgresrouterrepository.py @@ -14,6 +14,57 @@ class PostgresRouterRepository(RouterRepository): + async def get_aliases_by_router_id(self, router_id: int) -> list[str]: + query = select(RouterAliasTable.value).where(RouterAliasTable.router_id == router_id) + result = await self.postgres_session.execute(query) + return [row[0] for row in result.all()] + + async def get_router_by_id(self, router_id: int) -> Router | None: + provider_count_subquery = ( + select(func.count(ProviderTable.id)).where(ProviderTable.router_id == RouterTable.id).correlate(RouterTable).scalar_subquery() + ) + query = ( + select( + RouterTable.id, + RouterTable.name, + RouterTable.user_id, + RouterTable.type, + RouterTable.load_balancing_strategy, + RouterTable.cost_prompt_tokens, + RouterTable.cost_completion_tokens, + ProviderTable.max_context_length, + ProviderTable.vector_size, + provider_count_subquery.label("providers"), + cast(func.extract("epoch", RouterTable.created), Integer).label("created"), + cast(func.extract("epoch", RouterTable.updated), Integer).label("updated"), + ) + .where(RouterTable.id == router_id) + .join(ProviderTable, ProviderTable.router_id == RouterTable.id, isouter=True) + .limit(1) + ) + + result = await self.postgres_session.execute(query) + row = result.one_or_none() + if row is None: + return None + user_id = MASTER_USER_ID if row.user_id is None else row.user_id + aliases = await self.get_aliases_by_router_id(router_id) + return Router( + id=row.id, + name=row.name, + user_id=user_id, + type=ModelType(row.type), + aliases=aliases, + load_balancing_strategy=RouterLoadBalancingStrategy(row.load_balancing_strategy), + vector_size=row.vector_size, + max_context_length=row.max_context_length, + cost_prompt_tokens=row.cost_prompt_tokens or 0.0, + cost_completion_tokens=row.cost_completion_tokens or 0.0, + providers=row.providers, + created=row.created, + updated=row.updated, + ) + def __init__(self, postgres_session: AsyncSession, app_title: str): self.postgres_session = postgres_session self.app_title = app_title @@ -56,7 +107,7 @@ async def get_all_routers(self) -> list[Router]: result = await self.postgres_session.execute(query) router_results = [row._asdict() for row in result.all()] - aliases = await self.get_aliases_by_router_id() + aliases = await self.get_all_aliases_grouped_by_router() for row in router_results: user_id = MASTER_USER_ID if row["user_id"] is None else row["user_id"] @@ -79,7 +130,7 @@ async def get_all_routers(self) -> list[Router]: ) return routers - async def get_aliases_by_router_id(self) -> dict[str, list[str]]: + async def get_all_aliases_grouped_by_router(self) -> dict[str, list[str]]: aliases_query = select(RouterAliasTable.router_id.label("router_id"), RouterAliasTable.value) aliases_result = await self.postgres_session.execute(aliases_query) aliases = {} diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 8be3a0277..6b17937d9 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -5,13 +5,18 @@ from httpx import ASGITransport, AsyncClient import pytest import pytest_asyncio +from sqlalchemy import event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool from api.app import create_app from api.dependencies import get_postgres_session +from api.helpers.models import ModelRegistry +from api.main import app from api.sql.models import Base from api.tests.integration import factories +from api.utils.dependencies import get_model_registry +from api.utils.dependencies import get_postgres_session as get_postgres_session_utils TEST_DATABASE_URL = "postgresql+asyncpg://postgres:changeme@localhost:5432/test_db" @@ -69,23 +74,53 @@ async def test_session_factory(test_engine): @pytest_asyncio.fixture(scope="function") -async def db_session(test_session_factory) -> AsyncGenerator[AsyncSession]: - async with test_session_factory() as session: +async def db_session(test_engine) -> AsyncGenerator[AsyncSession]: + """Provide a transactional scope for each test. + + Uses the recommended SQLAlchemy pattern: an outer transaction that is never + committed, with SAVEPOINTs for the test code. When code under test calls + session.commit() or session.rollback(), the SAVEPOINT is released/rolled back + and automatically restarted, so the outer transaction stays open and can be + rolled back at the end to undo everything. + """ + async with test_engine.connect() as connection: + transaction = await connection.begin() + + session = AsyncSession(bind=connection, expire_on_commit=False) + await session.begin_nested() + all_sql_factories = factories.BaseSQLFactory.__subclasses__() - session.expire_on_commit = False + for factory in all_sql_factories: + factory._meta.sqlalchemy_session = session + + # Restart a SAVEPOINT whenever code under test commits or rolls back, + # so the outer transaction is never affected. + @event.listens_for(session.sync_session, "after_transaction_end") + def restart_savepoint(sess, trans): + if trans.nested and not trans._parent.nested: + sess.begin_nested() + try: - async with session.begin_nested(): - for factory in all_sql_factories: - factory._meta.sqlalchemy_session = session - yield session + yield session finally: - if session.in_transaction(): - await session.rollback() await session.close() + await transaction.rollback() + + +@pytest_asyncio.fixture(scope="session") +def model_registry(): + """Create a real ModelRegistry for integration tests.""" + return ModelRegistry( + app_title="test", + queuing_enabled=False, + max_priority=0, + max_retries=0, + retry_countdown=0, + ) @pytest_asyncio.fixture(scope="function") -async def client(db_session, test_configuration) -> AsyncGenerator[AsyncClient, None]: +async def client(db_session, model_registry, test_configuration) -> AsyncGenerator[AsyncClient, None]: app = create_app(test_configuration, skip_lifespan=True) async def override_get_postgres_session(): @@ -99,6 +134,8 @@ async def override_get_postgres_session(): raise app.dependency_overrides[get_postgres_session] = override_get_postgres_session + app.dependency_overrides[get_postgres_session_utils] = override_get_postgres_session + app.dependency_overrides[get_model_registry] = lambda: model_registry try: async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: diff --git a/api/tests/integration/test_admin_providers.py b/api/tests/integration/test_admin_providers.py new file mode 100644 index 000000000..1c87c70cc --- /dev/null +++ b/api/tests/integration/test_admin_providers.py @@ -0,0 +1,345 @@ +from collections.abc import AsyncGenerator +from unittest.mock import patch + +from fastapi import FastAPI, Request +from httpx import ASGITransport, AsyncClient +import pytest +import pytest_asyncio +from sqlalchemy import select + +from api.dependencies import get_postgres_session +from api.infrastructure.fastapi.endpoints.admin.providers import router as providers_router +from api.schemas.core.context import RequestContext +from api.schemas.models import ModelType +from api.schemas.usage import Usage +from api.sql.models import Provider as ProviderTable +from api.tests.helpers import create_token +from api.tests.integration.factories import ProviderSQLFactory, RouterSQLFactory, UserSQLFactory +from api.utils.context import request_context +from api.utils.dependencies import get_model_registry +from api.utils.dependencies import get_postgres_session as get_postgres_session_utils +from api.utils.variables import ENDPOINT__ADMIN_PROVIDERS + +URL = f"/v1{ENDPOINT__ADMIN_PROVIDERS}" + + +def _valid_body(router_id=1, **overrides) -> dict: + """Return a minimal valid provider creation body, with optional overrides.""" + body = { + "router": router_id, + "type": "albert", + "model_name": "my-model", + } + body.update(overrides) + return body + + +# --------------------------------------------------------------------------- +# Fake providers – the ONLY mock: external HTTP boundary +# --------------------------------------------------------------------------- + + +class FakeProvider: + """Simulates an external model provider (health check calls).""" + + def __init__(self, url, key, timeout, model_name, model_hosting_zone, model_total_params, model_active_params): + self.model_name = model_name + + async def get_max_context_length(self): + return 4096 + + async def get_vector_size(self): + return 768 + + +class UnreachableFakeProvider(FakeProvider): + """provider whose health check fails.""" + + async def get_max_context_length(self): + raise AssertionError("provider not reachable") + + async def get_vector_size(self): + raise AssertionError("provider not reachable") + + +class FakeProviderWithDifferentVectorSizeAndMaxContentLength(FakeProvider): + """provider whose health check fails.""" + + async def get_max_context_length(self): + return 1234 + + async def get_vector_size(self): + return 1234 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture(scope="function") +async def client(db_session, model_registry) -> AsyncGenerator[AsyncClient, None]: + """Test client using a minimal app with only the new infrastructure providers router.""" + test_app = FastAPI() + + @test_app.middleware("http") + async def set_request_context(request: Request, call_next): + request_context.set(RequestContext(method=request.method, endpoint=request.url.path, usage=Usage())) + return await call_next(request) + + test_app.include_router(providers_router) + + async def override_get_postgres_session(): + try: + yield db_session + if db_session.in_transaction(): + await db_session.flush() + except Exception: + if db_session.in_transaction(): + await db_session.rollback() + raise + + test_app.dependency_overrides[get_postgres_session] = override_get_postgres_session + test_app.dependency_overrides[get_postgres_session_utils] = override_get_postgres_session + test_app.dependency_overrides[get_model_registry] = lambda: model_registry + + async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as ac: + yield ac + + +@pytest.fixture +def mock_import_module(): + """Patch ModelProvider.import_module so no real HTTP call is made.""" + with patch("api.helpers.models._modelregistry.ModelProvider.import_module") as mock: + mock.return_value = FakeProvider + yield mock + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +class TestCreateProvider: + async def test_happy_path(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 201, response.text + assert isinstance(response.json()["id"], int) + + async def test_no_auth_token(self, client: AsyncClient): + response = await client.post(url=URL, json=_valid_body()) + + assert response.status_code == 401 + + async def test_missing_required_field(self, client: AsyncClient, db_session): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + + body = {"type": "albert", "model_name": "my-model"} # missing "router" + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=body, + ) + + assert response.status_code == 422 + + async def test_invalid_provider_type(self, client: AsyncClient, db_session): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(type="not_a_real_provider"), + ) + + assert response.status_code == 422 + + async def test_qos_metric_without_limit(self, client: AsyncClient, db_session): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(qos_metric="ttft"), + ) + + assert response.status_code == 422 + + async def test_tei_type_requires_url(self, client: AsyncClient, db_session): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(type="tei"), + ) + + assert response.status_code == 422 + + async def test_incompatible_provider_type(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(router.id, type="tei", url="https://tei.example.com/"), + ) + + assert response.status_code == 400 + + async def test_provider_not_reachable(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + + mock_import_module.return_value = UnreachableFakeProvider + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 424 + + async def test_provider_already_exists(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) + ProviderSQLFactory( + router=router, + user=admin_user, + url="https://albert.api.etalab.gouv.fr/", + model_name="my-model", + max_context_length=4096, + vector_size=None, + ) + await db_session.flush() + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 409 + + async def test_provider_mismatch_max_context_length(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_EMBEDDINGS_INFERENCE, name="test_router") + ProviderSQLFactory( + router=router, + user=admin_user, + url="https://albert.api.etalab.gouv.fr/", + model_name="my-model", + max_context_length=4096, + vector_size=1234, + ) + mock_import_module.return_value = FakeProviderWithDifferentVectorSizeAndMaxContentLength + + await db_session.flush() + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 403 + assert response.json().get("detail") == "Inconsistent max context length for test_router. Expected: 1234. Actual: 4096" + + async def test_provider_mismatch_vector_size(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token( + db_session, + name="admin_token", + user=admin_user, + ) + router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION, name="test_router") + ProviderSQLFactory( + router=router, + user=admin_user, + url="https://albert.api.etalab.gouv.fr/", + model_name="my-model", + max_context_length=4096, + vector_size=1234, + ) + mock_import_module.return_value = FakeProviderWithDifferentVectorSizeAndMaxContentLength + + await db_session.flush() + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 403 + assert response.json().get("detail") == "Inconsistent vector size for test_router. Expected: None. Actual: 1234" + + async def test_router_not_found(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(999999), + ) + + assert response.status_code == 404 + + async def test_url_trailing_slash(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(router.id, url="https://my-provider.example.com"), + ) + + assert response.status_code == 201, response.text + provider_id = response.json()["id"] + result = await db_session.execute(select(ProviderTable.url).where(ProviderTable.id == provider_id)) + assert result.scalar_one() == "https://my-provider.example.com/" + + async def test_default_url_for_albert(self, client: AsyncClient, db_session, mock_import_module): + admin_user = UserSQLFactory(admin_user=True) + token = await create_token(db_session, name="admin_token", user=admin_user) + router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 201, response.text + provider_id = response.json()["id"] + result = await db_session.execute(select(ProviderTable.url).where(ProviderTable.id == provider_id)) + assert result.scalar_one() == "https://albert.api.etalab.gouv.fr/" diff --git a/api/tests/integration/test_postgresrouterrepository.py b/api/tests/integration/test_postgresrouterrepository.py index 4b77ed797..a9ace92b7 100644 --- a/api/tests/integration/test_postgresrouterrepository.py +++ b/api/tests/integration/test_postgresrouterrepository.py @@ -92,7 +92,7 @@ async def test_get_all_aliases_should_return_all_aliases(self, repository, db_se # Act await db_session.flush() - aliases = await repository.get_aliases_by_router_id() + aliases = await repository.get_all_aliases_grouped_by_router() # Assert assert aliases == { router_1.id: ["alias1_m1", "alias2_m1"], diff --git a/api/tests/unit/test_helpers/test_modelregistry/test_providers.py b/api/tests/unit/test_helpers/test_modelregistry/test_providers.py index 2168e3a23..9aedf79ca 100644 --- a/api/tests/unit/test_helpers/test_modelregistry/test_providers.py +++ b/api/tests/unit/test_helpers/test_modelregistry/test_providers.py @@ -760,7 +760,7 @@ async def test_update_provider_change_router_invalid_type(postgres_session: Asyn ) # Mock get_providers to return provider (will be called with router_id=None, provider_id=1) - # Note: get_providers creates Provider with type from DB (string), which Pydantic converts to enum + # Note: get_providers creates provider with type from DB (string), which Pydantic converts to enum provider_result = _Result() provider_result.mappings = lambda: _MappingsResult( [ diff --git a/api/use_cases/admin/providers/__init__.py b/api/use_cases/admin/providers/__init__.py new file mode 100644 index 000000000..c18f9293a --- /dev/null +++ b/api/use_cases/admin/providers/__init__.py @@ -0,0 +1,5 @@ +from ._createproviderusecase import CreateProviderUseCase + +__all__ = [ + "CreateProviderUseCase", +] diff --git a/api/use_cases/admin/providers/_createproviderusecase.py b/api/use_cases/admin/providers/_createproviderusecase.py new file mode 100644 index 000000000..5d55d9ab0 --- /dev/null +++ b/api/use_cases/admin/providers/_createproviderusecase.py @@ -0,0 +1,144 @@ +from dataclasses import dataclass + +from api.domain.model import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError +from api.domain.provider import InvalidProviderTypeError, ProviderGateway, ProviderNotReachableError, ProviderRepository +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.errors import ProviderAlreadyExistsError +from api.domain.router import RouterRepository +from api.domain.router.errors import RouterNotFoundError +from api.domain.userinfo import UserInfoRepository +from api.infrastructure.fastapi.schemas.models import ModelType +from api.schemas.core.models import Metric + + +@dataclass +class CreateProviderUseCaseSuccess: + provider: Provider + + +type CreateProviderUseCaseResult = ( + CreateProviderUseCaseSuccess + | InvalidProviderTypeError + | ProviderNotReachableError + | InconsistentModelMaxContextLengthError + | InconsistentModelVectorSizeError + | RouterNotFoundError + | ProviderAlreadyExistsError +) + +MODEL_TYPE_TO_MODEL_PROVIDER_TYPE_MAPPING = { + ModelType.AUTOMATIC_SPEECH_RECOGNITION: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.IMAGE_TEXT_TO_TEXT: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_EMBEDDINGS_INFERENCE: [ + ProviderType.ALBERT.value, + ProviderType.OPENAI.value, + ProviderType.TEI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_GENERATION: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_CLASSIFICATION: [ + ProviderType.ALBERT.value, + ProviderType.TEI.value, + ], + ModelType.IMAGE_TO_TEXT: [ + ProviderType.MISTRAL.value, + ], +} + + +class CreateProviderUseCase: + def __init__( + self, + router_repository: RouterRepository, + provider_repository: ProviderRepository, + user_info_repository: UserInfoRepository, + provider_gateway: ProviderGateway, + ): + self.router_repository = router_repository + self.provider_repository = provider_repository + self.user_info_repository = user_info_repository + self.provider_gateway = provider_gateway + + async def execute( + self, + router_id: int, + user_id: int, + provider_type: ProviderType, + url: str, + key: str | None, + timeout: int, + model_name: str, + model_hosting_zone: ProviderCarbonFootprintZone, + model_total_params: int, + model_active_params: int, + qos_metric: Metric | None, + qos_limit: float | None, + ) -> CreateProviderUseCaseResult: + router = await self.router_repository.get_router_by_id(router_id=router_id) + if router is None: + return RouterNotFoundError(router_id) + + if provider_type.value not in MODEL_TYPE_TO_MODEL_PROVIDER_TYPE_MAPPING[router.type]: + return InvalidProviderTypeError(provider_type.value) + + result = await self.provider_gateway.get_capabilities(provider_type=provider_type, url=url, key=key, timeout=timeout, model_name=model_name) + + match result: + case ProviderNotReachableError() as error: + return error + case provider_capabilities: + pass + + max_context_length = provider_capabilities.max_context_length + if router.type == ModelType.TEXT_EMBEDDINGS_INFERENCE: + vector_size = provider_capabilities.vector_size + else: + vector_size = None + + if router.providers > 0: + if router.vector_size != vector_size: + return InconsistentModelVectorSizeError( + actual_vector_size=vector_size, expected_vector_size=router.vector_size, router_name=router.name + ) + if router.max_context_length != max_context_length: + return InconsistentModelMaxContextLengthError( + actual_max_context_length=max_context_length, expected_max_context_length=router.max_context_length, router_name=router.name + ) + + result = await self.provider_repository.create_provider( + router_id=router_id, + user_id=user_id, + provider_type=provider_type, + url=url, + key=key, + timeout=timeout, + model_name=model_name, + model_hosting_zone=model_hosting_zone, + model_total_params=model_total_params, + model_active_params=model_active_params, + qos_metric=qos_metric, + qos_limit=qos_limit, + max_context_length=max_context_length, + vector_size=vector_size, + ) + + match result: + case Provider() as provider: + return CreateProviderUseCaseSuccess(provider) + case error: + return error diff --git a/api/use_cases/models/_getmodelsusecase.py b/api/use_cases/models/_getmodelsusecase.py index 69ed93888..40ca49ef1 100644 --- a/api/use_cases/models/_getmodelsusecase.py +++ b/api/use_cases/models/_getmodelsusecase.py @@ -1,7 +1,7 @@ from dataclasses import dataclass +from api.domain.model import Model, ModelCosts from api.domain.router import RouterRepository -from api.domain.router.entities import Model, ModelCosts from api.domain.userinfo import UserInfoRepository diff --git a/playground/app/features/providers/models.py b/playground/app/features/providers/models.py index 4f243e9a3..7da297636 100644 --- a/playground/app/features/providers/models.py +++ b/playground/app/features/providers/models.py @@ -2,7 +2,7 @@ class Provider(Entity): - """Provider model.""" + """provider model.""" id: int | None = None router: str | None = None diff --git a/playground/app/features/providers/state.py b/playground/app/features/providers/state.py index b93757938..1a33cef4d 100644 --- a/playground/app/features/providers/state.py +++ b/playground/app/features/providers/state.py @@ -218,7 +218,7 @@ async def delete_entity(self): response.raise_for_status() self.handle_delete_entity_dialog_change(is_open=False) - yield rx.toast.success("Provider deleted successfully", position="bottom-right") + yield rx.toast.success("provider deleted successfully", position="bottom-right") async for _ in self.load_entities(): yield @@ -289,7 +289,7 @@ async def create_entity(self): ) response.raise_for_status() - yield rx.toast.success("Provider created successfully", position="bottom-right") + yield rx.toast.success("provider created successfully", position="bottom-right") async for _ in self.load_entities(): yield @@ -356,7 +356,7 @@ async def edit_entity(self): response.raise_for_status() self.handle_settings_entity_dialog_change(is_open=False) - yield rx.toast.success("Provider updated successfully", position="bottom-right") + yield rx.toast.success("provider updated successfully", position="bottom-right") async for _ in self.load_entities(): yield From 72340e508ac346d17ca96861d45efb8ca7e5f149 Mon Sep 17 00:00:00 2001 From: Benjamin PILIA Date: Tue, 17 Feb 2026 16:22:30 +0100 Subject: [PATCH 02/13] Ajout tests unitaires --- api/domain/provider/errors.py | 3 +- .../fastapi/endpoints/admin/providers.py | 4 +- .../fastapi/endpoints/exceptions.py | 4 +- api/tests/unit/use_case/factories.py | 22 ++ .../use_case/test_createproviderusecase.py | 284 ++++++++++++++++++ api/use_cases/admin/providers/__init__.py | 6 +- .../admin/providers/_createproviderusecase.py | 2 +- 7 files changed, 315 insertions(+), 10 deletions(-) create mode 100644 api/tests/unit/use_case/test_createproviderusecase.py diff --git a/api/domain/provider/errors.py b/api/domain/provider/errors.py index 63a63786f..64cad8275 100644 --- a/api/domain/provider/errors.py +++ b/api/domain/provider/errors.py @@ -3,7 +3,8 @@ @dataclass class InvalidProviderTypeError: - type: str + provider_type: str + router_type: str @dataclass diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py index d0d53e3ea..db60c5ee0 100644 --- a/api/infrastructure/fastapi/endpoints/admin/providers.py +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -78,8 +78,8 @@ async def create_provider( match result: case CreateProviderUseCaseSuccess(created_provider): return CreateProviderResponse.model_validate(created_provider, from_attributes=True) - case InvalidProviderTypeError(provider_type): - raise InvalidProviderTypeHTTPException(provider_type) + case InvalidProviderTypeError(provider_type, router_type): + raise InvalidProviderTypeHTTPException(provider_type, router_type) case ProviderNotReachableError(name): raise ProviderNotReachableHTTPException(name) case ProviderAlreadyExistsError(model_name, url, router_id): diff --git a/api/infrastructure/fastapi/endpoints/exceptions.py b/api/infrastructure/fastapi/endpoints/exceptions.py index 09a94e1f0..3fe4cc246 100644 --- a/api/infrastructure/fastapi/endpoints/exceptions.py +++ b/api/infrastructure/fastapi/endpoints/exceptions.py @@ -3,9 +3,9 @@ # 400 class InvalidProviderTypeHTTPException(HTTPException): - def __init__(self, incorrect_provider_type: str) -> None: + def __init__(self, incorrect_provider_type: str, router_type: str) -> None: super().__init__( - status_code=400, detail=f"Invalid model provider type {incorrect_provider_type} for this model router type. Allowed types are: " + status_code=400, detail=f"Invalid model provider type {incorrect_provider_type} for {router_type} router. Allowed types are: " ) diff --git a/api/tests/unit/use_case/factories.py b/api/tests/unit/use_case/factories.py index d317b3350..c1203f06d 100644 --- a/api/tests/unit/use_case/factories.py +++ b/api/tests/unit/use_case/factories.py @@ -4,6 +4,7 @@ import factory from factory import fuzzy +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType from api.domain.role.entities import Limit, LimitType, PermissionType, Role from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy from api.domain.user.entities import User @@ -70,6 +71,27 @@ class Params: with_providers = factory.Trait(providers=factory.Faker("random_int", min=1, max=5)) +class ProviderFactory(factory.Factory): + class Meta: + model = Provider + + id = factory.Sequence(lambda n: n + 1) + router_id = factory.Faker("random_int", min=1, max=1000) + user_id = factory.Faker("random_int", min=1, max=1000) + type = factory.Faker("random_element", elements=list(ProviderType)) + url = factory.Faker("url") + key = None + timeout = 30 + model_name = factory.Faker("bothify", text="model-????") + model_hosting_zone = ProviderCarbonFootprintZone.WOR + model_total_params = 0 + model_active_params = 0 + qos_metric = None + qos_limit = None + created = factory.LazyFunction(lambda: int(datetime.now(UTC).timestamp())) + updated = factory.LazyFunction(lambda: int(datetime.now(UTC).timestamp())) + + class UserFactory(factory.Factory): class Meta: model = User diff --git a/api/tests/unit/use_case/test_createproviderusecase.py b/api/tests/unit/use_case/test_createproviderusecase.py new file mode 100644 index 000000000..607e6b3f9 --- /dev/null +++ b/api/tests/unit/use_case/test_createproviderusecase.py @@ -0,0 +1,284 @@ +from unittest.mock import AsyncMock + +import pytest + +from api.domain.model.errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError +from api.domain.provider import ProviderCapabilities +from api.domain.provider.entities import ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError +from api.domain.router.entities import ModelType +from api.domain.router.errors import RouterNotFoundError +from api.tests.unit.use_case.factories import ProviderFactory, RouterFactory +from api.use_cases.admin.providers import CreateProviderUseCase, CreateProviderUseCaseSuccess + + +@pytest.fixture +def use_case(): + return CreateProviderUseCase( + router_repository=AsyncMock(), + provider_repository=AsyncMock(), + user_info_repository=AsyncMock(), + provider_gateway=AsyncMock(), + ) + + +@pytest.fixture +def sample_router(): + return RouterFactory( + id=1, + name="test-router", + type=ModelType.TEXT_GENERATION, + providers=0, + ) + + +@pytest.fixture +def sample_router_with_providers(): + return RouterFactory( + id=1, + name="test-router", + type=ModelType.TEXT_GENERATION, + providers=2, + max_context_length=4096, + vector_size=None, + ) + + +@pytest.fixture +def sample_embedding_router_with_providers(): + return RouterFactory( + id=1, + name="embedding-router", + type=ModelType.TEXT_EMBEDDINGS_INFERENCE, + providers=1, + max_context_length=512, + vector_size=768, + ) + + +@pytest.fixture +def sample_provider(): + return ProviderFactory( + id=1, + router_id=1, + user_id=1, + type=ProviderType.VLLM, + url="https://example.com/", + model_name="my-model", + ) + + +@pytest.fixture +def default_execute_params(): + return dict( + router_id=1, + user_id=1, + provider_type=ProviderType.VLLM, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.WOR, + model_total_params=0, + model_active_params=0, + qos_metric=None, + qos_limit=None, + ) + + +class TestCreateProviderUseCase: + @pytest.mark.asyncio + async def test_should_create_provider_when_router_exists_without_any_provider( + self, use_case, sample_router, sample_provider, default_execute_params + ): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=4096, vector_size=None) + use_case.provider_repository.create_provider.return_value = sample_provider + + # Act + result = await use_case.execute(**default_execute_params) + + # Assert + assert isinstance(result, CreateProviderUseCaseSuccess) + assert result.provider == sample_provider + use_case.router_repository.get_router_by_id.assert_called_once_with(router_id=1) + use_case.provider_gateway.get_capabilities.assert_called_once_with( + provider_type=ProviderType.VLLM, url="https://example.com/", key=None, timeout=30, model_name="my-model" + ) + use_case.provider_repository.create_provider.assert_called_once_with( + router_id=1, + user_id=1, + provider_type=ProviderType.VLLM, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.WOR, + model_total_params=0, + model_active_params=0, + qos_metric=None, + qos_limit=None, + max_context_length=4096, + vector_size=None, + ) + + @pytest.mark.asyncio + async def test_should_create_provider_when_router_has_a_different_provider( + self, use_case, sample_router_with_providers, sample_provider, default_execute_params + ): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router_with_providers + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=4096, vector_size=None) + use_case.provider_repository.create_provider.return_value = sample_provider + + # Act + result = await use_case.execute(**default_execute_params) + + # Assert + assert isinstance(result, CreateProviderUseCaseSuccess) + assert result.provider == sample_provider + use_case.provider_repository.create_provider.assert_called_once_with( + router_id=1, + user_id=1, + provider_type=ProviderType.VLLM, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.WOR, + model_total_params=0, + model_active_params=0, + qos_metric=None, + qos_limit=None, + max_context_length=4096, + vector_size=None, + ) + + @pytest.mark.asyncio + async def test_should_create_embedding_provider_when_vector_size_matches( + self, use_case, sample_embedding_router_with_providers, sample_provider, default_execute_params + ): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_embedding_router_with_providers + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=512, vector_size=768) + use_case.provider_repository.create_provider.return_value = sample_provider + + # Act + result = await use_case.execute(**{**default_execute_params, "provider_type": ProviderType.TEI}) + + # Assert + assert isinstance(result, CreateProviderUseCaseSuccess) + assert result.provider == sample_provider + use_case.provider_repository.create_provider.assert_called_once_with( + router_id=1, + user_id=1, + provider_type=ProviderType.TEI, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.WOR, + model_total_params=0, + model_active_params=0, + qos_metric=None, + qos_limit=None, + max_context_length=512, + vector_size=768, + ) + + @pytest.mark.asyncio + async def test_should_return_router_not_found_error_when_router_does_not_exist(self, use_case, default_execute_params): + # Arrange + use_case.router_repository.get_router_by_id.return_value = None + + # Act + result = await use_case.execute(**default_execute_params) + + # Assert + assert isinstance(result, RouterNotFoundError) + assert result.router_id == 1 + use_case.provider_gateway.get_capabilities.assert_not_called() + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_invalid_provider_type_error_when_type_not_compatible(self, use_case, default_execute_params): + # Arrange + router = RouterFactory(id=1, name="tei-router", type=ModelType.TEXT_CLASSIFICATION) + use_case.router_repository.get_router_by_id.return_value = router + + # Act + result = await use_case.execute(**default_execute_params) + + # Assert + assert isinstance(result, InvalidProviderTypeError) + assert result.provider_type == ProviderType.VLLM.value + assert result.router_type == ModelType.TEXT_CLASSIFICATION + use_case.provider_gateway.get_capabilities.assert_not_called() + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_provider_not_reachable_error_when_gateway_fails(self, use_case, sample_router, default_execute_params): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router + use_case.provider_gateway.get_capabilities.return_value = ProviderNotReachableError(model_name="my-model") + + # Act + result = await use_case.execute(**default_execute_params) + + # Assert + assert isinstance(result, ProviderNotReachableError) + assert result.model_name == "my-model" + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_inconsistent_max_context_length_error_when_mismatch( + self, use_case, sample_router_with_providers, default_execute_params + ): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router_with_providers + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=2048, vector_size=None) + + # Act + result = await use_case.execute(**default_execute_params) + + # Assert + assert isinstance(result, InconsistentModelMaxContextLengthError) + assert result.actual_max_context_length == 2048 + assert result.expected_max_context_length == 4096 + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_inconsistent_vector_size_error_when_mismatch( + self, use_case, sample_embedding_router_with_providers, default_execute_params + ): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_embedding_router_with_providers + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=512, vector_size=384) + + # Act + result = await use_case.execute(**{**default_execute_params, "provider_type": ProviderType.TEI}) + + # Assert + assert isinstance(result, InconsistentModelVectorSizeError) + assert result.actual_vector_size == 384 + assert result.expected_vector_size == 768 + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_provider_already_exists_error(self, use_case, sample_router, default_execute_params): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=4096, vector_size=None) + use_case.provider_repository.create_provider.return_value = ProviderAlreadyExistsError( + model_name="my-model", url="https://example.com/", router_id=1 + ) + + # Act + result = await use_case.execute(**default_execute_params) + + # Assert + assert isinstance(result, ProviderAlreadyExistsError) + assert result.model_name == "my-model" + assert result.url == "https://example.com/" + assert result.router_id == 1 diff --git a/api/use_cases/admin/providers/__init__.py b/api/use_cases/admin/providers/__init__.py index c18f9293a..f063362a0 100644 --- a/api/use_cases/admin/providers/__init__.py +++ b/api/use_cases/admin/providers/__init__.py @@ -1,5 +1,3 @@ -from ._createproviderusecase import CreateProviderUseCase +from ._createproviderusecase import CreateProviderUseCase, CreateProviderUseCaseSuccess -__all__ = [ - "CreateProviderUseCase", -] +__all__ = ["CreateProviderUseCase", "CreateProviderUseCaseSuccess"] diff --git a/api/use_cases/admin/providers/_createproviderusecase.py b/api/use_cases/admin/providers/_createproviderusecase.py index 5d55d9ab0..53f96a814 100644 --- a/api/use_cases/admin/providers/_createproviderusecase.py +++ b/api/use_cases/admin/providers/_createproviderusecase.py @@ -94,7 +94,7 @@ async def execute( return RouterNotFoundError(router_id) if provider_type.value not in MODEL_TYPE_TO_MODEL_PROVIDER_TYPE_MAPPING[router.type]: - return InvalidProviderTypeError(provider_type.value) + return InvalidProviderTypeError(provider_type=provider_type.value, router_type=router.type) result = await self.provider_gateway.get_capabilities(provider_type=provider_type, url=url, key=key, timeout=timeout, model_name=model_name) From 615fcd30104a903fa06204f327459f4f27f09649 Mon Sep 17 00:00:00 2001 From: Benjamin PILIA Date: Tue, 17 Feb 2026 16:29:30 +0100 Subject: [PATCH 03/13] WIP --- api/domain/provider/entities.py | 36 ++++++ .../fastapi/endpoints/admin/providers.py | 6 +- .../use_case/test_createproviderusecase.py | 64 ++++++---- api/use_cases/admin/providers/__init__.py | 4 +- .../admin/providers/_createproviderusecase.py | 114 +++++++----------- 5 files changed, 121 insertions(+), 103 deletions(-) diff --git a/api/domain/provider/entities.py b/api/domain/provider/entities.py index 48401a485..9ffcbe3c5 100644 --- a/api/domain/provider/entities.py +++ b/api/domain/provider/entities.py @@ -4,6 +4,7 @@ import pycountry from pydantic import Field, constr +from api.domain.model.entities import ModelType from api.schemas import BaseModel from api.schemas.core.models import Metric @@ -21,6 +22,41 @@ class ProviderType(str, Enum): VLLM = "vllm" +COMPATIBLE_PROVIDER_TYPES: dict[ModelType, list[str]] = { + ModelType.AUTOMATIC_SPEECH_RECOGNITION: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.IMAGE_TEXT_TO_TEXT: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_EMBEDDINGS_INFERENCE: [ + ProviderType.ALBERT.value, + ProviderType.OPENAI.value, + ProviderType.TEI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_GENERATION: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_CLASSIFICATION: [ + ProviderType.ALBERT.value, + ProviderType.TEI.value, + ], + ModelType.IMAGE_TO_TEXT: [ + ProviderType.MISTRAL.value, + ], +} + + class Provider(BaseModel): object: Literal["provider"] = "provider" id: int = Field(..., description="Provider ID.") # fmt: off diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py index db60c5ee0..6d5916e30 100644 --- a/api/infrastructure/fastapi/endpoints/admin/providers.py +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -29,8 +29,7 @@ Providers, UpdateProvider, ) -from api.use_cases.admin.providers import CreateProviderUseCase -from api.use_cases.admin.providers._createproviderusecase import CreateProviderUseCaseSuccess +from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess from api.utils.dependencies import get_model_registry, get_postgres_session from api.utils.variables import ENDPOINT__ADMIN_PROVIDERS, ROUTER__ADMIN @@ -50,7 +49,7 @@ async def create_provider( request_context: RequestContext = Depends(get_request_context), ) -> CreateProviderResponse: try: - result = await create_provider_use_case.execute( + command = CreateProviderCommand( router_id=body.router, user_id=request_context.get().user_id, provider_type=body.type, @@ -64,6 +63,7 @@ async def create_provider( qos_metric=body.qos_metric, qos_limit=body.qos_limit, ) + result = await create_provider_use_case.execute(command) except Exception as e: logger.exception( "Unexpected error while executing create_router use case", diff --git a/api/tests/unit/use_case/test_createproviderusecase.py b/api/tests/unit/use_case/test_createproviderusecase.py index 607e6b3f9..8d84eb0b5 100644 --- a/api/tests/unit/use_case/test_createproviderusecase.py +++ b/api/tests/unit/use_case/test_createproviderusecase.py @@ -9,7 +9,7 @@ from api.domain.router.entities import ModelType from api.domain.router.errors import RouterNotFoundError from api.tests.unit.use_case.factories import ProviderFactory, RouterFactory -from api.use_cases.admin.providers import CreateProviderUseCase, CreateProviderUseCaseSuccess +from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess @pytest.fixture @@ -17,7 +17,6 @@ def use_case(): return CreateProviderUseCase( router_repository=AsyncMock(), provider_repository=AsyncMock(), - user_info_repository=AsyncMock(), provider_gateway=AsyncMock(), ) @@ -69,8 +68,8 @@ def sample_provider(): @pytest.fixture -def default_execute_params(): - return dict( +def default_command(): + return CreateProviderCommand( router_id=1, user_id=1, provider_type=ProviderType.VLLM, @@ -86,18 +85,33 @@ def default_execute_params(): ) +def with_provider_type(command: CreateProviderCommand, provider_type: ProviderType) -> CreateProviderCommand: + return CreateProviderCommand( + router_id=command.router_id, + user_id=command.user_id, + provider_type=provider_type, + url=command.url, + key=command.key, + timeout=command.timeout, + model_name=command.model_name, + model_hosting_zone=command.model_hosting_zone, + model_total_params=command.model_total_params, + model_active_params=command.model_active_params, + qos_metric=command.qos_metric, + qos_limit=command.qos_limit, + ) + + class TestCreateProviderUseCase: @pytest.mark.asyncio - async def test_should_create_provider_when_router_exists_without_any_provider( - self, use_case, sample_router, sample_provider, default_execute_params - ): + async def test_should_create_provider_when_router_exists_without_any_provider(self, use_case, sample_router, sample_provider, default_command): # Arrange use_case.router_repository.get_router_by_id.return_value = sample_router use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=4096, vector_size=None) use_case.provider_repository.create_provider.return_value = sample_provider # Act - result = await use_case.execute(**default_execute_params) + result = await use_case.execute(default_command) # Assert assert isinstance(result, CreateProviderUseCaseSuccess) @@ -125,7 +139,7 @@ async def test_should_create_provider_when_router_exists_without_any_provider( @pytest.mark.asyncio async def test_should_create_provider_when_router_has_a_different_provider( - self, use_case, sample_router_with_providers, sample_provider, default_execute_params + self, use_case, sample_router_with_providers, sample_provider, default_command ): # Arrange use_case.router_repository.get_router_by_id.return_value = sample_router_with_providers @@ -133,7 +147,7 @@ async def test_should_create_provider_when_router_has_a_different_provider( use_case.provider_repository.create_provider.return_value = sample_provider # Act - result = await use_case.execute(**default_execute_params) + result = await use_case.execute(default_command) # Assert assert isinstance(result, CreateProviderUseCaseSuccess) @@ -157,7 +171,7 @@ async def test_should_create_provider_when_router_has_a_different_provider( @pytest.mark.asyncio async def test_should_create_embedding_provider_when_vector_size_matches( - self, use_case, sample_embedding_router_with_providers, sample_provider, default_execute_params + self, use_case, sample_embedding_router_with_providers, sample_provider, default_command ): # Arrange use_case.router_repository.get_router_by_id.return_value = sample_embedding_router_with_providers @@ -165,7 +179,7 @@ async def test_should_create_embedding_provider_when_vector_size_matches( use_case.provider_repository.create_provider.return_value = sample_provider # Act - result = await use_case.execute(**{**default_execute_params, "provider_type": ProviderType.TEI}) + result = await use_case.execute(with_provider_type(default_command, ProviderType.TEI)) # Assert assert isinstance(result, CreateProviderUseCaseSuccess) @@ -188,12 +202,12 @@ async def test_should_create_embedding_provider_when_vector_size_matches( ) @pytest.mark.asyncio - async def test_should_return_router_not_found_error_when_router_does_not_exist(self, use_case, default_execute_params): + async def test_should_return_router_not_found_error_when_router_does_not_exist(self, use_case, default_command): # Arrange use_case.router_repository.get_router_by_id.return_value = None # Act - result = await use_case.execute(**default_execute_params) + result = await use_case.execute(default_command) # Assert assert isinstance(result, RouterNotFoundError) @@ -202,13 +216,13 @@ async def test_should_return_router_not_found_error_when_router_does_not_exist(s use_case.provider_repository.create_provider.assert_not_called() @pytest.mark.asyncio - async def test_should_return_invalid_provider_type_error_when_type_not_compatible(self, use_case, default_execute_params): + async def test_should_return_invalid_provider_type_error_when_type_not_compatible(self, use_case, default_command): # Arrange router = RouterFactory(id=1, name="tei-router", type=ModelType.TEXT_CLASSIFICATION) use_case.router_repository.get_router_by_id.return_value = router # Act - result = await use_case.execute(**default_execute_params) + result = await use_case.execute(default_command) # Assert assert isinstance(result, InvalidProviderTypeError) @@ -218,13 +232,13 @@ async def test_should_return_invalid_provider_type_error_when_type_not_compatibl use_case.provider_repository.create_provider.assert_not_called() @pytest.mark.asyncio - async def test_should_return_provider_not_reachable_error_when_gateway_fails(self, use_case, sample_router, default_execute_params): + async def test_should_return_provider_not_reachable_error_when_gateway_fails(self, use_case, sample_router, default_command): # Arrange use_case.router_repository.get_router_by_id.return_value = sample_router use_case.provider_gateway.get_capabilities.return_value = ProviderNotReachableError(model_name="my-model") # Act - result = await use_case.execute(**default_execute_params) + result = await use_case.execute(default_command) # Assert assert isinstance(result, ProviderNotReachableError) @@ -232,15 +246,13 @@ async def test_should_return_provider_not_reachable_error_when_gateway_fails(sel use_case.provider_repository.create_provider.assert_not_called() @pytest.mark.asyncio - async def test_should_return_inconsistent_max_context_length_error_when_mismatch( - self, use_case, sample_router_with_providers, default_execute_params - ): + async def test_should_return_inconsistent_max_context_length_error_when_mismatch(self, use_case, sample_router_with_providers, default_command): # Arrange use_case.router_repository.get_router_by_id.return_value = sample_router_with_providers use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=2048, vector_size=None) # Act - result = await use_case.execute(**default_execute_params) + result = await use_case.execute(default_command) # Assert assert isinstance(result, InconsistentModelMaxContextLengthError) @@ -250,14 +262,14 @@ async def test_should_return_inconsistent_max_context_length_error_when_mismatch @pytest.mark.asyncio async def test_should_return_inconsistent_vector_size_error_when_mismatch( - self, use_case, sample_embedding_router_with_providers, default_execute_params + self, use_case, sample_embedding_router_with_providers, default_command ): # Arrange use_case.router_repository.get_router_by_id.return_value = sample_embedding_router_with_providers use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=512, vector_size=384) # Act - result = await use_case.execute(**{**default_execute_params, "provider_type": ProviderType.TEI}) + result = await use_case.execute(with_provider_type(default_command, ProviderType.TEI)) # Assert assert isinstance(result, InconsistentModelVectorSizeError) @@ -266,7 +278,7 @@ async def test_should_return_inconsistent_vector_size_error_when_mismatch( use_case.provider_repository.create_provider.assert_not_called() @pytest.mark.asyncio - async def test_should_return_provider_already_exists_error(self, use_case, sample_router, default_execute_params): + async def test_should_return_provider_already_exists_error(self, use_case, sample_router, default_command): # Arrange use_case.router_repository.get_router_by_id.return_value = sample_router use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=4096, vector_size=None) @@ -275,7 +287,7 @@ async def test_should_return_provider_already_exists_error(self, use_case, sampl ) # Act - result = await use_case.execute(**default_execute_params) + result = await use_case.execute(default_command) # Assert assert isinstance(result, ProviderAlreadyExistsError) diff --git a/api/use_cases/admin/providers/__init__.py b/api/use_cases/admin/providers/__init__.py index f063362a0..6ba58c5e0 100644 --- a/api/use_cases/admin/providers/__init__.py +++ b/api/use_cases/admin/providers/__init__.py @@ -1,3 +1,3 @@ -from ._createproviderusecase import CreateProviderUseCase, CreateProviderUseCaseSuccess +from ._createproviderusecase import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess -__all__ = ["CreateProviderUseCase", "CreateProviderUseCaseSuccess"] +__all__ = ["CreateProviderCommand", "CreateProviderUseCase", "CreateProviderUseCaseSuccess"] diff --git a/api/use_cases/admin/providers/_createproviderusecase.py b/api/use_cases/admin/providers/_createproviderusecase.py index 53f96a814..89939a80b 100644 --- a/api/use_cases/admin/providers/_createproviderusecase.py +++ b/api/use_cases/admin/providers/_createproviderusecase.py @@ -1,16 +1,30 @@ from dataclasses import dataclass -from api.domain.model import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError +from api.domain.model import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError, ModelType from api.domain.provider import InvalidProviderTypeError, ProviderGateway, ProviderNotReachableError, ProviderRepository -from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.entities import COMPATIBLE_PROVIDER_TYPES, Provider, ProviderCarbonFootprintZone, ProviderType from api.domain.provider.errors import ProviderAlreadyExistsError from api.domain.router import RouterRepository from api.domain.router.errors import RouterNotFoundError -from api.domain.userinfo import UserInfoRepository -from api.infrastructure.fastapi.schemas.models import ModelType from api.schemas.core.models import Metric +@dataclass +class CreateProviderCommand: + router_id: int + user_id: int + provider_type: ProviderType + url: str + key: str | None + timeout: int + model_name: str + model_hosting_zone: ProviderCarbonFootprintZone + model_total_params: int + model_active_params: int + qos_metric: Metric | None + qos_limit: float | None + + @dataclass class CreateProviderUseCaseSuccess: provider: Provider @@ -26,77 +40,33 @@ class CreateProviderUseCaseSuccess: | ProviderAlreadyExistsError ) -MODEL_TYPE_TO_MODEL_PROVIDER_TYPE_MAPPING = { - ModelType.AUTOMATIC_SPEECH_RECOGNITION: [ - ProviderType.ALBERT.value, - ProviderType.MISTRAL.value, - ProviderType.OPENAI.value, - ProviderType.VLLM.value, - ], - ModelType.IMAGE_TEXT_TO_TEXT: [ - ProviderType.ALBERT.value, - ProviderType.MISTRAL.value, - ProviderType.OPENAI.value, - ProviderType.VLLM.value, - ], - ModelType.TEXT_EMBEDDINGS_INFERENCE: [ - ProviderType.ALBERT.value, - ProviderType.OPENAI.value, - ProviderType.TEI.value, - ProviderType.VLLM.value, - ], - ModelType.TEXT_GENERATION: [ - ProviderType.ALBERT.value, - ProviderType.MISTRAL.value, - ProviderType.OPENAI.value, - ProviderType.VLLM.value, - ], - ModelType.TEXT_CLASSIFICATION: [ - ProviderType.ALBERT.value, - ProviderType.TEI.value, - ], - ModelType.IMAGE_TO_TEXT: [ - ProviderType.MISTRAL.value, - ], -} - class CreateProviderUseCase: def __init__( self, router_repository: RouterRepository, provider_repository: ProviderRepository, - user_info_repository: UserInfoRepository, provider_gateway: ProviderGateway, ): self.router_repository = router_repository self.provider_repository = provider_repository - self.user_info_repository = user_info_repository self.provider_gateway = provider_gateway - async def execute( - self, - router_id: int, - user_id: int, - provider_type: ProviderType, - url: str, - key: str | None, - timeout: int, - model_name: str, - model_hosting_zone: ProviderCarbonFootprintZone, - model_total_params: int, - model_active_params: int, - qos_metric: Metric | None, - qos_limit: float | None, - ) -> CreateProviderUseCaseResult: - router = await self.router_repository.get_router_by_id(router_id=router_id) + async def execute(self, command: CreateProviderCommand) -> CreateProviderUseCaseResult: + router = await self.router_repository.get_router_by_id(router_id=command.router_id) if router is None: - return RouterNotFoundError(router_id) + return RouterNotFoundError(command.router_id) - if provider_type.value not in MODEL_TYPE_TO_MODEL_PROVIDER_TYPE_MAPPING[router.type]: - return InvalidProviderTypeError(provider_type=provider_type.value, router_type=router.type) + if command.provider_type.value not in COMPATIBLE_PROVIDER_TYPES[router.type]: + return InvalidProviderTypeError(provider_type=command.provider_type.value, router_type=router.type) - result = await self.provider_gateway.get_capabilities(provider_type=provider_type, url=url, key=key, timeout=timeout, model_name=model_name) + result = await self.provider_gateway.get_capabilities( + provider_type=command.provider_type, + url=command.url, + key=command.key, + timeout=command.timeout, + model_name=command.model_name, + ) match result: case ProviderNotReachableError() as error: @@ -121,18 +91,18 @@ async def execute( ) result = await self.provider_repository.create_provider( - router_id=router_id, - user_id=user_id, - provider_type=provider_type, - url=url, - key=key, - timeout=timeout, - model_name=model_name, - model_hosting_zone=model_hosting_zone, - model_total_params=model_total_params, - model_active_params=model_active_params, - qos_metric=qos_metric, - qos_limit=qos_limit, + router_id=command.router_id, + user_id=command.user_id, + provider_type=command.provider_type, + url=command.url, + key=command.key, + timeout=command.timeout, + model_name=command.model_name, + model_hosting_zone=command.model_hosting_zone, + model_total_params=command.model_total_params, + model_active_params=command.model_active_params, + qos_metric=command.qos_metric, + qos_limit=command.qos_limit, max_context_length=max_context_length, vector_size=vector_size, ) From 8b94eaa4464b4587b937e415753ee56db6c1a7a6 Mon Sep 17 00:00:00 2001 From: Benjamin PILIA Date: Wed, 18 Feb 2026 11:33:01 +0100 Subject: [PATCH 04/13] =?UTF-8?q?Refacto=20apr=C3=A8s=20rebase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/dependencies.py | 1 - .../fastapi/endpoints/admin/providers.py | 14 ++++----- api/tests/integration/conftest.py | 30 ++++++++----------- api/tests/integration/test_admin_providers.py | 4 +-- 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/api/dependencies.py b/api/dependencies.py index b9e92610f..c3b3e8bcd 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -49,7 +49,6 @@ def create_provider_use_case_factory( ) -> CreateProviderUseCase: return CreateProviderUseCase( router_repository=PostgresRouterRepository(postgres_session=postgres_session, app_title=configuration.settings.app_title), - user_info_repository=PostgresUserInfoRepository(postgres_session=postgres_session), provider_repository=PostgresProviderRepository(postgres_session=postgres_session), provider_gateway=ModelProviderGateway(), ) diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py index 6d5916e30..44fc5f4b7 100644 --- a/api/infrastructure/fastapi/endpoints/admin/providers.py +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -31,14 +31,14 @@ ) from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess from api.utils.dependencies import get_model_registry, get_postgres_session -from api.utils.variables import ENDPOINT__ADMIN_PROVIDERS, ROUTER__ADMIN +from api.utils.variables import EndpointRoute, RouterName logger = logging.getLogger(__name__) -router = APIRouter(prefix="/v1", tags=[ROUTER__ADMIN.title()]) +router = APIRouter(prefix="/v1", tags=[RouterName.ADMIN.title()]) @router.post( - path=ENDPOINT__ADMIN_PROVIDERS, + path=EndpointRoute.ADMIN_PROVIDERS, dependencies=[Security(dependency=get_current_key)], status_code=201, ) @@ -95,7 +95,7 @@ async def create_provider( @router.delete( - path=ENDPOINT__ADMIN_PROVIDERS + "/{provider}", + path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", dependencies=[Security(dependency=get_current_key)], status_code=204, ) @@ -111,7 +111,7 @@ async def delete_provider( @router.patch( - path=ENDPOINT__ADMIN_PROVIDERS + "/{provider}", + path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", dependencies=[Security(dependency=get_current_key)], status_code=204, ) @@ -138,7 +138,7 @@ async def update_provider( @router.get( - path=ENDPOINT__ADMIN_PROVIDERS + "/{provider}", + path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", dependencies=[Security(dependency=get_current_key)], status_code=200, response_model=Provider, @@ -156,7 +156,7 @@ async def get_provider( @router.get( - path=ENDPOINT__ADMIN_PROVIDERS, + path=EndpointRoute.ADMIN_PROVIDERS, dependencies=[Security(dependency=get_current_key)], status_code=200, response_model=Providers, diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 6b17937d9..b6a87558c 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -6,13 +6,12 @@ import pytest import pytest_asyncio from sqlalchemy import event -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.pool import NullPool from api.app import create_app from api.dependencies import get_postgres_session from api.helpers.models import ModelRegistry -from api.main import app from api.sql.models import Base from api.tests.integration import factories from api.utils.dependencies import get_model_registry @@ -68,33 +67,28 @@ async def test_engine(): await engine.dispose() -@pytest_asyncio.fixture(scope="session") -async def test_session_factory(test_engine): - return async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) +def _all_sql_factories(): + result = [] + stack = list(factories.BaseSQLFactory.__subclasses__()) + while stack: + cls = stack.pop() + result.append(cls) + stack.extend(cls.__subclasses__()) + return result @pytest_asyncio.fixture(scope="function") async def db_session(test_engine) -> AsyncGenerator[AsyncSession]: - """Provide a transactional scope for each test. - - Uses the recommended SQLAlchemy pattern: an outer transaction that is never - committed, with SAVEPOINTs for the test code. When code under test calls - session.commit() or session.rollback(), the SAVEPOINT is released/rolled back - and automatically restarted, so the outer transaction stays open and can be - rolled back at the end to undo everything. - """ async with test_engine.connect() as connection: transaction = await connection.begin() session = AsyncSession(bind=connection, expire_on_commit=False) await session.begin_nested() - all_sql_factories = factories.BaseSQLFactory.__subclasses__() + all_sql_factories = _all_sql_factories() for factory in all_sql_factories: factory._meta.sqlalchemy_session = session - # Restart a SAVEPOINT whenever code under test commits or rolls back, - # so the outer transaction is never affected. @event.listens_for(session.sync_session, "after_transaction_end") def restart_savepoint(sess, trans): if trans.nested and not trans._parent.nested: @@ -103,13 +97,15 @@ def restart_savepoint(sess, trans): try: yield session finally: + event.remove(session.sync_session, "after_transaction_end", restart_savepoint) + for factory in all_sql_factories: + factory._meta.sqlalchemy_session = None await session.close() await transaction.rollback() @pytest_asyncio.fixture(scope="session") def model_registry(): - """Create a real ModelRegistry for integration tests.""" return ModelRegistry( app_title="test", queuing_enabled=False, diff --git a/api/tests/integration/test_admin_providers.py b/api/tests/integration/test_admin_providers.py index 1c87c70cc..f04f6512a 100644 --- a/api/tests/integration/test_admin_providers.py +++ b/api/tests/integration/test_admin_providers.py @@ -18,9 +18,9 @@ from api.utils.context import request_context from api.utils.dependencies import get_model_registry from api.utils.dependencies import get_postgres_session as get_postgres_session_utils -from api.utils.variables import ENDPOINT__ADMIN_PROVIDERS +from api.utils.variables import EndpointRoute -URL = f"/v1{ENDPOINT__ADMIN_PROVIDERS}" +URL = f"/v1{EndpointRoute.ADMIN_PROVIDERS}" def _valid_body(router_id=1, **overrides) -> dict: From a738b67b3e0554f8a3fa660f0aca1b0190eace77 Mon Sep 17 00:00:00 2001 From: Benjamin PILIA Date: Thu, 19 Feb 2026 15:10:10 +0100 Subject: [PATCH 05/13] Refacto tests --- api/clients/model/_albertmodelprovider.py | 1 + .../fastapi/endpoints/exceptions.py | 4 +- .../model/_modelprovidergateway.py | 2 +- .../postgres/_postgresproviderrepository.py | 7 +- api/routers/registry.py | 2 +- api/tests/__init__.py | 0 api/tests/integration/__init__.py | 0 api/tests/integration/conftest.py | 75 ++-- api/tests/integration/endpoints/__init__.py | 0 .../endpoints/test_admin_providers.py | 186 ++++++++++ .../{ => endpoints}/test_admin_router.py | 0 .../{ => endpoints}/test_models.py | 0 api/tests/integration/postgres/__init__.py | 0 .../test_postgresproviderrepository.py | 86 +++++ .../test_postgresrouterrepository.py | 0 api/tests/integration/test_admin_providers.py | 345 ------------------ .../admin/providers/_createproviderusecase.py | 2 +- 17 files changed, 329 insertions(+), 381 deletions(-) create mode 100644 api/tests/__init__.py create mode 100644 api/tests/integration/__init__.py create mode 100644 api/tests/integration/endpoints/__init__.py create mode 100644 api/tests/integration/endpoints/test_admin_providers.py rename api/tests/integration/{ => endpoints}/test_admin_router.py (100%) rename api/tests/integration/{ => endpoints}/test_models.py (100%) create mode 100644 api/tests/integration/postgres/__init__.py create mode 100644 api/tests/integration/postgres/test_postgresproviderrepository.py rename api/tests/integration/{ => postgres}/test_postgresrouterrepository.py (100%) delete mode 100644 api/tests/integration/test_admin_providers.py diff --git a/api/clients/model/_albertmodelprovider.py b/api/clients/model/_albertmodelprovider.py index 3df6dbed4..27824082d 100644 --- a/api/clients/model/_albertmodelprovider.py +++ b/api/clients/model/_albertmodelprovider.py @@ -54,6 +54,7 @@ async def get_max_context_length(self) -> int | None: response = await client.get(url=url, headers=self.headers, timeout=self.timeout) response.raise_for_status() except Exception as e: + # TODO: remove exc_info=True and return error instead of exception logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True) raise AssertionError(f"Model is not reachable ({e}).") diff --git a/api/infrastructure/fastapi/endpoints/exceptions.py b/api/infrastructure/fastapi/endpoints/exceptions.py index 3fe4cc246..67863b899 100644 --- a/api/infrastructure/fastapi/endpoints/exceptions.py +++ b/api/infrastructure/fastapi/endpoints/exceptions.py @@ -4,9 +4,7 @@ # 400 class InvalidProviderTypeHTTPException(HTTPException): def __init__(self, incorrect_provider_type: str, router_type: str) -> None: - super().__init__( - status_code=400, detail=f"Invalid model provider type {incorrect_provider_type} for {router_type} router. Allowed types are: " - ) + super().__init__(status_code=400, detail=f"Invalid model provider type {incorrect_provider_type} for {router_type} router.") # 401 diff --git a/api/infrastructure/model/_modelprovidergateway.py b/api/infrastructure/model/_modelprovidergateway.py index f882a7806..c32399b65 100644 --- a/api/infrastructure/model/_modelprovidergateway.py +++ b/api/infrastructure/model/_modelprovidergateway.py @@ -12,7 +12,7 @@ async def get_capabilities(self, provider_type, url, key, timeout, model_name): max_context_length=max_context_length, vector_size=vector_size, ) - except Exception as e: + except AssertionError as e: return ProviderNotReachableError(model_name) def _build_client(self, provider_type, url, key, timeout, model_name): diff --git a/api/infrastructure/postgres/_postgresproviderrepository.py b/api/infrastructure/postgres/_postgresproviderrepository.py index 681ab97ab..f01889c9d 100644 --- a/api/infrastructure/postgres/_postgresproviderrepository.py +++ b/api/infrastructure/postgres/_postgresproviderrepository.py @@ -31,7 +31,6 @@ async def create_provider( max_context_length: int, ) -> Provider | ProviderAlreadyExistsError: try: - user_id = None if user_id == 0 else user_id # 0 corresponds to master user ID qos_metric = qos_metric.value if qos_metric is not None else None query = ( insert(ProviderTable) @@ -53,8 +52,10 @@ async def create_provider( ) .returning(ProviderTable) ) - result = await self.postgres_session.execute(query) - row = result.scalar_one() + async with self.postgres_session.begin_nested(): + result = await self.postgres_session.execute(query) + row = result.scalar_one() + return Provider( router_id=row.router_id, user_id=row.user_id, diff --git a/api/routers/registry.py b/api/routers/registry.py index a9fd85a2e..e637a15e8 100644 --- a/api/routers/registry.py +++ b/api/routers/registry.py @@ -12,7 +12,7 @@ class RouterDefinition: ROUTER_DEFINITIONS: tuple[RouterDefinition, ...] = ( # Admin routers RouterDefinition(name=RouterName.ADMIN, module_path="api.endpoints.admin.organizations"), - RouterDefinition(name=RouterName.ADMIN, module_path="api.endpoints.admin.providers"), + RouterDefinition(name=RouterName.ADMIN, module_path="api.infrastructure.fastapi.endpoints.admin.providers"), RouterDefinition(name=RouterName.ADMIN, module_path="api.endpoints.admin.roles"), RouterDefinition(name=RouterName.ADMIN, module_path="api.endpoints.admin.routers"), RouterDefinition(name=RouterName.ADMIN, module_path="api.infrastructure.fastapi.endpoints.admin_router"), diff --git a/api/tests/__init__.py b/api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration/__init__.py b/api/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index b6a87558c..55b57d97f 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -1,17 +1,16 @@ from collections.abc import AsyncGenerator -from types import SimpleNamespace import asyncpg from httpx import ASGITransport, AsyncClient import pytest import pytest_asyncio -from sqlalchemy import event -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool from api.app import create_app from api.dependencies import get_postgres_session from api.helpers.models import ModelRegistry +from api.schemas.core.configuration import Configuration, Dependencies, Settings from api.sql.models import Base from api.tests.integration import factories from api.utils.dependencies import get_model_registry @@ -22,8 +21,8 @@ @pytest.fixture def test_configuration(): - return SimpleNamespace( - settings=SimpleNamespace( + configuration = Configuration.model_construct( + settings=Settings.model_construct( app_title="test", swagger_summary=None, swagger_version="0.0.0", @@ -39,8 +38,9 @@ def test_configuration(): hidden_routers=[], monitoring_prometheus_enabled=False, ), - dependencies=SimpleNamespace(sentry=None), + dependencies=Dependencies.model_construct(sentry=None), ) + return configuration @pytest_asyncio.fixture(scope="session") @@ -67,6 +67,11 @@ async def test_engine(): await engine.dispose() +@pytest_asyncio.fixture(scope="session") +async def test_session_factory(test_engine): + return async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) + + def _all_sql_factories(): result = [] stack = list(factories.BaseSQLFactory.__subclasses__()) @@ -78,30 +83,46 @@ def _all_sql_factories(): @pytest_asyncio.fixture(scope="function") -async def db_session(test_engine) -> AsyncGenerator[AsyncSession]: - async with test_engine.connect() as connection: - transaction = await connection.begin() - - session = AsyncSession(bind=connection, expire_on_commit=False) - await session.begin_nested() - - all_sql_factories = _all_sql_factories() - for factory in all_sql_factories: - factory._meta.sqlalchemy_session = session - - @event.listens_for(session.sync_session, "after_transaction_end") - def restart_savepoint(sess, trans): - if trans.nested and not trans._parent.nested: - sess.begin_nested() - +async def db_session(test_session_factory) -> AsyncGenerator[AsyncSession]: + async with test_session_factory() as session: + all_sql_factories = factories.BaseSQLFactory.__subclasses__() + session.expire_on_commit = False try: - yield session + async with session.begin_nested(): + for factory in all_sql_factories: + factory._meta.sqlalchemy_session = session + yield session finally: - event.remove(session.sync_session, "after_transaction_end", restart_savepoint) - for factory in all_sql_factories: - factory._meta.sqlalchemy_session = None + if session.in_transaction(): + await session.rollback() await session.close() - await transaction.rollback() + + +# @pytest_asyncio.fixture(scope="function") +# async def db_session(test_engine) -> AsyncGenerator[AsyncSession]: +# async with test_engine.connect() as connection: +# transaction = await connection.begin() +# +# session = AsyncSession(bind=connection, expire_on_commit=False) +# await session.begin_nested() +# +# all_sql_factories = _all_sql_factories() +# for factory in all_sql_factories: +# factory._meta.sqlalchemy_session = session +# +# @event.listens_for(session.sync_session, "after_transaction_end") +# def restart_savepoint(sess, trans): +# if trans.nested and not trans._parent.nested: +# sess.begin_nested() +# +# try: +# yield session +# finally: +# event.remove(session.sync_session, "after_transaction_end", restart_savepoint) +# for factory in all_sql_factories: +# factory._meta.sqlalchemy_session = None +# await session.close() +# await transaction.rollback() @pytest_asyncio.fixture(scope="session") diff --git a/api/tests/integration/endpoints/__init__.py b/api/tests/integration/endpoints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration/endpoints/test_admin_providers.py b/api/tests/integration/endpoints/test_admin_providers.py new file mode 100644 index 000000000..ee2c4ba42 --- /dev/null +++ b/api/tests/integration/endpoints/test_admin_providers.py @@ -0,0 +1,186 @@ +import httpx +from httpx import AsyncClient +import pytest +import pytest_asyncio +import respx + +from api.schemas.models import ModelType +from api.tests.helpers import create_token +from api.tests.integration.factories import ProviderSQLFactory, RouterSQLFactory, UserSQLFactory +from api.utils.variables import EndpointRoute + +URL = f"/v1{EndpointRoute.ADMIN_PROVIDERS}" + +DEFAULT_PROVIDER_URL = "http://my-test-provider/" + + +def _valid_body(router_id=1, **overrides) -> dict: + """Return a minimal valid provider creation body, with optional overrides.""" + body = { + "router": router_id, + "type": "albert", + "model_name": "my-model", + "url": DEFAULT_PROVIDER_URL, + } + body.update(overrides) + return body + + +def _mock_provider_reachable(respx_mock, base_url=DEFAULT_PROVIDER_URL, max_context_length=4096, vector_size=768): + """Mock GET /v1/models and POST /v1/embeddings for a reachable albert provider.""" + base_url = base_url.rstrip("/") + respx_mock.get(f"{base_url}/v1/models").mock( + return_value=httpx.Response( + 200, + json={ + "data": [{"id": "my-model", "aliases": [], "max_context_length": max_context_length}], + }, + ) + ) + embedding = [0.0] * vector_size if vector_size else [] + respx_mock.post(f"{base_url}/v1/embeddings").mock( + return_value=httpx.Response( + 200, + json={ + "data": [{"embedding": embedding}], + }, + ) + ) + + +def _mock_provider_unreachable(respx_mock, base_url=DEFAULT_PROVIDER_URL): + """Mock a provider that cannot be reached.""" + base_url = base_url.rstrip("/") + respx_mock.get(f"{base_url}/v1/models").mock(side_effect=httpx.ConnectError("connection refused")) + respx_mock.post(f"{base_url}/v1/embeddings").mock(side_effect=httpx.ConnectError("connection refused")) + + +@pytest.mark.asyncio(loop_scope="session") +class TestCreateProvider: + @pytest_asyncio.fixture(autouse=True) + async def setup(self, db_session): + self.admin_user = UserSQLFactory(admin_user=True) + self.token = await create_token(db_session, name="admin_token", user=self.admin_user) + + @respx.mock + async def test_happy_path(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + _mock_provider_reachable(respx) + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 201, response.text + assert isinstance(response.json()["id"], int) + + @respx.mock + async def test_incompatible_provider_type(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + _mock_provider_reachable(respx, base_url="https://tei.example.com") + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(router.id, type="tei", url="https://tei.example.com/"), + ) + + assert response.status_code == 400 + assert response.json().get("detail") == "Invalid model provider type tei for text-generation router." + + @respx.mock + async def test_provider_not_reachable(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + _mock_provider_unreachable(respx) + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 424 + assert response.json().get("detail") == "Model provider my-model not reachable." + + @respx.mock + async def test_provider_already_exists(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) + ProviderSQLFactory( + router=router, + user=self.admin_user, + url=DEFAULT_PROVIDER_URL, + model_name="my-model", + max_context_length=4096, + vector_size=None, + ) + await db_session.flush() + _mock_provider_reachable(respx) + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(router.id), + ) + assert response.status_code == 409 + assert response.json().get("detail") == "Model provider my-model for url http://my-test-provider/ already exists for router 4." + + @respx.mock + async def test_provider_mismatch_max_context_length(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_EMBEDDINGS_INFERENCE, name="test_router") + ProviderSQLFactory( + router=router, + user=self.admin_user, + url="https://albert.api.etalab.gouv.fr/", + model_name="my-model", + max_context_length=4096, + vector_size=1234, + ) + await db_session.flush() + _mock_provider_reachable(respx, max_context_length=1234, vector_size=1234) + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 403 + assert response.json().get("detail") == "Inconsistent max context length for test_router. Expected: 1234. Actual: 4096" + + @respx.mock + async def test_provider_mismatch_vector_size(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION, name="test_router") + ProviderSQLFactory( + router=router, + user=self.admin_user, + url="https://albert.api.etalab.gouv.fr/", + model_name="my-model", + max_context_length=4096, + vector_size=1234, + ) + await db_session.flush() + _mock_provider_reachable(respx, max_context_length=1234, vector_size=1234) + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 403 + assert response.json().get("detail") == "Inconsistent vector size for test_router. Expected: None. Actual: 1234" + + @respx.mock + async def test_router_not_found(self, client: AsyncClient, db_session): + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(999999), + ) + + assert response.status_code == 404 + assert response.json().get("detail") == "Model router 999999 not found." diff --git a/api/tests/integration/test_admin_router.py b/api/tests/integration/endpoints/test_admin_router.py similarity index 100% rename from api/tests/integration/test_admin_router.py rename to api/tests/integration/endpoints/test_admin_router.py diff --git a/api/tests/integration/test_models.py b/api/tests/integration/endpoints/test_models.py similarity index 100% rename from api/tests/integration/test_models.py rename to api/tests/integration/endpoints/test_models.py diff --git a/api/tests/integration/postgres/__init__.py b/api/tests/integration/postgres/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration/postgres/test_postgresproviderrepository.py b/api/tests/integration/postgres/test_postgresproviderrepository.py new file mode 100644 index 000000000..a1bb8f877 --- /dev/null +++ b/api/tests/integration/postgres/test_postgresproviderrepository.py @@ -0,0 +1,86 @@ +import pytest + +from api.domain.model.entities import Metric +from api.domain.provider import Provider, ProviderAlreadyExistsError, ProviderCarbonFootprintZone, ProviderType +from api.domain.router.entities import ModelType +from api.infrastructure.postgres import PostgresProviderRepository +from api.tests.integration.factories import ( + ProviderSQLFactory, + RouterSQLFactory, + UserSQLFactory, +) + +_EXCLUDE = {"id", "created", "updated"} + + +def _create_provider_args(user, router, **overrides): + return { + "user_id": user.id, + "router_id": router.id, + "provider_type": ProviderType.ALBERT, + "url": "http://test.com/", + "key": "model-key", + "timeout": 60, + "model_name": "my-model", + "model_hosting_zone": ProviderCarbonFootprintZone.FRA, + "model_total_params": 1000, + "model_active_params": 2000, + "qos_metric": Metric.TTFT, + "qos_limit": 12, + "vector_size": 10, + "max_context_length": 20, + **overrides, + } + + +@pytest.fixture +def repository(db_session): + return PostgresProviderRepository(db_session) + + +@pytest.mark.asyncio(loop_scope="session") +class TestCreateProvider: + async def test_create_provider_should_return_created_provider(self, repository, db_session): + # Arrange + user = UserSQLFactory(admin_user=True) + router = RouterSQLFactory(user=user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + + # Act + result = await repository.create_provider(**_create_provider_args(user, router)) + + # Assert + expected = Provider( + id=result.id, + router_id=router.id, + user_id=user.id, + type=ProviderType.ALBERT, + url="http://test.com/", + key="model-key", + timeout=60, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.FRA, + model_total_params=1000, + model_active_params=2000, + qos_metric=Metric.TTFT, + qos_limit=12, + vector_size=10, + max_context_length=20, + ) + assert result.model_dump(exclude=_EXCLUDE) == expected.model_dump(exclude=_EXCLUDE) + + async def test_create_provider_should_return_provider_already_exists_when_same_url_name_and_router_are_used(self, repository, db_session): + # Arrange + user = UserSQLFactory(admin_user=True) + router = RouterSQLFactory(user=user, type=ModelType.TEXT_GENERATION) + ProviderSQLFactory(type=ProviderType.ALBERT, url="http://test.com/", model_name="duplicate-provider", router=router) + await db_session.flush() + + # Act + result = await repository.create_provider(**_create_provider_args(user, router, model_name="duplicate-provider")) + + # Assert + assert isinstance(result, ProviderAlreadyExistsError) + assert result.router_id == router.id + assert result.url == "http://test.com/" + assert result.model_name == "duplicate-provider" diff --git a/api/tests/integration/test_postgresrouterrepository.py b/api/tests/integration/postgres/test_postgresrouterrepository.py similarity index 100% rename from api/tests/integration/test_postgresrouterrepository.py rename to api/tests/integration/postgres/test_postgresrouterrepository.py diff --git a/api/tests/integration/test_admin_providers.py b/api/tests/integration/test_admin_providers.py deleted file mode 100644 index f04f6512a..000000000 --- a/api/tests/integration/test_admin_providers.py +++ /dev/null @@ -1,345 +0,0 @@ -from collections.abc import AsyncGenerator -from unittest.mock import patch - -from fastapi import FastAPI, Request -from httpx import ASGITransport, AsyncClient -import pytest -import pytest_asyncio -from sqlalchemy import select - -from api.dependencies import get_postgres_session -from api.infrastructure.fastapi.endpoints.admin.providers import router as providers_router -from api.schemas.core.context import RequestContext -from api.schemas.models import ModelType -from api.schemas.usage import Usage -from api.sql.models import Provider as ProviderTable -from api.tests.helpers import create_token -from api.tests.integration.factories import ProviderSQLFactory, RouterSQLFactory, UserSQLFactory -from api.utils.context import request_context -from api.utils.dependencies import get_model_registry -from api.utils.dependencies import get_postgres_session as get_postgres_session_utils -from api.utils.variables import EndpointRoute - -URL = f"/v1{EndpointRoute.ADMIN_PROVIDERS}" - - -def _valid_body(router_id=1, **overrides) -> dict: - """Return a minimal valid provider creation body, with optional overrides.""" - body = { - "router": router_id, - "type": "albert", - "model_name": "my-model", - } - body.update(overrides) - return body - - -# --------------------------------------------------------------------------- -# Fake providers – the ONLY mock: external HTTP boundary -# --------------------------------------------------------------------------- - - -class FakeProvider: - """Simulates an external model provider (health check calls).""" - - def __init__(self, url, key, timeout, model_name, model_hosting_zone, model_total_params, model_active_params): - self.model_name = model_name - - async def get_max_context_length(self): - return 4096 - - async def get_vector_size(self): - return 768 - - -class UnreachableFakeProvider(FakeProvider): - """provider whose health check fails.""" - - async def get_max_context_length(self): - raise AssertionError("provider not reachable") - - async def get_vector_size(self): - raise AssertionError("provider not reachable") - - -class FakeProviderWithDifferentVectorSizeAndMaxContentLength(FakeProvider): - """provider whose health check fails.""" - - async def get_max_context_length(self): - return 1234 - - async def get_vector_size(self): - return 1234 - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest_asyncio.fixture(scope="function") -async def client(db_session, model_registry) -> AsyncGenerator[AsyncClient, None]: - """Test client using a minimal app with only the new infrastructure providers router.""" - test_app = FastAPI() - - @test_app.middleware("http") - async def set_request_context(request: Request, call_next): - request_context.set(RequestContext(method=request.method, endpoint=request.url.path, usage=Usage())) - return await call_next(request) - - test_app.include_router(providers_router) - - async def override_get_postgres_session(): - try: - yield db_session - if db_session.in_transaction(): - await db_session.flush() - except Exception: - if db_session.in_transaction(): - await db_session.rollback() - raise - - test_app.dependency_overrides[get_postgres_session] = override_get_postgres_session - test_app.dependency_overrides[get_postgres_session_utils] = override_get_postgres_session - test_app.dependency_overrides[get_model_registry] = lambda: model_registry - - async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as ac: - yield ac - - -@pytest.fixture -def mock_import_module(): - """Patch ModelProvider.import_module so no real HTTP call is made.""" - with patch("api.helpers.models._modelregistry.ModelProvider.import_module") as mock: - mock.return_value = FakeProvider - yield mock - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio(loop_scope="session") -class TestCreateProvider: - async def test_happy_path(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) - await db_session.flush() - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(router.id), - ) - - assert response.status_code == 201, response.text - assert isinstance(response.json()["id"], int) - - async def test_no_auth_token(self, client: AsyncClient): - response = await client.post(url=URL, json=_valid_body()) - - assert response.status_code == 401 - - async def test_missing_required_field(self, client: AsyncClient, db_session): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - - body = {"type": "albert", "model_name": "my-model"} # missing "router" - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=body, - ) - - assert response.status_code == 422 - - async def test_invalid_provider_type(self, client: AsyncClient, db_session): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(type="not_a_real_provider"), - ) - - assert response.status_code == 422 - - async def test_qos_metric_without_limit(self, client: AsyncClient, db_session): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(qos_metric="ttft"), - ) - - assert response.status_code == 422 - - async def test_tei_type_requires_url(self, client: AsyncClient, db_session): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(type="tei"), - ) - - assert response.status_code == 422 - - async def test_incompatible_provider_type(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) - await db_session.flush() - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(router.id, type="tei", url="https://tei.example.com/"), - ) - - assert response.status_code == 400 - - async def test_provider_not_reachable(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) - await db_session.flush() - - mock_import_module.return_value = UnreachableFakeProvider - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(router.id), - ) - - assert response.status_code == 424 - - async def test_provider_already_exists(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) - ProviderSQLFactory( - router=router, - user=admin_user, - url="https://albert.api.etalab.gouv.fr/", - model_name="my-model", - max_context_length=4096, - vector_size=None, - ) - await db_session.flush() - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(router.id), - ) - - assert response.status_code == 409 - - async def test_provider_mismatch_max_context_length(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_EMBEDDINGS_INFERENCE, name="test_router") - ProviderSQLFactory( - router=router, - user=admin_user, - url="https://albert.api.etalab.gouv.fr/", - model_name="my-model", - max_context_length=4096, - vector_size=1234, - ) - mock_import_module.return_value = FakeProviderWithDifferentVectorSizeAndMaxContentLength - - await db_session.flush() - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(router.id), - ) - - assert response.status_code == 403 - assert response.json().get("detail") == "Inconsistent max context length for test_router. Expected: 1234. Actual: 4096" - - async def test_provider_mismatch_vector_size(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token( - db_session, - name="admin_token", - user=admin_user, - ) - router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION, name="test_router") - ProviderSQLFactory( - router=router, - user=admin_user, - url="https://albert.api.etalab.gouv.fr/", - model_name="my-model", - max_context_length=4096, - vector_size=1234, - ) - mock_import_module.return_value = FakeProviderWithDifferentVectorSizeAndMaxContentLength - - await db_session.flush() - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(router.id), - ) - - assert response.status_code == 403 - assert response.json().get("detail") == "Inconsistent vector size for test_router. Expected: None. Actual: 1234" - - async def test_router_not_found(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(999999), - ) - - assert response.status_code == 404 - - async def test_url_trailing_slash(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) - await db_session.flush() - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(router.id, url="https://my-provider.example.com"), - ) - - assert response.status_code == 201, response.text - provider_id = response.json()["id"] - result = await db_session.execute(select(ProviderTable.url).where(ProviderTable.id == provider_id)) - assert result.scalar_one() == "https://my-provider.example.com/" - - async def test_default_url_for_albert(self, client: AsyncClient, db_session, mock_import_module): - admin_user = UserSQLFactory(admin_user=True) - token = await create_token(db_session, name="admin_token", user=admin_user) - router = RouterSQLFactory(user=admin_user, type=ModelType.TEXT_GENERATION) - await db_session.flush() - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {token.token}"}, - json=_valid_body(router.id), - ) - - assert response.status_code == 201, response.text - provider_id = response.json()["id"] - result = await db_session.execute(select(ProviderTable.url).where(ProviderTable.id == provider_id)) - assert result.scalar_one() == "https://albert.api.etalab.gouv.fr/" diff --git a/api/use_cases/admin/providers/_createproviderusecase.py b/api/use_cases/admin/providers/_createproviderusecase.py index 89939a80b..86ca1d5ae 100644 --- a/api/use_cases/admin/providers/_createproviderusecase.py +++ b/api/use_cases/admin/providers/_createproviderusecase.py @@ -58,7 +58,7 @@ async def execute(self, command: CreateProviderCommand) -> CreateProviderUseCase return RouterNotFoundError(command.router_id) if command.provider_type.value not in COMPATIBLE_PROVIDER_TYPES[router.type]: - return InvalidProviderTypeError(provider_type=command.provider_type.value, router_type=router.type) + return InvalidProviderTypeError(provider_type=command.provider_type.value, router_type=router.type.value) result = await self.provider_gateway.get_capabilities( provider_type=command.provider_type, From 48dc4459e55629364f023768a7c028e58edb2912 Mon Sep 17 00:00:00 2001 From: Benjamin PILIA Date: Thu, 19 Feb 2026 17:03:49 +0100 Subject: [PATCH 06/13] Refacto integration tests --- .../fastapi/endpoints/admin/providers.py | 4 +- .../postgres/_postgresproviderrepository.py | 6 +- api/tests/integration/conftest.py | 108 ++++++----- .../endpoints/test_admin_providers.py | 167 ++++++------------ 4 files changed, 125 insertions(+), 160 deletions(-) diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py index 44fc5f4b7..818195ff5 100644 --- a/api/infrastructure/fastapi/endpoints/admin/providers.py +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -84,11 +84,11 @@ async def create_provider( raise ProviderNotReachableHTTPException(name) case ProviderAlreadyExistsError(model_name, url, router_id): raise ProviderAlreadyExistsHTTPException(model_name, url, router_id) - case InconsistentModelMaxContextLengthError(actual_max_context_length, expected_max_context_length, router_name): + case InconsistentModelMaxContextLengthError(expected_max_context_length, actual_max_context_length, router_name): raise InconsistentModelMaxContextLengthHTTPException( input_max_context_length=actual_max_context_length, model_max_context_length=expected_max_context_length, model_name=router_name ) - case InconsistentModelVectorSizeError(actual_vector_size, expected_vector_size, router_name): + case InconsistentModelVectorSizeError(expected_vector_size, actual_vector_size, router_name): raise InconsistentModelVectorSizeHTTPException(actual_vector_size, expected_vector_size, router_name) case RouterNotFoundError(router_id): raise RouterNotFoundHTTPException(router_id) diff --git a/api/infrastructure/postgres/_postgresproviderrepository.py b/api/infrastructure/postgres/_postgresproviderrepository.py index f01889c9d..5e59808b8 100644 --- a/api/infrastructure/postgres/_postgresproviderrepository.py +++ b/api/infrastructure/postgres/_postgresproviderrepository.py @@ -52,10 +52,8 @@ async def create_provider( ) .returning(ProviderTable) ) - async with self.postgres_session.begin_nested(): - result = await self.postgres_session.execute(query) - row = result.scalar_one() - + result = await self.postgres_session.execute(query) + row = result.scalar_one() return Provider( router_id=row.router_id, user_id=row.user_id, diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 55b57d97f..41a77f56f 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -4,7 +4,8 @@ from httpx import ASGITransport, AsyncClient import pytest import pytest_asyncio -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.pool import NullPool from api.app import create_app @@ -67,11 +68,6 @@ async def test_engine(): await engine.dispose() -@pytest_asyncio.fixture(scope="session") -async def test_session_factory(test_engine): - return async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) - - def _all_sql_factories(): result = [] stack = list(factories.BaseSQLFactory.__subclasses__()) @@ -82,47 +78,64 @@ def _all_sql_factories(): return result -@pytest_asyncio.fixture(scope="function") -async def db_session(test_session_factory) -> AsyncGenerator[AsyncSession]: - async with test_session_factory() as session: - all_sql_factories = factories.BaseSQLFactory.__subclasses__() - session.expire_on_commit = False - try: - async with session.begin_nested(): - for factory in all_sql_factories: - factory._meta.sqlalchemy_session = session - yield session - finally: - if session.in_transaction(): - await session.rollback() - await session.close() - +# @pytest_asyncio.fixture(scope="session") +# async def test_session_factory(test_engine): +# return async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) -# @pytest_asyncio.fixture(scope="function") -# async def db_session(test_engine) -> AsyncGenerator[AsyncSession]: -# async with test_engine.connect() as connection: -# transaction = await connection.begin() -# -# session = AsyncSession(bind=connection, expire_on_commit=False) -# await session.begin_nested() -# -# all_sql_factories = _all_sql_factories() -# for factory in all_sql_factories: -# factory._meta.sqlalchemy_session = session -# -# @event.listens_for(session.sync_session, "after_transaction_end") -# def restart_savepoint(sess, trans): -# if trans.nested and not trans._parent.nested: -# sess.begin_nested() # +# @pytest_asyncio.fixture(scope="function") +# async def db_session(test_session_factory) -> AsyncGenerator[AsyncSession]: +# async with test_session_factory() as session: +# all_sql_factories = factories.BaseSQLFactory.__subclasses__() +# session.expire_on_commit = False # try: -# yield session +# async with session.begin_nested(): +# for factory in all_sql_factories: +# factory._meta.sqlalchemy_session = session +# yield session # finally: -# event.remove(session.sync_session, "after_transaction_end", restart_savepoint) -# for factory in all_sql_factories: -# factory._meta.sqlalchemy_session = None +# if session.in_transaction(): +# await session.rollback() # await session.close() -# await transaction.rollback() + + +def pytest_addoption(parser): + parser.addoption( + "--commit-db", + action="store_true", + default=False, + help="Commit DB changes after each test (for debugging with psql).", + ) + + +@pytest_asyncio.fixture(scope="function") +async def db_session(test_engine, request) -> AsyncGenerator[AsyncSession]: + async with test_engine.connect() as connection: + transaction = await connection.begin() + + session = AsyncSession(bind=connection, expire_on_commit=False) + await session.begin_nested() + + all_sql_factories = _all_sql_factories() + for factory in all_sql_factories: + factory._meta.sqlalchemy_session = session + + @event.listens_for(session.sync_session, "after_transaction_end") + def restart_savepoint(sess, trans): + if trans.nested and not trans._parent.nested: + sess.begin_nested() + + try: + yield session + finally: + event.remove(session.sync_session, "after_transaction_end", restart_savepoint) + for factory in all_sql_factories: + factory._meta.sqlalchemy_session = None + await session.close() + if request.config.getoption("--commit-db"): + await transaction.commit() + else: + await transaction.rollback() @pytest_asyncio.fixture(scope="session") @@ -137,7 +150,7 @@ def model_registry(): @pytest_asyncio.fixture(scope="function") -async def client(db_session, model_registry, test_configuration) -> AsyncGenerator[AsyncClient, None]: +async def app(db_session, model_registry, test_configuration): app = create_app(test_configuration, skip_lifespan=True) async def override_get_postgres_session(): @@ -155,7 +168,12 @@ async def override_get_postgres_session(): app.dependency_overrides[get_model_registry] = lambda: model_registry try: - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - yield ac + yield app finally: app.dependency_overrides.clear() + + +@pytest_asyncio.fixture(scope="function") +async def client(app) -> AsyncGenerator[AsyncClient, None]: + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac diff --git a/api/tests/integration/endpoints/test_admin_providers.py b/api/tests/integration/endpoints/test_admin_providers.py index ee2c4ba42..d5039f32c 100644 --- a/api/tests/integration/endpoints/test_admin_providers.py +++ b/api/tests/integration/endpoints/test_admin_providers.py @@ -1,12 +1,18 @@ +from unittest.mock import AsyncMock + import httpx from httpx import AsyncClient import pytest import pytest_asyncio import respx +from api.dependencies import create_provider_use_case_factory +from api.domain.model.errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError +from api.domain.provider.errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError +from api.domain.router.errors import RouterNotFoundError from api.schemas.models import ModelType from api.tests.helpers import create_token -from api.tests.integration.factories import ProviderSQLFactory, RouterSQLFactory, UserSQLFactory +from api.tests.integration.factories import RouterSQLFactory, UserSQLFactory from api.utils.variables import EndpointRoute URL = f"/v1{EndpointRoute.ADMIN_PROVIDERS}" @@ -14,7 +20,7 @@ DEFAULT_PROVIDER_URL = "http://my-test-provider/" -def _valid_body(router_id=1, **overrides) -> dict: +def _valid_body(router_id: int, **overrides) -> dict: """Return a minimal valid provider creation body, with optional overrides.""" body = { "router": router_id, @@ -48,13 +54,6 @@ def _mock_provider_reachable(respx_mock, base_url=DEFAULT_PROVIDER_URL, max_cont ) -def _mock_provider_unreachable(respx_mock, base_url=DEFAULT_PROVIDER_URL): - """Mock a provider that cannot be reached.""" - base_url = base_url.rstrip("/") - respx_mock.get(f"{base_url}/v1/models").mock(side_effect=httpx.ConnectError("connection refused")) - respx_mock.post(f"{base_url}/v1/embeddings").mock(side_effect=httpx.ConnectError("connection refused")) - - @pytest.mark.asyncio(loop_scope="session") class TestCreateProvider: @pytest_asyncio.fixture(autouse=True) @@ -62,6 +61,11 @@ async def setup(self, db_session): self.admin_user = UserSQLFactory(admin_user=True) self.token = await create_token(db_session, name="admin_token", user=self.admin_user) + @pytest_asyncio.fixture(autouse=True) + async def cleanup_overrides(self, app): + yield + app.dependency_overrides.pop(create_provider_use_case_factory, None) + @respx.mock async def test_happy_path(self, client: AsyncClient, db_session): router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) @@ -76,111 +80,56 @@ async def test_happy_path(self, client: AsyncClient, db_session): assert response.status_code == 201, response.text assert isinstance(response.json()["id"], int) - @respx.mock - async def test_incompatible_provider_type(self, client: AsyncClient, db_session): - router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) - await db_session.flush() - _mock_provider_reachable(respx, base_url="https://tei.example.com") - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {self.token.token}"}, - json=_valid_body(router.id, type="tei", url="https://tei.example.com/"), - ) - - assert response.status_code == 400 - assert response.json().get("detail") == "Invalid model provider type tei for text-generation router." - - @respx.mock - async def test_provider_not_reachable(self, client: AsyncClient, db_session): - router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) - await db_session.flush() - _mock_provider_unreachable(respx) - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {self.token.token}"}, - json=_valid_body(router.id), - ) - - assert response.status_code == 424 - assert response.json().get("detail") == "Model provider my-model not reachable." - - @respx.mock - async def test_provider_already_exists(self, client: AsyncClient, db_session): - router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) - ProviderSQLFactory( - router=router, - user=self.admin_user, - url=DEFAULT_PROVIDER_URL, - model_name="my-model", - max_context_length=4096, - vector_size=None, - ) - await db_session.flush() - _mock_provider_reachable(respx) - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {self.token.token}"}, - json=_valid_body(router.id), - ) - assert response.status_code == 409 - assert response.json().get("detail") == "Model provider my-model for url http://my-test-provider/ already exists for router 4." - - @respx.mock - async def test_provider_mismatch_max_context_length(self, client: AsyncClient, db_session): - router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_EMBEDDINGS_INFERENCE, name="test_router") - ProviderSQLFactory( - router=router, - user=self.admin_user, - url="https://albert.api.etalab.gouv.fr/", - model_name="my-model", - max_context_length=4096, - vector_size=1234, - ) - await db_session.flush() - _mock_provider_reachable(respx, max_context_length=1234, vector_size=1234) - - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {self.token.token}"}, - json=_valid_body(router.id), - ) - - assert response.status_code == 403 - assert response.json().get("detail") == "Inconsistent max context length for test_router. Expected: 1234. Actual: 4096" - - @respx.mock - async def test_provider_mismatch_vector_size(self, client: AsyncClient, db_session): - router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION, name="test_router") - ProviderSQLFactory( - router=router, - user=self.admin_user, - url="https://albert.api.etalab.gouv.fr/", - model_name="my-model", - max_context_length=4096, - vector_size=1234, - ) - await db_session.flush() - _mock_provider_reachable(respx, max_context_length=1234, vector_size=1234) + @pytest.mark.parametrize( + "use_case_result,expected_status,expected_detail", + [ + (RouterNotFoundError(router_id=1), 404, "Model router 1 not found."), + ( + InvalidProviderTypeError(provider_type="tei", router_type="text-generation"), + 400, + "Invalid model provider type tei for text-generation router.", + ), + (ProviderNotReachableError(model_name="my-model"), 424, "Model provider my-model not reachable."), + ( + ProviderAlreadyExistsError(model_name="my-model", url=DEFAULT_PROVIDER_URL, router_id=1), + 409, + f"Model provider my-model for url {DEFAULT_PROVIDER_URL} already exists for router 1.", + ), + ( + InconsistentModelMaxContextLengthError(expected_max_context_length=4096, actual_max_context_length=2048, router_name="my-router"), + 403, + "Inconsistent max context length for my-router. Expected: 4096. Actual: 2048", + ), + ( + InconsistentModelVectorSizeError(expected_vector_size=768, actual_vector_size=384, router_name="my-router"), + 403, + "Inconsistent vector size for my-router. Expected: 768. Actual: 384", + ), + ], + ) + async def test_error_maps_to_correct_http_status(self, client: AsyncClient, app, use_case_result, expected_status, expected_detail): + mock_use_case = AsyncMock() + mock_use_case.execute.return_value = use_case_result + app.dependency_overrides[create_provider_use_case_factory] = lambda: mock_use_case response = await client.post( url=URL, headers={"Authorization": f"Bearer {self.token.token}"}, - json=_valid_body(router.id), + json=_valid_body(router_id=1), ) - assert response.status_code == 403 - assert response.json().get("detail") == "Inconsistent vector size for test_router. Expected: None. Actual: 1234" + assert response.status_code == expected_status + assert response.json().get("detail") == expected_detail - @respx.mock - async def test_router_not_found(self, client: AsyncClient, db_session): - response = await client.post( - url=URL, - headers={"Authorization": f"Bearer {self.token.token}"}, - json=_valid_body(999999), - ) + @pytest.mark.parametrize( + "headers,expected_status,expected_detail", + [ + ({}, 401, "Not authenticated"), + ({"Authorization": "Bearer invalid-token"}, 403, "Invalid API key."), + ], + ) + async def test_auth(self, client: AsyncClient, headers, expected_status, expected_detail): + response = await client.post(url=URL, headers=headers, json=_valid_body(router_id=1)) - assert response.status_code == 404 - assert response.json().get("detail") == "Model router 999999 not found." + assert response.status_code == expected_status + assert response.json().get("detail") == expected_detail From 308a36954fd1c0331100417db208e4a0b3b0a300 Mon Sep 17 00:00:00 2001 From: benjaminpilia Date: Fri, 20 Feb 2026 10:13:21 +0000 Subject: [PATCH 07/13] Update unit coverage badge --- .github/badges/coverage.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/badges/coverage.json b/.github/badges/coverage.json index f96bab692..b9757f0da 100644 --- a/.github/badges/coverage.json +++ b/.github/badges/coverage.json @@ -1 +1 @@ -{"schemaVersion":1,"label":"coverage","message":"50.85%","color":"red"} +{"schemaVersion":1,"label":"coverage","message":"50.68%","color":"red"} From 64afa2b6fc69a8f3ef60fffb3dc2cd74d3d28e97 Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Mon, 23 Feb 2026 10:54:11 +0100 Subject: [PATCH 08/13] fix(providers): import /v1/admin/providers endpoint & check router type before reach provider API --- api/app.py | 4 ++++ api/clients/model/_mistralmodelprovider.py | 1 + api/domain/provider/_providergateway.py | 2 ++ api/domain/router/_routerrepository.py | 5 +++-- api/domain/router/entities.py | 4 ++-- api/infrastructure/model/_modelprovidergateway.py | 8 ++++++-- .../postgres/_postgresrouterrepository.py | 11 ++++++----- api/use_cases/admin/_createrouterusecase.py | 5 +++-- .../admin/providers/_createproviderusecase.py | 1 + 9 files changed, 28 insertions(+), 13 deletions(-) diff --git a/api/app.py b/api/app.py index 0d2bcb9d9..c776ccff6 100644 --- a/api/app.py +++ b/api/app.py @@ -80,6 +80,10 @@ def _register_routers(app: FastAPI, configuration: Configuration) -> None: module = import_module("api.infrastructure.fastapi.endpoints.admin_router") app.include_router(router=module.router, include_in_schema=RouterName.ADMIN not in hidden_routers) + if RouterName.ADMIN not in disabled_routers: + module = import_module("api.infrastructure.fastapi.endpoints.admin.providers") + app.include_router(router=module.router, include_in_schema=RouterName.ADMIN not in hidden_routers) + def _setup_monitoring(app: FastAPI, configuration: Configuration) -> None: if RouterName.MONITORING in configuration.settings.disabled_routers: diff --git a/api/clients/model/_mistralmodelprovider.py b/api/clients/model/_mistralmodelprovider.py index 9d436791a..c7c688f58 100644 --- a/api/clients/model/_mistralmodelprovider.py +++ b/api/clients/model/_mistralmodelprovider.py @@ -53,6 +53,7 @@ async def get_max_context_length(self) -> int | None: async with httpx.AsyncClient() as client: response = await client.get(url=url, headers=self.headers, timeout=self.timeout) response.raise_for_status() + except Exception as e: logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True) raise AssertionError(f"Model is not reachable ({e}).") diff --git a/api/domain/provider/_providergateway.py b/api/domain/provider/_providergateway.py index da01c26e7..7f26df361 100644 --- a/api/domain/provider/_providergateway.py +++ b/api/domain/provider/_providergateway.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from api.domain.model import ModelType as RouterType from api.domain.provider.entities import ProviderType from api.domain.provider.errors import ProviderNotReachableError @@ -15,6 +16,7 @@ class ProviderGateway(ABC): @abstractmethod async def get_capabilities( self, + router_type: RouterType, provider_type: ProviderType, url: str, key: str | None, diff --git a/api/domain/router/_routerrepository.py b/api/domain/router/_routerrepository.py index 8bbc60083..58ec5eedb 100644 --- a/api/domain/router/_routerrepository.py +++ b/api/domain/router/_routerrepository.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.model import ModelType as RouterType +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError @@ -25,7 +26,7 @@ async def get_aliases_by_router_id(self, router_id: int) -> Router | None: async def create_router( self, name: str, - router_type: ModelType, + router_type: RouterType, load_balancing_strategy: RouterLoadBalancingStrategy, cost_prompt_tokens: float, cost_completion_tokens: float, diff --git a/api/domain/router/entities.py b/api/domain/router/entities.py index 6abea7b5b..49a2d47dd 100644 --- a/api/domain/router/entities.py +++ b/api/domain/router/entities.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -from api.domain.model import ModelType +from api.domain.model import ModelType as RouterType class RouterLoadBalancingStrategy(str, Enum): @@ -14,7 +14,7 @@ class Router(BaseModel): id: int = Field(..., description="ID of the router.") # fmt: off name: str = Field(..., description="Name of the router.") # fmt: off user_id: int = Field(..., description="ID of the user that owns the router.") # fmt: off - type: ModelType = Field(..., description="Type of the model router. It will be used to identify the model router type.", examples=["text-generation"]) # fmt: off + type: RouterType = Field(..., description="Type of the model router. It will be used to identify the model router type.", examples=["text-generation"]) # fmt: off aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off load_balancing_strategy: RouterLoadBalancingStrategy = Field(..., description="Routing strategy for load balancing between providers of the model. It will be used to identify the model type.", examples=["least_busy"]) # fmt: off vector_size: int | None = Field(default=None, description="Dimension of the vectors, if the models are embeddings. Make sure it is the same for all models.") # fmt: off diff --git a/api/infrastructure/model/_modelprovidergateway.py b/api/infrastructure/model/_modelprovidergateway.py index c32399b65..4d47c1804 100644 --- a/api/infrastructure/model/_modelprovidergateway.py +++ b/api/infrastructure/model/_modelprovidergateway.py @@ -1,13 +1,17 @@ from api.clients.model import BaseModelProvider +from api.domain.model import ModelType as RouterType from api.domain.provider import ProviderCapabilities, ProviderGateway, ProviderNotReachableError class ModelProviderGateway(ProviderGateway): - async def get_capabilities(self, provider_type, url, key, timeout, model_name): + async def get_capabilities(self, router_type, provider_type, url, key, timeout, model_name): try: client = self._build_client(provider_type, url, key, timeout, model_name) max_context_length = await client.get_max_context_length() - vector_size = await client.get_vector_size() + if router_type == RouterType.TEXT_EMBEDDINGS_INFERENCE: + vector_size = await client.get_vector_size() + else: + vector_size = None return ProviderCapabilities( max_context_length=max_context_length, vector_size=vector_size, diff --git a/api/infrastructure/postgres/_postgresrouterrepository.py b/api/infrastructure/postgres/_postgresrouterrepository.py index e3b3ac119..4c3ab19c3 100644 --- a/api/infrastructure/postgres/_postgresrouterrepository.py +++ b/api/infrastructure/postgres/_postgresrouterrepository.py @@ -3,8 +3,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from api.domain.key.entities import MASTER_USER_ID +from api.domain.model import ModelType as RouterType from api.domain.router import RouterRepository -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.sql.models import Organization as OrganizationTable from api.sql.models import Provider as ProviderTable @@ -53,7 +54,7 @@ async def get_router_by_id(self, router_id: int) -> Router | None: id=row.id, name=row.name, user_id=user_id, - type=ModelType(row.type), + type=RouterType(row.type), aliases=aliases, load_balancing_strategy=RouterLoadBalancingStrategy(row.load_balancing_strategy), vector_size=row.vector_size, @@ -116,7 +117,7 @@ async def get_all_routers(self) -> list[Router]: id=row["id"], name=row["name"], user_id=user_id, - type=ModelType(row["type"]), + type=RouterType(row["type"]), aliases=aliases.get(row["id"], []), load_balancing_strategy=RouterLoadBalancingStrategy(row["load_balancing_strategy"]), vector_size=row["vector_size"], @@ -143,7 +144,7 @@ async def get_all_aliases_grouped_by_router(self) -> dict[str, list[str]]: async def create_router( self, name: str, - router_type: ModelType, + router_type: RouterType, load_balancing_strategy: RouterLoadBalancingStrategy, cost_prompt_tokens: float, cost_completion_tokens: float, @@ -199,7 +200,7 @@ async def create_router( id=row.id, name=row.name, user_id=user_id, - type=ModelType(row.type), + type=RouterType(row.type), aliases=aliases, load_balancing_strategy=RouterLoadBalancingStrategy(row.load_balancing_strategy), vector_size=None, diff --git a/api/use_cases/admin/_createrouterusecase.py b/api/use_cases/admin/_createrouterusecase.py index a0e25b694..0f979d7da 100644 --- a/api/use_cases/admin/_createrouterusecase.py +++ b/api/use_cases/admin/_createrouterusecase.py @@ -1,7 +1,8 @@ from dataclasses import dataclass +from api.domain.model import ModelType as RouterType from api.domain.router import RouterRepository -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.domain.userinfo import UserInfoRepository from api.domain.userinfo.errors import InsufficientPermissionError @@ -26,7 +27,7 @@ async def execute( self, user_id: int, name: str, - router_type: ModelType, + router_type: RouterType, aliases: list[str], load_balancing_strategy: RouterLoadBalancingStrategy, cost_prompt_tokens: float, diff --git a/api/use_cases/admin/providers/_createproviderusecase.py b/api/use_cases/admin/providers/_createproviderusecase.py index 86ca1d5ae..276cb723a 100644 --- a/api/use_cases/admin/providers/_createproviderusecase.py +++ b/api/use_cases/admin/providers/_createproviderusecase.py @@ -61,6 +61,7 @@ async def execute(self, command: CreateProviderCommand) -> CreateProviderUseCase return InvalidProviderTypeError(provider_type=command.provider_type.value, router_type=router.type.value) result = await self.provider_gateway.get_capabilities( + router_type=router.type, provider_type=command.provider_type, url=command.url, key=command.key, From 289c2e102654576f6d21bd4ea6e0987ec5ee3c48 Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Mon, 23 Feb 2026 12:03:59 +0100 Subject: [PATCH 09/13] fix(providers): add exception auto documentation & add permissions check --- api/dependencies.py | 2 +- api/endpoints/admin/providers.py | 6 +- .../fastapi/endpoints/admin/providers.py | 55 +++++++------ .../fastapi/endpoints/admin_router.py | 17 +++- .../fastapi/endpoints/exceptions.py | 78 ++++++++++++++----- .../fastapi/endpoints/models.py | 7 +- api/infrastructure/fastapi/utils.py | 28 +++++++ .../endpoints/test_admin_router.py | 17 ++-- api/use_cases/admin/__init__.py | 7 +- api/use_cases/admin/_createrouterusecase.py | 35 +++++---- .../admin/providers/_createproviderusecase.py | 10 +++ 11 files changed, 175 insertions(+), 87 deletions(-) create mode 100644 api/infrastructure/fastapi/utils.py diff --git a/api/dependencies.py b/api/dependencies.py index c3b3e8bcd..5cc2bafb9 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -45,12 +45,12 @@ def get_models_use_case( def create_provider_use_case_factory( postgres_session: AsyncSession = Depends(get_postgres_session), - request_context: RequestContext = Depends(get_request_context), ) -> CreateProviderUseCase: return CreateProviderUseCase( router_repository=PostgresRouterRepository(postgres_session=postgres_session, app_title=configuration.settings.app_title), provider_repository=PostgresProviderRepository(postgres_session=postgres_session), provider_gateway=ModelProviderGateway(), + user_info_repository=PostgresUserInfoRepository(postgres_session=postgres_session), ) diff --git a/api/endpoints/admin/providers.py b/api/endpoints/admin/providers.py index 7369bd276..0a774c646 100644 --- a/api/endpoints/admin/providers.py +++ b/api/endpoints/admin/providers.py @@ -7,11 +7,7 @@ from api.endpoints.admin import router from api.helpers._accesscontroller import AccessController from api.helpers.models import ModelRegistry -from api.schemas.admin.providers import ( - Provider, - Providers, - UpdateProvider, -) +from api.schemas.admin.providers import Provider, Providers, UpdateProvider from api.schemas.admin.roles import PermissionType from api.utils.dependencies import get_model_registry, get_postgres_session from api.utils.variables import EndpointRoute diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py index 818195ff5..1502eaecd 100644 --- a/api/infrastructure/fastapi/endpoints/admin/providers.py +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -10,25 +10,22 @@ from api.domain.provider import InvalidProviderTypeError, ProviderNotReachableError from api.domain.provider.errors import ProviderAlreadyExistsError from api.domain.router.errors import RouterNotFoundError +from api.domain.userinfo.errors import InsufficientPermissionError from api.helpers.models import ModelRegistry from api.infrastructure.fastapi.access import get_current_key from api.infrastructure.fastapi.context import RequestContext from api.infrastructure.fastapi.endpoints.exceptions import ( InconsistentModelMaxContextLengthHTTPException, InconsistentModelVectorSizeHTTPException, + InsufficientPermissionHTTPException, InternalServerHTTPException, InvalidProviderTypeHTTPException, ProviderAlreadyExistsHTTPException, ProviderNotReachableHTTPException, RouterNotFoundHTTPException, ) -from api.infrastructure.fastapi.schemas.providers import ( - CreateProvider, - CreateProviderResponse, - Provider, - Providers, - UpdateProvider, -) +from api.infrastructure.fastapi.schemas.providers import CreateProvider, CreateProviderResponse, Provider, Providers, UpdateProvider +from api.infrastructure.fastapi.utils import get_documentation_responses from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess from api.utils.dependencies import get_model_registry, get_postgres_session from api.utils.variables import EndpointRoute, RouterName @@ -41,6 +38,15 @@ path=EndpointRoute.ADMIN_PROVIDERS, dependencies=[Security(dependency=get_current_key)], status_code=201, + responses=get_documentation_responses([ + InconsistentModelMaxContextLengthHTTPException, + InconsistentModelVectorSizeHTTPException, + InvalidProviderTypeHTTPException, + ProviderNotReachableHTTPException, + ProviderAlreadyExistsHTTPException, + RouterNotFoundHTTPException, + InsufficientPermissionHTTPException, + ]), ) async def create_provider( request: Request, @@ -66,10 +72,12 @@ async def create_provider( result = await create_provider_use_case.execute(command) except Exception as e: logger.exception( - "Unexpected error while executing create_router use case", + "Unexpected error while executing create_provider use case", extra={ "user_id": request_context.get().user_id, - "router_name": body.name, + "provider_router_id": body.router, + "provider_url": body.url, + "provider_model_name": body.model_name, "error_type": type(e).__name__, }, ) @@ -78,20 +86,21 @@ async def create_provider( match result: case CreateProviderUseCaseSuccess(created_provider): return CreateProviderResponse.model_validate(created_provider, from_attributes=True) - case InvalidProviderTypeError(provider_type, router_type): - raise InvalidProviderTypeHTTPException(provider_type, router_type) - case ProviderNotReachableError(name): - raise ProviderNotReachableHTTPException(name) - case ProviderAlreadyExistsError(model_name, url, router_id): - raise ProviderAlreadyExistsHTTPException(model_name, url, router_id) - case InconsistentModelMaxContextLengthError(expected_max_context_length, actual_max_context_length, router_name): - raise InconsistentModelMaxContextLengthHTTPException( - input_max_context_length=actual_max_context_length, model_max_context_length=expected_max_context_length, model_name=router_name - ) - case InconsistentModelVectorSizeError(expected_vector_size, actual_vector_size, router_name): - raise InconsistentModelVectorSizeHTTPException(actual_vector_size, expected_vector_size, router_name) - case RouterNotFoundError(router_id): - raise RouterNotFoundHTTPException(router_id) + + case InconsistentModelMaxContextLengthError(expected_max_context_length=expected_max_context_length, actual_max_context_length=actual_max_context_length, router_name=router_name): # fmt: off + raise InconsistentModelMaxContextLengthHTTPException(input_max_context_length=actual_max_context_length, model_max_context_length=expected_max_context_length, model_name=router_name) # fmt: off + case InconsistentModelVectorSizeError(expected_vector_size=expected_vector_size, actual_vector_size=actual_vector_size, router_name=router_name): # fmt: off + raise InconsistentModelVectorSizeHTTPException(input_vector_size=actual_vector_size, model_vector_size=expected_vector_size, model_name=router_name) # fmt: off + case InvalidProviderTypeError(provider_type=provider_type, router_type=router_type): + raise InvalidProviderTypeHTTPException(incorrect_provider_type=provider_type, router_type=router_type) + case ProviderNotReachableError(model_name=name): + raise ProviderNotReachableHTTPException(name=name) + case ProviderAlreadyExistsError(model_name=model_name, url=url, router_id=router_id): + raise ProviderAlreadyExistsHTTPException(model_name=model_name, url=url, router_id=router_id) + case RouterNotFoundError(router_id=router_id): + raise RouterNotFoundHTTPException(router_id=router_id) + case InsufficientPermissionError(): + raise InsufficientPermissionHTTPException() @router.delete( diff --git a/api/infrastructure/fastapi/endpoints/admin_router.py b/api/infrastructure/fastapi/endpoints/admin_router.py index f1b226bb9..7e33c90d6 100644 --- a/api/infrastructure/fastapi/endpoints/admin_router.py +++ b/api/infrastructure/fastapi/endpoints/admin_router.py @@ -14,14 +14,24 @@ RouterAlreadyExistsHTTPException, ) from api.infrastructure.fastapi.schemas.routers import CreateRouter, CreateRouterResponse -from api.use_cases.admin import CreateRouterUseCase, CreateRouterUseCaseSuccess +from api.infrastructure.fastapi.utils import get_documentation_responses +from api.use_cases.admin import CreateRouterCommand, CreateRouterUseCase, CreateRouterUseCaseSuccess from api.utils.variables import EndpointRoute, RouterName logger = logging.getLogger(__name__) router = APIRouter(prefix="/v1", tags=[RouterName.ADMIN.title()]) -@router.post(path=EndpointRoute.ADMIN_ROUTERS, dependencies=[Security(dependency=get_current_key)], status_code=201) +@router.post( + path=EndpointRoute.ADMIN_ROUTERS, + dependencies=[Security(dependency=get_current_key)], + status_code=201, + responses=get_documentation_responses([ + RouterAliasAlreadyExistsHTTPException, + RouterAlreadyExistsHTTPException, + InsufficientPermissionHTTPException, + ]), +) async def create_router( body: CreateRouter = Body(description="The router creation request."), create_router_use_case: CreateRouterUseCase = Depends(create_router_use_case), @@ -31,7 +41,7 @@ async def create_router( Create a router (without any providers). """ try: - result = await create_router_use_case.execute( + command = CreateRouterCommand( user_id=request_context.get().user_id, name=body.name, router_type=body.type, @@ -40,6 +50,7 @@ async def create_router( cost_prompt_tokens=body.cost_prompt_tokens, cost_completion_tokens=body.cost_completion_tokens, ) + result = await create_router_use_case.execute(command) except Exception as e: logger.exception( "Unexpected error while executing create_router use case", diff --git a/api/infrastructure/fastapi/endpoints/exceptions.py b/api/infrastructure/fastapi/endpoints/exceptions.py index 67863b899..59da1715e 100644 --- a/api/infrastructure/fastapi/endpoints/exceptions.py +++ b/api/infrastructure/fastapi/endpoints/exceptions.py @@ -3,67 +3,99 @@ # 400 class InvalidProviderTypeHTTPException(HTTPException): + status_code = 400 + detail = "Invalid model provider type {input_type} for {expected_type} router." + def __init__(self, incorrect_provider_type: str, router_type: str) -> None: - super().__init__(status_code=400, detail=f"Invalid model provider type {incorrect_provider_type} for {router_type} router.") + super().__init__(status_code=self.status_code, detail=f"Invalid model provider type {incorrect_provider_type} for {router_type} router.") # 401 +class InvalidAPIKeyException(HTTPException): + status_code = 401 + detail = "Invalid API key." + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) -# 403 -class InvalidAuthenticationSchemeException(HTTPException): - def __init__(self, detail: str = "Invalid authentication scheme.") -> None: - super().__init__(status_code=403, detail=detail) +class InvalidAuthenticationSchemeException(HTTPException): + status_code = 401 + detail = "Invalid authentication scheme." -class InvalidAPIKeyException(HTTPException): - def __init__(self, detail: str = "Invalid API key.") -> None: - super().__init__(status_code=403, detail=detail) + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) +# 403 class InsufficientPermissionHTTPException(HTTPException): - def __init__(self, detail: str = "Insufficient rights.") -> None: - super().__init__(status_code=403, detail=detail) + status_code = 403 + detail = "Insufficient rights." + + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) class InconsistentModelMaxContextLengthHTTPException(HTTPException): + status_code = 403 + detail = "Inconsistent max context length for {model_name}. Expected: {expected_length}. Actual: {actual_length}" + def __init__(self, input_max_context_length: int, model_max_context_length: int, model_name: str) -> None: super().__init__( - status_code=403, + status_code=self.status_code, detail=f"Inconsistent max context length for {model_name}. Expected: {model_max_context_length}. Actual: {input_max_context_length}", ) class InconsistentModelVectorSizeHTTPException(HTTPException): + status_code = 403 + detail = "Inconsistent vector size for {model_name}. Expected: {expected_size}. Actual: {actual_size}" + def __init__(self, input_vector_size: int, model_vector_size: int, model_name: str) -> None: super().__init__( - status_code=403, detail=f"Inconsistent vector size for {model_name}. Expected: {model_vector_size}. Actual: {input_vector_size}" + status_code=self.status_code, + detail=f"Inconsistent vector size for {model_name}. Expected: {model_vector_size}. Actual: {input_vector_size}", ) # 404 class ModelNotFoundHTTPException(HTTPException): - def __init__(self, detail: str = "Model not found.") -> None: - super().__init__(status_code=404, detail=detail) + status_code = 404 + detail = "Model not found." + + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) class RouterNotFoundHTTPException(HTTPException): + status_code = 404 + detail = "Model router {router_id} not found." + def __init__(self, router_id: int) -> None: - super().__init__(status_code=404, detail=f"Model router {router_id} not found.") + super().__init__(status_code=self.status_code, detail=f"Model router {router_id} not found.") # 409 class RouterAliasAlreadyExistsHTTPException(HTTPException): + status_code = 409 + detail = "Following aliases already exist: '{router_aliases}'" + def __init__(self, aliases: list[str]): - super().__init__(status_code=409, detail=f"Following aliases already exist: '{aliases}'") + super().__init__(status_code=self.status_code, detail=f"Following aliases already exist: '{aliases}'") class RouterAlreadyExistsHTTPException(HTTPException): + status_code = 409 + detail = "Router {router_name} already exists." + def __init__(self, name: str): - super().__init__(status_code=409, detail=f"Router '{name}' already exists.") + super().__init__(status_code=self.status_code, detail=f"Router {name} already exists.") class ProviderAlreadyExistsHTTPException(HTTPException): + status_code = 409 + detail = "Model provider {model_name} for url {url} already exists for router {router_id}." + def __init__(self, model_name: str, url: str, router_id: int) -> None: super().__init__(status_code=409, detail=f"Model provider {model_name} for url {url} already exists for router {router_id}.") @@ -76,8 +108,11 @@ def __init__(self, model_name: str, url: str, router_id: int) -> None: # 424 class ProviderNotReachableHTTPException(HTTPException): + status_code = 424 + detail = "Model provider {provider_name} not reachable." + def __init__(self, name: str) -> None: - super().__init__(status_code=424, detail=f"Model provider {name} not reachable.") + super().__init__(status_code=self.status_code, detail=f"Model provider {name} not reachable.") # 429 @@ -85,10 +120,11 @@ def __init__(self, name: str) -> None: # 500 class InternalServerHTTPException(HTTPException): - """Exception for unexpected internal errors.""" + status_code = 500 + detail = "An unexpected error occurred" - def __init__(self, detail: str = "An unexpected error occurred"): - super().__init__(status_code=500, detail=detail) + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) # 503 diff --git a/api/infrastructure/fastapi/endpoints/models.py b/api/infrastructure/fastapi/endpoints/models.py index db2118b8e..6a9324d17 100644 --- a/api/infrastructure/fastapi/endpoints/models.py +++ b/api/infrastructure/fastapi/endpoints/models.py @@ -5,10 +5,9 @@ from api.infrastructure.fastapi.access import get_current_key from api.infrastructure.fastapi.endpoints.exceptions import ModelNotFoundHTTPException from api.infrastructure.fastapi.schemas.models import Model, Models -from api.schemas.exception import HTTPExceptionModel +from api.infrastructure.fastapi.utils import get_documentation_responses from api.use_cases.models import GetModelsUseCase from api.use_cases.models._getmodelsusecase import ModelNotFound, Success -from api.utils.exceptions import ModelNotFoundException from api.utils.variables import EndpointRoute, RouterName router = APIRouter(prefix="/v1", tags=[RouterName.MODELS.title()]) @@ -19,7 +18,7 @@ dependencies=[Security(dependency=get_current_key)], status_code=200, response_model=Model, - responses={ModelNotFoundException().status_code: {"model": HTTPExceptionModel, "description": {ModelNotFoundException().detail}}}, + responses=get_documentation_responses([ModelNotFoundHTTPException]), ) async def get_model( request: Request, @@ -45,7 +44,7 @@ async def get_model( dependencies=[Security(dependency=get_current_key)], status_code=200, response_model=Models, - responses={ModelNotFoundException().status_code: {"model": HTTPExceptionModel, "description": {ModelNotFoundException().detail}}}, + responses=get_documentation_responses([ModelNotFoundHTTPException]), ) async def get_models( request: Request, diff --git a/api/infrastructure/fastapi/utils.py b/api/infrastructure/fastapi/utils.py new file mode 100644 index 000000000..148ecfe92 --- /dev/null +++ b/api/infrastructure/fastapi/utils.py @@ -0,0 +1,28 @@ +from fastapi import HTTPException +from pydantic import BaseModel + +from api.infrastructure.fastapi.endpoints.exceptions import ( + InvalidAPIKeyException, + InvalidAuthenticationSchemeException, +) + + +class HTTPExceptionModel(BaseModel): + status_code: int + detail: str + headers: dict[str, str] | None = None + + +def get_documentation_responses(exceptions: list[HTTPException]): + """ + Generate a dictionary of responses for a list of HTTP exceptions in Redoc and Swagger documentation. + """ + exceptions.extend([InvalidAuthenticationSchemeException, InvalidAPIKeyException]) + responses = {} + for exception in exceptions: + if exception.status_code not in responses: + responses[exception.status_code] = {"model": HTTPExceptionModel, "description": exception.detail} + else: + responses[exception.status_code]["description"] += f"
{exception.detail}" + + return responses diff --git a/api/tests/integration/endpoints/test_admin_router.py b/api/tests/integration/endpoints/test_admin_router.py index 4a356be37..8a71045e6 100644 --- a/api/tests/integration/endpoints/test_admin_router.py +++ b/api/tests/integration/endpoints/test_admin_router.py @@ -2,12 +2,9 @@ import pytest from api.domain.router.entities import RouterLoadBalancingStrategy -from api.schemas.models import ModelType +from api.schemas.models import ModelType as RouterType from api.tests.helpers import create_token -from api.tests.integration.factories import ( - RouterSQLFactory, - UserSQLFactory, -) +from api.tests.integration.factories import RouterSQLFactory, UserSQLFactory from api.utils.variables import EndpointRoute @@ -75,7 +72,7 @@ async def test_create_router_with_duplicate_name(self, client: AsyncClient, db_s RouterSQLFactory( user=admin_user, name=duplicate_name, - type=ModelType.TEXT_GENERATION, + type=RouterType.TEXT_GENERATION, ) await db_session.flush() @@ -83,7 +80,7 @@ async def test_create_router_with_duplicate_name(self, client: AsyncClient, db_s router_data = { "name": duplicate_name, - "type": ModelType.TEXT_GENERATION, + "type": RouterType.TEXT_GENERATION, "aliases": [], "load_balancing_strategy": "shuffle", "cost_prompt_tokens": 0.001, @@ -99,20 +96,20 @@ async def test_create_router_with_duplicate_name(self, client: AsyncClient, db_s # Assert assert response.status_code in [400, 409], f"Expected 400 or 409, got {response.status_code}" - assert response.json().get("detail") == f"Router '{duplicate_name}' already exists." + assert response.json().get("detail") == f"Router {duplicate_name} already exists." async def test_create_router_with_duplicate_alias(self, client: AsyncClient, db_session): # Arrange admin_user = UserSQLFactory(admin_user=True) duplicate_alias = "duplicate-alias" - RouterSQLFactory(user=admin_user, name="existing-router", type=ModelType.TEXT_GENERATION, alias=[duplicate_alias]) + RouterSQLFactory(user=admin_user, name="existing-router", type=RouterType.TEXT_GENERATION, alias=[duplicate_alias]) await db_session.flush() token = await create_token(db_session, name="admin_token", user=admin_user) router_data = { "name": "new-router", - "type": ModelType.TEXT_GENERATION.value, + "type": RouterType.TEXT_GENERATION.value, "aliases": [duplicate_alias], "load_balancing_strategy": RouterLoadBalancingStrategy.SHUFFLE.value, "cost_prompt_tokens": 0.001, diff --git a/api/use_cases/admin/__init__.py b/api/use_cases/admin/__init__.py index e9e497d35..7f55f81f9 100644 --- a/api/use_cases/admin/__init__.py +++ b/api/use_cases/admin/__init__.py @@ -1,6 +1,3 @@ -from ._createrouterusecase import CreateRouterUseCase, CreateRouterUseCaseSuccess +from ._createrouterusecase import CreateRouterCommand, CreateRouterUseCase, CreateRouterUseCaseSuccess -__all__ = [ - "CreateRouterUseCase", - "CreateRouterUseCaseSuccess", -] +__all__ = ["CreateRouterCommand", "CreateRouterUseCase", "CreateRouterUseCaseSuccess"] diff --git a/api/use_cases/admin/_createrouterusecase.py b/api/use_cases/admin/_createrouterusecase.py index 0f979d7da..0afabba5d 100644 --- a/api/use_cases/admin/_createrouterusecase.py +++ b/api/use_cases/admin/_createrouterusecase.py @@ -8,6 +8,17 @@ from api.domain.userinfo.errors import InsufficientPermissionError +@dataclass +class CreateRouterCommand: + user_id: int + name: str + router_type: RouterType + aliases: list[str] + load_balancing_strategy: RouterLoadBalancingStrategy + cost_prompt_tokens: float + cost_completion_tokens: float + + @dataclass class CreateRouterUseCaseSuccess: router: Router @@ -25,27 +36,21 @@ def __init__(self, router_repository: RouterRepository, user_info_repository: Us async def execute( self, - user_id: int, - name: str, - router_type: RouterType, - aliases: list[str], - load_balancing_strategy: RouterLoadBalancingStrategy, - cost_prompt_tokens: float, - cost_completion_tokens: float, + command: CreateRouterCommand, ) -> CreateRouterUseCaseResult: - user_info = await self.user_info_repository.get_user_info(user_id=user_id) + user_info = await self.user_info_repository.get_user_info(user_id=command.user_id) if not user_info.is_admin: return InsufficientPermissionError() result = await self.router_repository.create_router( - name=name, - router_type=router_type, - load_balancing_strategy=load_balancing_strategy, - cost_prompt_tokens=cost_prompt_tokens, - cost_completion_tokens=cost_completion_tokens, - user_id=user_id, - aliases=aliases, + name=command.name, + router_type=command.router_type, + load_balancing_strategy=command.load_balancing_strategy, + cost_prompt_tokens=command.cost_prompt_tokens, + cost_completion_tokens=command.cost_completion_tokens, + user_id=command.user_id, + aliases=command.aliases, ) match result: diff --git a/api/use_cases/admin/providers/_createproviderusecase.py b/api/use_cases/admin/providers/_createproviderusecase.py index 276cb723a..d8d32dbb3 100644 --- a/api/use_cases/admin/providers/_createproviderusecase.py +++ b/api/use_cases/admin/providers/_createproviderusecase.py @@ -6,6 +6,8 @@ from api.domain.provider.errors import ProviderAlreadyExistsError from api.domain.router import RouterRepository from api.domain.router.errors import RouterNotFoundError +from api.domain.userinfo import UserInfoRepository +from api.domain.userinfo.errors import InsufficientPermissionError from api.schemas.core.models import Metric @@ -38,6 +40,7 @@ class CreateProviderUseCaseSuccess: | InconsistentModelVectorSizeError | RouterNotFoundError | ProviderAlreadyExistsError + | InsufficientPermissionError ) @@ -47,12 +50,19 @@ def __init__( router_repository: RouterRepository, provider_repository: ProviderRepository, provider_gateway: ProviderGateway, + user_info_repository: UserInfoRepository, ): self.router_repository = router_repository self.provider_repository = provider_repository self.provider_gateway = provider_gateway + self.user_info_repository = user_info_repository async def execute(self, command: CreateProviderCommand) -> CreateProviderUseCaseResult: + user_info = await self.user_info_repository.get_user_info(user_id=command.user_id) + + if not user_info.is_admin: + return InsufficientPermissionError() + router = await self.router_repository.get_router_by_id(router_id=command.router_id) if router is None: return RouterNotFoundError(command.router_id) From 6cec09c37da50336a5d5ec7c327a0bff8007d25e Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Mon, 23 Feb 2026 12:11:28 +0100 Subject: [PATCH 10/13] fix(provider): move admin_router to admin folder --- api/app.py | 7 +++---- api/infrastructure/fastapi/endpoints/admin/__init__.py | 7 +++++++ api/infrastructure/fastapi/endpoints/admin/providers.py | 6 +++--- .../endpoints/{admin_router.py => admin/routers.py} | 6 +++--- api/utils/variables.py | 2 +- 5 files changed, 17 insertions(+), 11 deletions(-) create mode 100644 api/infrastructure/fastapi/endpoints/admin/__init__.py rename api/infrastructure/fastapi/endpoints/{admin_router.py => admin/routers.py} (94%) diff --git a/api/app.py b/api/app.py index c776ccff6..c290e4769 100644 --- a/api/app.py +++ b/api/app.py @@ -74,14 +74,13 @@ def _register_routers(app: FastAPI, configuration: Configuration) -> None: include_in_schema = enabled_router not in hidden_routers app.include_router(router=router, include_in_schema=include_in_schema) - # @TODO: legacy import before total clean archi migration - # @TODO: create admin folder in infrastructure.fastapi.endpoints with router declaration in __init__.py + # @TODO: legacy import, remove after total clean archi migration if RouterName.ADMIN not in disabled_routers: - module = import_module("api.infrastructure.fastapi.endpoints.admin_router") + module = import_module("api.endpoints.admin.routers") app.include_router(router=module.router, include_in_schema=RouterName.ADMIN not in hidden_routers) if RouterName.ADMIN not in disabled_routers: - module = import_module("api.infrastructure.fastapi.endpoints.admin.providers") + module = import_module("api.endpoints.admin.providers") app.include_router(router=module.router, include_in_schema=RouterName.ADMIN not in hidden_routers) diff --git a/api/infrastructure/fastapi/endpoints/admin/__init__.py b/api/infrastructure/fastapi/endpoints/admin/__init__.py new file mode 100644 index 000000000..c74ae4a85 --- /dev/null +++ b/api/infrastructure/fastapi/endpoints/admin/__init__.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +from api.utils.variables import RouterName + +router = APIRouter(prefix="/v1", tags=[RouterName.ADMIN.title()]) + +from . import providers, routers # noqa: F401 E402 diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py index 1502eaecd..adc9d0099 100644 --- a/api/infrastructure/fastapi/endpoints/admin/providers.py +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -1,7 +1,7 @@ import logging from typing import Literal -from fastapi import APIRouter, Body, Depends, Path, Query, Request, Security +from fastapi import Body, Depends, Path, Query, Request, Security from fastapi.responses import JSONResponse, Response from sqlalchemy.ext.asyncio import AsyncSession @@ -14,6 +14,7 @@ from api.helpers.models import ModelRegistry from api.infrastructure.fastapi.access import get_current_key from api.infrastructure.fastapi.context import RequestContext +from api.infrastructure.fastapi.endpoints.admin import router from api.infrastructure.fastapi.endpoints.exceptions import ( InconsistentModelMaxContextLengthHTTPException, InconsistentModelVectorSizeHTTPException, @@ -28,10 +29,9 @@ from api.infrastructure.fastapi.utils import get_documentation_responses from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess from api.utils.dependencies import get_model_registry, get_postgres_session -from api.utils.variables import EndpointRoute, RouterName +from api.utils.variables import EndpointRoute logger = logging.getLogger(__name__) -router = APIRouter(prefix="/v1", tags=[RouterName.ADMIN.title()]) @router.post( diff --git a/api/infrastructure/fastapi/endpoints/admin_router.py b/api/infrastructure/fastapi/endpoints/admin/routers.py similarity index 94% rename from api/infrastructure/fastapi/endpoints/admin_router.py rename to api/infrastructure/fastapi/endpoints/admin/routers.py index 7e33c90d6..77ac51840 100644 --- a/api/infrastructure/fastapi/endpoints/admin_router.py +++ b/api/infrastructure/fastapi/endpoints/admin/routers.py @@ -1,12 +1,13 @@ import logging -from fastapi import APIRouter, Body, Depends, Security +from fastapi import Body, Depends, Security from api.dependencies import create_router_use_case, get_request_context from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.domain.userinfo.errors import InsufficientPermissionError from api.infrastructure.fastapi.access import get_current_key from api.infrastructure.fastapi.context import RequestContext +from api.infrastructure.fastapi.endpoints.admin import router from api.infrastructure.fastapi.endpoints.exceptions import ( InsufficientPermissionHTTPException, InternalServerHTTPException, @@ -16,10 +17,9 @@ from api.infrastructure.fastapi.schemas.routers import CreateRouter, CreateRouterResponse from api.infrastructure.fastapi.utils import get_documentation_responses from api.use_cases.admin import CreateRouterCommand, CreateRouterUseCase, CreateRouterUseCaseSuccess -from api.utils.variables import EndpointRoute, RouterName +from api.utils.variables import EndpointRoute logger = logging.getLogger(__name__) -router = APIRouter(prefix="/v1", tags=[RouterName.ADMIN.title()]) @router.post( diff --git a/api/utils/variables.py b/api/utils/variables.py index 81617f416..e58482f6e 100644 --- a/api/utils/variables.py +++ b/api/utils/variables.py @@ -11,7 +11,7 @@ class RouterName(StrEnum): - ADMIN = ("admin", "api.endpoints.admin") + ADMIN = ("admin", "api.infrastructure.fastapi.endpoints.admin") AUDIO = ("audio", "api.endpoints.audio") AUTH = ("auth", "api.endpoints.auth") CHAT = ("chat", "api.endpoints.chat") From 1e9b511ba7b65360ad3b5ccb2061903ce47b830e Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Mon, 23 Feb 2026 12:11:34 +0100 Subject: [PATCH 11/13] fix(provider): move admin_router to admin folder --- api/tests/integ/test_swagger.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 api/tests/integ/test_swagger.py diff --git a/api/tests/integ/test_swagger.py b/api/tests/integ/test_swagger.py deleted file mode 100644 index 0857b6d98..000000000 --- a/api/tests/integ/test_swagger.py +++ /dev/null @@ -1,15 +0,0 @@ -from fastapi.testclient import TestClient -import pytest - -from api.utils.configuration import configuration - - -@pytest.mark.usefixtures("client") -class TestSwagger: - def test_swagger(self, client: TestClient): - """Test the GET /swagger response status code.""" - response = client.get_without_permissions(url=configuration.settings.swagger_docs_url) - assert response.status_code == 200, response.text - - response = client.get_without_permissions(url=configuration.settings.swagger_openapi_url) - assert response.status_code == 200, response.text From d19a1cca4bf9059af72ffc70a1aa5549be017c86 Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Mon, 23 Feb 2026 16:11:03 +0100 Subject: [PATCH 12/13] fix: unit tests --- .../fastapi/{utils.py => documentation.py} | 0 .../fastapi/endpoints/admin/providers.py | 2 +- .../fastapi/endpoints/admin/routers.py | 2 +- .../fastapi/endpoints/models.py | 2 +- .../endpoints/test_admin_providers.py | 18 +- .../postgres/test_postgresrouterrepository.py | 33 ++-- api/tests/unit/use_case/factories.py | 7 +- .../use_case/test_createproviderusecase.py | 34 +++- .../unit/use_case/test_createrouterusecase.py | 170 ++++++++---------- .../unit/use_case/test_getmodelsusecase.py | 2 +- .../admin/providers/_createproviderusecase.py | 1 + 11 files changed, 143 insertions(+), 128 deletions(-) rename api/infrastructure/fastapi/{utils.py => documentation.py} (100%) diff --git a/api/infrastructure/fastapi/utils.py b/api/infrastructure/fastapi/documentation.py similarity index 100% rename from api/infrastructure/fastapi/utils.py rename to api/infrastructure/fastapi/documentation.py diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py index adc9d0099..d402b2ae6 100644 --- a/api/infrastructure/fastapi/endpoints/admin/providers.py +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -14,6 +14,7 @@ from api.helpers.models import ModelRegistry from api.infrastructure.fastapi.access import get_current_key from api.infrastructure.fastapi.context import RequestContext +from api.infrastructure.fastapi.documentation import get_documentation_responses from api.infrastructure.fastapi.endpoints.admin import router from api.infrastructure.fastapi.endpoints.exceptions import ( InconsistentModelMaxContextLengthHTTPException, @@ -26,7 +27,6 @@ RouterNotFoundHTTPException, ) from api.infrastructure.fastapi.schemas.providers import CreateProvider, CreateProviderResponse, Provider, Providers, UpdateProvider -from api.infrastructure.fastapi.utils import get_documentation_responses from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess from api.utils.dependencies import get_model_registry, get_postgres_session from api.utils.variables import EndpointRoute diff --git a/api/infrastructure/fastapi/endpoints/admin/routers.py b/api/infrastructure/fastapi/endpoints/admin/routers.py index 77ac51840..02eb47f57 100644 --- a/api/infrastructure/fastapi/endpoints/admin/routers.py +++ b/api/infrastructure/fastapi/endpoints/admin/routers.py @@ -7,6 +7,7 @@ from api.domain.userinfo.errors import InsufficientPermissionError from api.infrastructure.fastapi.access import get_current_key from api.infrastructure.fastapi.context import RequestContext +from api.infrastructure.fastapi.documentation import get_documentation_responses from api.infrastructure.fastapi.endpoints.admin import router from api.infrastructure.fastapi.endpoints.exceptions import ( InsufficientPermissionHTTPException, @@ -15,7 +16,6 @@ RouterAlreadyExistsHTTPException, ) from api.infrastructure.fastapi.schemas.routers import CreateRouter, CreateRouterResponse -from api.infrastructure.fastapi.utils import get_documentation_responses from api.use_cases.admin import CreateRouterCommand, CreateRouterUseCase, CreateRouterUseCaseSuccess from api.utils.variables import EndpointRoute diff --git a/api/infrastructure/fastapi/endpoints/models.py b/api/infrastructure/fastapi/endpoints/models.py index 6a9324d17..9839ac65a 100644 --- a/api/infrastructure/fastapi/endpoints/models.py +++ b/api/infrastructure/fastapi/endpoints/models.py @@ -3,9 +3,9 @@ from api.dependencies import get_models_use_case from api.infrastructure.fastapi.access import get_current_key +from api.infrastructure.fastapi.documentation import get_documentation_responses from api.infrastructure.fastapi.endpoints.exceptions import ModelNotFoundHTTPException from api.infrastructure.fastapi.schemas.models import Model, Models -from api.infrastructure.fastapi.utils import get_documentation_responses from api.use_cases.models import GetModelsUseCase from api.use_cases.models._getmodelsusecase import ModelNotFound, Success from api.utils.variables import EndpointRoute, RouterName diff --git a/api/tests/integration/endpoints/test_admin_providers.py b/api/tests/integration/endpoints/test_admin_providers.py index d5039f32c..691d4b3ff 100644 --- a/api/tests/integration/endpoints/test_admin_providers.py +++ b/api/tests/integration/endpoints/test_admin_providers.py @@ -10,6 +10,7 @@ from api.domain.model.errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError from api.domain.provider.errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError from api.domain.router.errors import RouterNotFoundError +from api.domain.userinfo.errors import InsufficientPermissionError from api.schemas.models import ModelType from api.tests.helpers import create_token from api.tests.integration.factories import RouterSQLFactory, UserSQLFactory @@ -83,13 +84,21 @@ async def test_happy_path(self, client: AsyncClient, db_session): @pytest.mark.parametrize( "use_case_result,expected_status,expected_detail", [ - (RouterNotFoundError(router_id=1), 404, "Model router 1 not found."), + ( + RouterNotFoundError(router_id=1), + 404, + "Model router 1 not found.", + ), ( InvalidProviderTypeError(provider_type="tei", router_type="text-generation"), 400, "Invalid model provider type tei for text-generation router.", ), - (ProviderNotReachableError(model_name="my-model"), 424, "Model provider my-model not reachable."), + ( + ProviderNotReachableError(model_name="my-model"), + 424, + "Model provider my-model not reachable.", + ), ( ProviderAlreadyExistsError(model_name="my-model", url=DEFAULT_PROVIDER_URL, router_id=1), 409, @@ -105,6 +114,11 @@ async def test_happy_path(self, client: AsyncClient, db_session): 403, "Inconsistent vector size for my-router. Expected: 768. Actual: 384", ), + ( + InsufficientPermissionError(), + 403, + "Insufficient rights.", + ), ], ) async def test_error_maps_to_correct_http_status(self, client: AsyncClient, app, use_case_result, expected_status, expected_detail): diff --git a/api/tests/integration/postgres/test_postgresrouterrepository.py b/api/tests/integration/postgres/test_postgresrouterrepository.py index a9ace92b7..c77c7d532 100644 --- a/api/tests/integration/postgres/test_postgresrouterrepository.py +++ b/api/tests/integration/postgres/test_postgresrouterrepository.py @@ -1,14 +1,11 @@ import pytest from api.domain.key.entities import MASTER_USER_ID -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.model import ModelType as RouterType +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.infrastructure.postgres import PostgresRouterRepository -from api.tests.integration.factories import ( - OrganizationSQLFactory, - RouterSQLFactory, - UserSQLFactory, -) +from api.tests.integration.factories import OrganizationSQLFactory, RouterSQLFactory, UserSQLFactory @pytest.fixture @@ -29,13 +26,13 @@ async def test_get_all_routers_should_return_all_routers(self, repository, db_se user_2 = UserSQLFactory() router_1 = RouterSQLFactory( - user=user_1, name="router_1", type=ModelType.TEXT_GENERATION, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, providers=2 + user=user_1, name="router_1", type=RouterType.TEXT_GENERATION, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, providers=2 ) router_2 = RouterSQLFactory( - user=user_1, name="router_2", type=ModelType.TEXT_EMBEDDINGS_INFERENCE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, providers=1 + user=user_1, name="router_2", type=RouterType.TEXT_EMBEDDINGS_INFERENCE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, providers=1 ) router_3 = RouterSQLFactory( - user=user_2, name="router_3", type=ModelType.TEXT_EMBEDDINGS_INFERENCE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, providers=1 + user=user_2, name="router_3", type=RouterType.TEXT_EMBEDDINGS_INFERENCE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, providers=1 ) # Act @@ -48,7 +45,7 @@ async def test_get_all_routers_should_return_all_routers(self, repository, db_se assert router_names == {router_1.name, router_2.name, router_3.name} result_router_1 = result_routers[0] - assert result_router_1.type == ModelType.TEXT_GENERATION + assert result_router_1.type == RouterType.TEXT_GENERATION assert result_router_1.providers == 2 assert result_router_1.cost_prompt_tokens == 0.001 assert result_router_1.cost_completion_tokens == 0.002 @@ -58,7 +55,7 @@ async def test_get_all_routers_should_return_all_routers(self, repository, db_se async def test_get_all_routers_should_return_routers_with_master_id_user(self, repository, db_session): # Arrange RouterSQLFactory( - user=None, name="router_1", type=ModelType.TEXT_GENERATION, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, providers=2 + user=None, name="router_1", type=RouterType.TEXT_GENERATION, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, providers=2 ) # Act @@ -112,7 +109,7 @@ async def test_create_router_should_return_created_router_without_alias(self, re # Act result = await repository.create_router( name="test-router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, @@ -122,7 +119,7 @@ async def test_create_router_should_return_created_router_without_alias(self, re # Assert assert isinstance(result, Router) assert result.name == "test-router" - assert result.type == ModelType.TEXT_GENERATION + assert result.type == RouterType.TEXT_GENERATION assert result.load_balancing_strategy == RouterLoadBalancingStrategy.SHUFFLE assert result.cost_prompt_tokens == 0.001 assert result.cost_completion_tokens == 0.002 @@ -142,7 +139,7 @@ async def test_create_router_should_return_router_name_already_exists_when_name_ # Act result = await repository.create_router( name="duplicate-router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, @@ -160,7 +157,7 @@ async def test_create_router_with_master_user_id_should_set_db_user_id_to_null(s # Act result = await repository.create_router( name="master-router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, @@ -180,7 +177,7 @@ async def test_create_router_with_aliases_should_insert_aliases(self, repository # Act result = await repository.create_router( name="router-with-aliases", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, @@ -207,7 +204,7 @@ async def test_create_router_should_return_router_alias_already_exists_when_one_ # Act result = await repository.create_router( name="router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, aliases=[duplicate_alias], load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, @@ -229,7 +226,7 @@ async def test_create_router_should_return_router_alias_already_exists_when_seve # Act result = await repository.create_router( name="router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, aliases=duplicate_aliases, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, diff --git a/api/tests/unit/use_case/factories.py b/api/tests/unit/use_case/factories.py index c1203f06d..f247f071f 100644 --- a/api/tests/unit/use_case/factories.py +++ b/api/tests/unit/use_case/factories.py @@ -4,9 +4,10 @@ import factory from factory import fuzzy +from api.domain.model import ModelType as RouterType from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType from api.domain.role.entities import Limit, LimitType, PermissionType, Role -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.user.entities import User from api.domain.userinfo.entities import UserInfo @@ -43,7 +44,7 @@ class Meta: id = factory.Sequence(lambda n: n + 1) name = factory.Faker("bothify", text="router_####") user_id = factory.Faker("random_int", min=1, max=1000) - type = factory.Faker("random_element", elements=list(ModelType)) + type = factory.Faker("random_element", elements=list(RouterType)) aliases = None load_balancing_strategy = factory.Faker("random_element", elements=list(RouterLoadBalancingStrategy)) vector_size = None @@ -63,7 +64,7 @@ class Params: ) embedding = factory.Trait( - type=ModelType.TEXT_EMBEDDINGS_INFERENCE, + type=RouterType.TEXT_EMBEDDINGS_INFERENCE, vector_size=factory.Faker("random_element", elements=[384, 768, 1536, 3072]), max_context_length=factory.Faker("random_element", elements=[512, 1024, 2048, 8192]), ) diff --git a/api/tests/unit/use_case/test_createproviderusecase.py b/api/tests/unit/use_case/test_createproviderusecase.py index 8d84eb0b5..dc0e484e7 100644 --- a/api/tests/unit/use_case/test_createproviderusecase.py +++ b/api/tests/unit/use_case/test_createproviderusecase.py @@ -2,13 +2,14 @@ import pytest +from api.domain.model import ModelType as RouterType from api.domain.model.errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError from api.domain.provider import ProviderCapabilities from api.domain.provider.entities import ProviderCarbonFootprintZone, ProviderType from api.domain.provider.errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError -from api.domain.router.entities import ModelType from api.domain.router.errors import RouterNotFoundError -from api.tests.unit.use_case.factories import ProviderFactory, RouterFactory +from api.domain.userinfo.errors import InsufficientPermissionError +from api.tests.unit.use_case.factories import ProviderFactory, RouterFactory, UserInfoFactory from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess @@ -18,6 +19,7 @@ def use_case(): router_repository=AsyncMock(), provider_repository=AsyncMock(), provider_gateway=AsyncMock(), + user_info_repository=AsyncMock(), ) @@ -26,7 +28,7 @@ def sample_router(): return RouterFactory( id=1, name="test-router", - type=ModelType.TEXT_GENERATION, + type=RouterType.TEXT_GENERATION, providers=0, ) @@ -36,7 +38,7 @@ def sample_router_with_providers(): return RouterFactory( id=1, name="test-router", - type=ModelType.TEXT_GENERATION, + type=RouterType.TEXT_GENERATION, providers=2, max_context_length=4096, vector_size=None, @@ -48,7 +50,7 @@ def sample_embedding_router_with_providers(): return RouterFactory( id=1, name="embedding-router", - type=ModelType.TEXT_EMBEDDINGS_INFERENCE, + type=RouterType.TEXT_EMBEDDINGS_INFERENCE, providers=1, max_context_length=512, vector_size=768, @@ -118,7 +120,12 @@ async def test_should_create_provider_when_router_exists_without_any_provider(se assert result.provider == sample_provider use_case.router_repository.get_router_by_id.assert_called_once_with(router_id=1) use_case.provider_gateway.get_capabilities.assert_called_once_with( - provider_type=ProviderType.VLLM, url="https://example.com/", key=None, timeout=30, model_name="my-model" + router_type=RouterType.TEXT_GENERATION, + provider_type=ProviderType.VLLM, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", ) use_case.provider_repository.create_provider.assert_called_once_with( router_id=1, @@ -218,7 +225,7 @@ async def test_should_return_router_not_found_error_when_router_does_not_exist(s @pytest.mark.asyncio async def test_should_return_invalid_provider_type_error_when_type_not_compatible(self, use_case, default_command): # Arrange - router = RouterFactory(id=1, name="tei-router", type=ModelType.TEXT_CLASSIFICATION) + router = RouterFactory(id=1, name="tei-router", type=RouterType.TEXT_CLASSIFICATION) use_case.router_repository.get_router_by_id.return_value = router # Act @@ -227,7 +234,7 @@ async def test_should_return_invalid_provider_type_error_when_type_not_compatibl # Assert assert isinstance(result, InvalidProviderTypeError) assert result.provider_type == ProviderType.VLLM.value - assert result.router_type == ModelType.TEXT_CLASSIFICATION + assert result.router_type == RouterType.TEXT_CLASSIFICATION use_case.provider_gateway.get_capabilities.assert_not_called() use_case.provider_repository.create_provider.assert_not_called() @@ -294,3 +301,14 @@ async def test_should_return_provider_already_exists_error(self, use_case, sampl assert result.model_name == "my-model" assert result.url == "https://example.com/" assert result.router_id == 1 + + @pytest.mark.asyncio + async def test_should_return_insufficient_permission_error_when_user_not_admin(self, use_case, default_command): + # Arrange + use_case.user_info_repository.get_user_info.return_value = UserInfoFactory(id=1, without_permission=True, limits=[]) + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, InsufficientPermissionError) diff --git a/api/tests/unit/use_case/test_createrouterusecase.py b/api/tests/unit/use_case/test_createrouterusecase.py index bf95c3cde..9ba32d21e 100644 --- a/api/tests/unit/use_case/test_createrouterusecase.py +++ b/api/tests/unit/use_case/test_createrouterusecase.py @@ -2,25 +2,17 @@ import pytest -from api.domain.router.entities import ModelType, RouterLoadBalancingStrategy +from api.domain.model import ModelType as RouterType +from api.domain.router.entities import RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.domain.userinfo.errors import InsufficientPermissionError from api.tests.unit.use_case.factories import RouterFactory, UserInfoFactory -from api.use_cases.admin import ( - CreateRouterUseCase, -) +from api.use_cases.admin import CreateRouterCommand, CreateRouterUseCase @pytest.fixture -def router_repository(): - repo = AsyncMock() - return repo - - -@pytest.fixture -def user_info_repository(): - repo = AsyncMock() - return repo +def use_case(): + return CreateRouterUseCase(router_repository=AsyncMock(), user_info_repository=AsyncMock()) @pytest.fixture @@ -38,7 +30,7 @@ def sample_router_with_aliases(): return RouterFactory( id=1, name="test-model", - type=ModelType.TEXT_GENERATION, + type=RouterType.TEXT_GENERATION, aliases=["alias1", "alias2"], user_id=1, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, @@ -48,41 +40,33 @@ def sample_router_with_aliases(): ) -@pytest.fixture -def use_case(router_repository, user_info_repository): - return CreateRouterUseCase( - router_repository=router_repository, - user_info_repository=user_info_repository, - ) - - class TestCreateRouterUseCase: @pytest.mark.asyncio - async def test_should_create_router_with_aliases_when_aliases_are_given( - self, router_repository, user_info_repository, admin_user_info, sample_router_with_aliases, use_case - ): + async def test_should_create_router_with_aliases_when_aliases_are_given(self, use_case, admin_user_info, sample_router_with_aliases): # Arrange - user_info_repository.get_user_info.return_value = admin_user_info - router_repository.create_router.return_value = sample_router_with_aliases + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.router_repository.create_router.return_value = sample_router_with_aliases # Act result = await use_case.execute( - user_id=1, - name="test-model", - router_type=ModelType.TEXT_GENERATION, - aliases=["alias1", "alias2"], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.01, - cost_completion_tokens=0.02, + command=CreateRouterCommand( + user_id=1, + name="test-model", + router_type=RouterType.TEXT_GENERATION, + aliases=["alias1", "alias2"], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.01, + cost_completion_tokens=0.02, + ) ) # Assert assert result.router == sample_router_with_aliases - user_info_repository.get_user_info.assert_called_once_with(user_id=1) - router_repository.create_router.assert_called_once_with( + use_case.user_info_repository.get_user_info.assert_called_once_with(user_id=1) + use_case.router_repository.create_router.assert_called_once_with( name="test-model", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.01, cost_completion_tokens=0.02, @@ -91,36 +75,36 @@ async def test_should_create_router_with_aliases_when_aliases_are_given( ) @pytest.mark.asyncio - async def test_should_create_router_without_aliases_if_no_alias_is_given( - self, router_repository, user_info_repository, admin_user_info, use_case - ): + async def test_should_create_router_without_aliases_if_no_alias_is_given(self, use_case, admin_user_info): # Arrange router_without_aliases = RouterFactory( id=2, name="model-no-alias", - type=ModelType.TEXT_GENERATION, + type=RouterType.TEXT_GENERATION, aliases=[], user_id=1, ) - user_info_repository.get_user_info.return_value = admin_user_info - router_repository.create_router.return_value = router_without_aliases + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.router_repository.create_router.return_value = router_without_aliases # Act result = await use_case.execute( - user_id=1, - name="model-no-alias", - router_type=ModelType.TEXT_GENERATION, - aliases=[], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.0, - cost_completion_tokens=0.0, + command=CreateRouterCommand( + user_id=1, + name="model-no-alias", + router_type=RouterType.TEXT_GENERATION, + aliases=[], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.0, + cost_completion_tokens=0.0, + ) ) # Assert assert result.router == router_without_aliases - router_repository.create_router.assert_called_once_with( + use_case.router_repository.create_router.assert_called_once_with( name="model-no-alias", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, @@ -129,85 +113,85 @@ async def test_should_create_router_without_aliases_if_no_alias_is_given( ) @pytest.mark.asyncio - async def test_should_return_insufficient_permission_error_if_user_not_admin( - self, router_repository, user_info_repository, non_admin_user_info, use_case - ): + async def test_should_return_insufficient_permission_error_if_user_not_admin(self, use_case, non_admin_user_info): # Arrange - user_info_repository.get_user_info.return_value = non_admin_user_info + use_case.user_info_repository.get_user_info.return_value = non_admin_user_info # Act error = await use_case.execute( - user_id=2, - name="test-model", - router_type=ModelType.TEXT_GENERATION, - aliases=[], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.0, - cost_completion_tokens=0.0, + command=CreateRouterCommand( + user_id=2, + name="test-model", + router_type=RouterType.TEXT_GENERATION, + aliases=[], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.0, + cost_completion_tokens=0.0, + ) ) # Assert assert isinstance(error, InsufficientPermissionError) - router_repository.create_router.assert_not_called() + use_case.router_repository.create_router.assert_not_called() @pytest.mark.asyncio - async def test_should_return_router_alias_already_exists_when_alias_already_exists( - self, router_repository, user_info_repository, admin_user_info, use_case - ): + async def test_should_return_router_alias_already_exists_when_alias_already_exists(self, use_case, admin_user_info): # Arrange - user_info_repository.get_user_info.return_value = admin_user_info - router_repository.create_router.return_value = RouterAliasAlreadyExistsError(aliases=["alias1"]) + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.router_repository.create_router.return_value = RouterAliasAlreadyExistsError(aliases=["alias1"]) # Act error = await use_case.execute( - user_id=1, - name="test-model", - router_type=ModelType.TEXT_GENERATION, - aliases=["alias1", "alias2"], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.0, - cost_completion_tokens=0.0, + command=CreateRouterCommand( + user_id=1, + name="test-model", + router_type=RouterType.TEXT_GENERATION, + aliases=["alias1", "alias2"], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.0, + cost_completion_tokens=0.0, + ) ) # Assert assert isinstance(error, RouterAliasAlreadyExistsError) - router_repository.create_router.assert_called_once_with( + use_case.router_repository.create_router.assert_called_once_with( aliases=["alias1", "alias2"], cost_completion_tokens=0.0, cost_prompt_tokens=0.0, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, name="test-model", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, user_id=1, ) @pytest.mark.asyncio - async def test_should_return_router_name_already_exists_when_name_already_exists( - self, router_repository, user_info_repository, admin_user_info, use_case - ): + async def test_should_return_router_name_already_exists_when_name_already_exists(self, use_case, admin_user_info): # Arrange - user_info_repository.get_user_info.return_value = admin_user_info - router_repository.create_router.return_value = RouterNameAlreadyExistsError(name="existing-router") + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.router_repository.create_router.return_value = RouterNameAlreadyExistsError(name="existing-router") # Act error = await use_case.execute( - user_id=1, - name="test-model", - router_type=ModelType.TEXT_GENERATION, - aliases=["alias1", "alias2"], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.0, - cost_completion_tokens=0.0, + command=CreateRouterCommand( + user_id=admin_user_info.id, + name="test-model", + router_type=RouterType.TEXT_GENERATION, + aliases=["alias1", "alias2"], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.0, + cost_completion_tokens=0.0, + ) ) # Assert assert isinstance(error, RouterNameAlreadyExistsError) - router_repository.create_router.assert_called_once_with( + use_case.router_repository.create_router.assert_called_once_with( + user_id=admin_user_info.id, name="test-model", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, - user_id=admin_user_info.id, aliases=["alias1", "alias2"], ) diff --git a/api/tests/unit/use_case/test_getmodelsusecase.py b/api/tests/unit/use_case/test_getmodelsusecase.py index a46458382..7dccc4e2f 100644 --- a/api/tests/unit/use_case/test_getmodelsusecase.py +++ b/api/tests/unit/use_case/test_getmodelsusecase.py @@ -3,8 +3,8 @@ import pytest +from api.domain.model import ModelType from api.domain.role.entities import LimitType -from api.domain.router.entities import ModelType from api.domain.userinfo.entities import Limit from api.tests.unit.use_case.factories import RouterFactory, UserInfoFactory from api.use_cases.models import GetModelsUseCase diff --git a/api/use_cases/admin/providers/_createproviderusecase.py b/api/use_cases/admin/providers/_createproviderusecase.py index d8d32dbb3..e4068430a 100644 --- a/api/use_cases/admin/providers/_createproviderusecase.py +++ b/api/use_cases/admin/providers/_createproviderusecase.py @@ -78,6 +78,7 @@ async def execute(self, command: CreateProviderCommand) -> CreateProviderUseCase timeout=command.timeout, model_name=command.model_name, ) + # @ TODO: separate health check logic from get_capabilities match result: case ProviderNotReachableError() as error: From e944694481591327008272847f94fec863628d70 Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Mon, 23 Feb 2026 15:12:02 +0000 Subject: [PATCH 13/13] Update unit coverage badge --- .github/badges/coverage.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/badges/coverage.json b/.github/badges/coverage.json index b9757f0da..3e7151f5a 100644 --- a/.github/badges/coverage.json +++ b/.github/badges/coverage.json @@ -1 +1 @@ -{"schemaVersion":1,"label":"coverage","message":"50.68%","color":"red"} +{"schemaVersion":1,"label":"coverage","message":"49.87%","color":"red"}