diff --git a/.github/badges/coverage.json b/.github/badges/coverage.json index f96bab692..3e7151f5a 100644 --- a/.github/badges/coverage.json +++ b/.github/badges/coverage.json @@ -1 +1 @@ -{"schemaVersion":1,"label":"coverage","message":"50.85%","color":"red"} +{"schemaVersion":1,"label":"coverage","message":"49.87%","color":"red"} diff --git a/.gitignore b/.gitignore index 285b58bde..6474fa8d1 100644 --- a/.gitignore +++ b/.gitignore @@ -222,5 +222,4 @@ playground/.gitignore playground/requirements.txt run.sh .claude - bruno \ No newline at end of file diff --git a/api/app.py b/api/app.py index 0d2bcb9d9..c290e4769 100644 --- a/api/app.py +++ b/api/app.py @@ -74,10 +74,13 @@ def _register_routers(app: FastAPI, configuration: Configuration) -> None: include_in_schema = enabled_router not in hidden_routers app.include_router(router=router, include_in_schema=include_in_schema) - # @TODO: legacy import before total clean archi migration - # @TODO: create admin folder in infrastructure.fastapi.endpoints with router declaration in __init__.py + # @TODO: legacy import, remove after total clean archi migration if RouterName.ADMIN not in disabled_routers: - module = import_module("api.infrastructure.fastapi.endpoints.admin_router") + module = import_module("api.endpoints.admin.routers") + app.include_router(router=module.router, include_in_schema=RouterName.ADMIN not in hidden_routers) + + if RouterName.ADMIN not in disabled_routers: + module = import_module("api.endpoints.admin.providers") app.include_router(router=module.router, include_in_schema=RouterName.ADMIN not in hidden_routers) diff --git a/api/clients/model/_albertmodelprovider.py b/api/clients/model/_albertmodelprovider.py index 3df6dbed4..27824082d 100644 --- a/api/clients/model/_albertmodelprovider.py +++ b/api/clients/model/_albertmodelprovider.py @@ -54,6 +54,7 @@ async def get_max_context_length(self) -> int | None: response = await client.get(url=url, headers=self.headers, timeout=self.timeout) response.raise_for_status() except Exception as e: + # TODO: remove exc_info=True and return error instead of exception logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True) raise AssertionError(f"Model is not reachable ({e}).") diff --git a/api/clients/model/_basemodelprovider.py b/api/clients/model/_basemodelprovider.py index 9f43e0a78..3978c6fed 100644 --- a/api/clients/model/_basemodelprovider.py +++ b/api/clients/model/_basemodelprovider.py @@ -76,7 +76,7 @@ def import_module(type: ProviderType) -> "type[BaseModelProvider]": return getattr(module, f"{type.capitalize()}ModelProvider") @staticmethod - async def get_max_context_length(self) -> int | None: + async def get_max_context_length() -> int | None: """ Get the max context length of the model provider to store in the database. Useful to check provider consistency. diff --git a/api/clients/model/_mistralmodelprovider.py b/api/clients/model/_mistralmodelprovider.py index 9d436791a..c7c688f58 100644 --- a/api/clients/model/_mistralmodelprovider.py +++ b/api/clients/model/_mistralmodelprovider.py @@ -53,6 +53,7 @@ async def get_max_context_length(self) -> int | None: async with httpx.AsyncClient() as client: response = await client.get(url=url, headers=self.headers, timeout=self.timeout) response.raise_for_status() + except Exception as e: logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True) raise AssertionError(f"Model is not reachable ({e}).") diff --git a/api/dependencies.py b/api/dependencies.py index d1877a9a1..5cc2bafb9 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -5,22 +5,17 @@ from sqlalchemy.ext.asyncio import AsyncSession from api.domain.key import KeyRepository -from api.infrastructure.postgres import PostgresKeyRepository, PostgresRouterRepository, PostgresUserInfoRepository +from api.infrastructure.model import ModelProviderGateway +from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository from api.schemas.core.context import RequestContext from api.use_cases.admin import CreateRouterUseCase +from api.use_cases.admin.providers import CreateProviderUseCase from api.use_cases.models import GetModelsUseCase from api.utils.configuration import configuration from api.utils.context import global_context, request_context async def get_postgres_session() -> AsyncGenerator[AsyncSession]: - """ - Get a PostgreSQL postgres_session from the global context. - - Returns: - AsyncSession: A PostgreSQL postgres_session instance. - """ - session_factory = global_context.postgres_session_factory async with session_factory() as postgres_session: try: @@ -34,13 +29,6 @@ async def get_postgres_session() -> AsyncGenerator[AsyncSession]: def get_request_context() -> ContextVar[RequestContext]: - """ - Get the RequestContext ContextVar from the global context. - - Returns: - ContextVar[RequestContext]: The RequestContext ContextVar instance. - """ - return request_context @@ -55,6 +43,17 @@ def get_models_use_case( ) +def create_provider_use_case_factory( + postgres_session: AsyncSession = Depends(get_postgres_session), +) -> CreateProviderUseCase: + return CreateProviderUseCase( + router_repository=PostgresRouterRepository(postgres_session=postgres_session, app_title=configuration.settings.app_title), + provider_repository=PostgresProviderRepository(postgres_session=postgres_session), + provider_gateway=ModelProviderGateway(), + user_info_repository=PostgresUserInfoRepository(postgres_session=postgres_session), + ) + + def create_router_use_case(postgres_session: AsyncSession = Depends(get_postgres_session)) -> CreateRouterUseCase: return CreateRouterUseCase( router_repository=PostgresRouterRepository(postgres_session=postgres_session, app_title=configuration.settings.app_title), diff --git a/api/domain/model/__init__.py b/api/domain/model/__init__.py new file mode 100644 index 000000000..21c9905e0 --- /dev/null +++ b/api/domain/model/__init__.py @@ -0,0 +1,4 @@ +from .entities import Model, ModelCosts, ModelType +from .errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError + +__all__ = ["ModelType", "Model", "ModelCosts", "InconsistentModelMaxContextLengthError", "InconsistentModelVectorSizeError"] diff --git a/api/domain/model/entities.py b/api/domain/model/entities.py new file mode 100644 index 000000000..e2f017359 --- /dev/null +++ b/api/domain/model/entities.py @@ -0,0 +1,34 @@ +from enum import Enum + +from pydantic import BaseModel, Field + + +class Metric(str, Enum): + TTFT = "ttft" # time to first token + LATENCY = "latency" # requests latency + INFLIGHT = "inflight" # requests concurrency + PERFORMANCE = "performance" # custom performance metric + + +class ModelCosts(BaseModel): + prompt_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)") + completion_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million completion tokens (decrease user budget)") + + +class ModelType(str, Enum): + AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition" + IMAGE_TEXT_TO_TEXT = "image-text-to-text" + IMAGE_TO_TEXT = "image-to-text" + TEXT_EMBEDDINGS_INFERENCE = "text-embeddings-inference" + TEXT_GENERATION = "text-generation" + TEXT_CLASSIFICATION = "text-classification" + + +class Model(BaseModel): + id: str = Field(..., description="The model identifier, which can be referenced in the API endpoints.") + type: ModelType = Field(..., description="The type of the model, which can be used to identify the model type.", examples=["text-generation"]) # fmt: off + aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off + created: int = Field(..., description="Time of creation, as Unix timestamp.") + owned_by: str = Field(..., description="The organization that owns the model.") + max_context_length: int | None = Field(default=None, description="Maximum amount of tokens a context could contains. Makes sure it is the same for all models.") # fmt: off + costs: ModelCosts = Field(..., description="Costs of the model.") diff --git a/api/domain/model/errors.py b/api/domain/model/errors.py new file mode 100644 index 000000000..a8d37c349 --- /dev/null +++ b/api/domain/model/errors.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + + +@dataclass +class InconsistentModelVectorSizeError: + expected_vector_size: int + actual_vector_size: int + router_name: str + + +@dataclass +class InconsistentModelMaxContextLengthError: + expected_max_context_length: int + actual_max_context_length: int + router_name: str diff --git a/api/domain/provider/__init__.py b/api/domain/provider/__init__.py new file mode 100644 index 000000000..7ba949199 --- /dev/null +++ b/api/domain/provider/__init__.py @@ -0,0 +1,17 @@ +from api.domain.provider._providergateway import ProviderCapabilities, ProviderGateway +from api.domain.provider._providerrepository import ProviderRepository + +from .entities import Provider, ProviderCarbonFootprintZone, ProviderType +from .errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError + +__all__ = [ + "InvalidProviderTypeError", + "ProviderNotReachableError", + "ProviderRepository", + "ProviderGateway", + "ProviderCapabilities", + "ProviderType", + "Provider", + "ProviderCarbonFootprintZone", + "ProviderAlreadyExistsError", +] diff --git a/api/domain/provider/_providergateway.py b/api/domain/provider/_providergateway.py new file mode 100644 index 000000000..7f26df361 --- /dev/null +++ b/api/domain/provider/_providergateway.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from api.domain.model import ModelType as RouterType +from api.domain.provider.entities import ProviderType +from api.domain.provider.errors import ProviderNotReachableError + + +@dataclass +class ProviderCapabilities: + max_context_length: int | None + vector_size: int | None + + +class ProviderGateway(ABC): + @abstractmethod + async def get_capabilities( + self, + router_type: RouterType, + provider_type: ProviderType, + url: str, + key: str | None, + timeout: int, + model_name: str, + ) -> ProviderCapabilities | ProviderNotReachableError: + pass diff --git a/api/domain/provider/_providerrepository.py b/api/domain/provider/_providerrepository.py new file mode 100644 index 000000000..cf5d715cf --- /dev/null +++ b/api/domain/provider/_providerrepository.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod + +from api.domain.model.entities import Metric +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.errors import ProviderAlreadyExistsError + + +class ProviderRepository(ABC): + @abstractmethod + async def create_provider( + self, + router_id: int, + user_id: int, + provider_type: ProviderType, + url: str, + key: str | None, + timeout: int, + model_name: str, + model_hosting_zone: ProviderCarbonFootprintZone, + model_total_params: int, + model_active_params: int, + qos_metric: Metric | None, + qos_limit: float | None, + vector_size: int, + max_context_length: int, + ) -> Provider | ProviderAlreadyExistsError: + pass diff --git a/api/domain/provider/entities.py b/api/domain/provider/entities.py new file mode 100644 index 000000000..9ffcbe3c5 --- /dev/null +++ b/api/domain/provider/entities.py @@ -0,0 +1,76 @@ +from enum import Enum +from typing import Literal + +import pycountry +from pydantic import Field, constr + +from api.domain.model.entities import ModelType +from api.schemas import BaseModel +from api.schemas.core.models import Metric + +# Add world as a country code, default value of the carbon footprint computation framework +country_codes = [country.alpha_3 for country in pycountry.countries] + ["WOR"] +country_codes_dict = {str(code).upper(): str(code) for code in sorted(set(country_codes))} +ProviderCarbonFootprintZone: type[Enum] = Enum("ProviderCarbonFootprintZone", country_codes_dict, type=str) + + +class ProviderType(str, Enum): + ALBERT = "albert" + OPENAI = "openai" + MISTRAL = "mistral" + TEI = "tei" + VLLM = "vllm" + + +COMPATIBLE_PROVIDER_TYPES: dict[ModelType, list[str]] = { + ModelType.AUTOMATIC_SPEECH_RECOGNITION: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.IMAGE_TEXT_TO_TEXT: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_EMBEDDINGS_INFERENCE: [ + ProviderType.ALBERT.value, + ProviderType.OPENAI.value, + ProviderType.TEI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_GENERATION: [ + ProviderType.ALBERT.value, + ProviderType.MISTRAL.value, + ProviderType.OPENAI.value, + ProviderType.VLLM.value, + ], + ModelType.TEXT_CLASSIFICATION: [ + ProviderType.ALBERT.value, + ProviderType.TEI.value, + ], + ModelType.IMAGE_TO_TEXT: [ + ProviderType.MISTRAL.value, + ], +} + + +class Provider(BaseModel): + object: Literal["provider"] = "provider" + id: int = Field(..., description="Provider ID.") # fmt: off + router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off + user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off + type: ProviderType = Field(..., description="Provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off + key: str | None = Field(description="Provider API key.") # fmt: off + timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off + updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off diff --git a/api/domain/provider/errors.py b/api/domain/provider/errors.py new file mode 100644 index 000000000..64cad8275 --- /dev/null +++ b/api/domain/provider/errors.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + + +@dataclass +class InvalidProviderTypeError: + provider_type: str + router_type: str + + +@dataclass +class ProviderNotReachableError: + model_name: str + + +@dataclass +class ProviderAlreadyExistsError: + model_name: str + url: str + router_id: int diff --git a/api/domain/router/_routerrepository.py b/api/domain/router/_routerrepository.py index f89225fc8..58ec5eedb 100644 --- a/api/domain/router/_routerrepository.py +++ b/api/domain/router/_routerrepository.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.model import ModelType as RouterType +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError @@ -13,11 +14,19 @@ async def get_organization_name(self, user_id) -> str: async def get_all_routers(self) -> list[Router]: pass + @abstractmethod + async def get_router_by_id(self, router_id: int) -> Router | None: + pass + + @abstractmethod + async def get_aliases_by_router_id(self, router_id: int) -> Router | None: + pass + @abstractmethod async def create_router( self, name: str, - router_type: ModelType, + router_type: RouterType, load_balancing_strategy: RouterLoadBalancingStrategy, cost_prompt_tokens: float, cost_completion_tokens: float, diff --git a/api/domain/router/entities.py b/api/domain/router/entities.py index 57fbbac9c..49a2d47dd 100644 --- a/api/domain/router/entities.py +++ b/api/domain/router/entities.py @@ -2,29 +2,7 @@ from pydantic import BaseModel, Field - -class ModelCosts(BaseModel): - prompt_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)") - completion_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million completion tokens (decrease user budget)") - - -class ModelType(str, Enum): - AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition" - IMAGE_TEXT_TO_TEXT = "image-text-to-text" - IMAGE_TO_TEXT = "image-to-text" - TEXT_EMBEDDINGS_INFERENCE = "text-embeddings-inference" - TEXT_GENERATION = "text-generation" - TEXT_CLASSIFICATION = "text-classification" - - -class Model(BaseModel): - id: str = Field(..., description="The model identifier, which can be referenced in the API endpoints.") - type: ModelType = Field(..., description="The type of the model, which can be used to identify the model type.", examples=["text-generation"]) # fmt: off - aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off - created: int = Field(..., description="Time of creation, as Unix timestamp.") - owned_by: str = Field(..., description="The organization that owns the model.") - max_context_length: int | None = Field(default=None, description="Maximum amount of tokens a context could contains. Makes sure it is the same for all models.") # fmt: off - costs: ModelCosts = Field(..., description="Costs of the model.") +from api.domain.model import ModelType as RouterType class RouterLoadBalancingStrategy(str, Enum): @@ -36,7 +14,7 @@ class Router(BaseModel): id: int = Field(..., description="ID of the router.") # fmt: off name: str = Field(..., description="Name of the router.") # fmt: off user_id: int = Field(..., description="ID of the user that owns the router.") # fmt: off - type: ModelType = Field(..., description="Type of the model router. It will be used to identify the model router type.", examples=["text-generation"]) # fmt: off + type: RouterType = Field(..., description="Type of the model router. It will be used to identify the model router type.", examples=["text-generation"]) # fmt: off aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off load_balancing_strategy: RouterLoadBalancingStrategy = Field(..., description="Routing strategy for load balancing between providers of the model. It will be used to identify the model type.", examples=["least_busy"]) # fmt: off vector_size: int | None = Field(default=None, description="Dimension of the vectors, if the models are embeddings. Make sure it is the same for all models.") # fmt: off diff --git a/api/domain/router/errors.py b/api/domain/router/errors.py index b031e3073..6ff5a988a 100644 --- a/api/domain/router/errors.py +++ b/api/domain/router/errors.py @@ -9,3 +9,8 @@ class RouterAliasAlreadyExistsError: @dataclass class RouterNameAlreadyExistsError: name: str + + +@dataclass +class RouterNotFoundError: + router_id: int diff --git a/api/endpoints/admin/providers.py b/api/endpoints/admin/providers.py index b5bd4d0c1..0a774c646 100644 --- a/api/endpoints/admin/providers.py +++ b/api/endpoints/admin/providers.py @@ -7,51 +7,12 @@ from api.endpoints.admin import router from api.helpers._accesscontroller import AccessController from api.helpers.models import ModelRegistry -from api.schemas.admin.providers import ( - CreateProvider, - CreateProviderResponse, - Provider, - Providers, - UpdateProvider, -) +from api.schemas.admin.providers import Provider, Providers, UpdateProvider from api.schemas.admin.roles import PermissionType -from api.utils.context import request_context from api.utils.dependencies import get_model_registry, get_postgres_session from api.utils.variables import EndpointRoute -@router.post( - path=EndpointRoute.ADMIN_PROVIDERS, - dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))], - status_code=201, -) -async def create_provider( - request: Request, - body: CreateProvider, - postgres_session: AsyncSession = Depends(get_postgres_session), - model_registry: ModelRegistry = Depends(get_model_registry), -) -> CreateProviderResponse: - """ - Create a model provider. - """ - provider_id = await model_registry.create_provider( - router_id=body.router, - user_id=request_context.get().user_info.id, - type=body.type, - url=body.url, - key=body.key, - timeout=body.timeout, - model_name=body.model_name, - model_hosting_zone=body.model_hosting_zone, - model_total_params=body.model_total_params, - model_active_params=body.model_active_params, - qos_metric=body.qos_metric, - qos_limit=body.qos_limit, - postgres_session=postgres_session, - ) - return JSONResponse(status_code=201, content=CreateProviderResponse(id=provider_id).model_dump()) - - @router.delete( path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))], diff --git a/api/helpers/models/_modelregistry.py b/api/helpers/models/_modelregistry.py index fbb268ab2..b8b10fae3 100644 --- a/api/helpers/models/_modelregistry.py +++ b/api/helpers/models/_modelregistry.py @@ -150,16 +150,16 @@ async def setup(self, models: list[ModelConfiguration], postgres_session: AsyncS postgres_session=postgres_session, ) except ProviderAlreadyExistsException: - logger.warning(f"Provider {provider.model_name} already exists for router {model.name} (skipping)") + logger.warning(f"provider {provider.model_name} already exists for router {model.name} (skipping)") continue except ProviderNotReachableException: - logger.warning(f"Provider {provider.model_name} is not reachable for router {model.name} (skipping)") + logger.warning(f"provider {provider.model_name} is not reachable for router {model.name} (skipping)") continue except Exception as e: await postgres_session.rollback() - logger.error(f"Provider {provider.model_name} failed to be created for router {model.name} ({e})") + logger.error(f"provider {provider.model_name} failed to be created for router {model.name} ({e})") raise e - logging.info(f"Provider {provider.model_name} successfully created for router {model.name} (id: {provider_id})") + logging.info(f"provider {provider.model_name} successfully created for router {model.name} (id: {provider_id})") if self.queuing_enabled: routers = await self.get_routers(router_id=None, name=None, postgres_session=postgres_session) @@ -459,9 +459,9 @@ async def create_provider( Args: router_id(int): The model router ID user_id(int): The user ID of owner of the provider - type(ProviderType): Provider type - url(str): Provider URL - key(str | None): Provider API key + type(ProviderType): provider type + url(str): provider URL + key(str | None): provider API key timeout(int): Request timeout model_name(str): Model name model_hosting_zone(ProviderCarbonFootprintZone): ProviderCarbonFootprintZone @@ -498,7 +498,7 @@ async def create_provider( vector_size = None except AssertionError as e: - logger.debug(f"Provider {provider.model_name} not reachable: {e}", exc_info=True) + logger.debug(f"provider {provider.model_name} not reachable: {e}", exc_info=True) raise ProviderNotReachableException() # consistency check @@ -724,7 +724,7 @@ async def update_provider( await postgres_session.commit() except IntegrityError: await postgres_session.rollback() - raise ProviderAlreadyExistsException("Provider already exists for the new router.") + raise ProviderAlreadyExistsException("provider already exists for the new router.") async def get_models(self, name: str | None, user_info: UserInfo, postgres_session: AsyncSession) -> list[Model]: """ diff --git a/api/infrastructure/fastapi/documentation.py b/api/infrastructure/fastapi/documentation.py new file mode 100644 index 000000000..148ecfe92 --- /dev/null +++ b/api/infrastructure/fastapi/documentation.py @@ -0,0 +1,28 @@ +from fastapi import HTTPException +from pydantic import BaseModel + +from api.infrastructure.fastapi.endpoints.exceptions import ( + InvalidAPIKeyException, + InvalidAuthenticationSchemeException, +) + + +class HTTPExceptionModel(BaseModel): + status_code: int + detail: str + headers: dict[str, str] | None = None + + +def get_documentation_responses(exceptions: list[HTTPException]): + """ + Generate a dictionary of responses for a list of HTTP exceptions in Redoc and Swagger documentation. + """ + exceptions.extend([InvalidAuthenticationSchemeException, InvalidAPIKeyException]) + responses = {} + for exception in exceptions: + if exception.status_code not in responses: + responses[exception.status_code] = {"model": HTTPExceptionModel, "description": exception.detail} + else: + responses[exception.status_code]["description"] += f"
{exception.detail}" + + return responses diff --git a/api/infrastructure/fastapi/endpoints/admin/__init__.py b/api/infrastructure/fastapi/endpoints/admin/__init__.py new file mode 100644 index 000000000..c74ae4a85 --- /dev/null +++ b/api/infrastructure/fastapi/endpoints/admin/__init__.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +from api.utils.variables import RouterName + +router = APIRouter(prefix="/v1", tags=[RouterName.ADMIN.title()]) + +from . import providers, routers # noqa: F401 E402 diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py new file mode 100644 index 000000000..d402b2ae6 --- /dev/null +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -0,0 +1,193 @@ +import logging +from typing import Literal + +from fastapi import Body, Depends, Path, Query, Request, Security +from fastapi.responses import JSONResponse, Response +from sqlalchemy.ext.asyncio import AsyncSession + +from api.dependencies import create_provider_use_case_factory, get_request_context +from api.domain.model import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError +from api.domain.provider import InvalidProviderTypeError, ProviderNotReachableError +from api.domain.provider.errors import ProviderAlreadyExistsError +from api.domain.router.errors import RouterNotFoundError +from api.domain.userinfo.errors import InsufficientPermissionError +from api.helpers.models import ModelRegistry +from api.infrastructure.fastapi.access import get_current_key +from api.infrastructure.fastapi.context import RequestContext +from api.infrastructure.fastapi.documentation import get_documentation_responses +from api.infrastructure.fastapi.endpoints.admin import router +from api.infrastructure.fastapi.endpoints.exceptions import ( + InconsistentModelMaxContextLengthHTTPException, + InconsistentModelVectorSizeHTTPException, + InsufficientPermissionHTTPException, + InternalServerHTTPException, + InvalidProviderTypeHTTPException, + ProviderAlreadyExistsHTTPException, + ProviderNotReachableHTTPException, + RouterNotFoundHTTPException, +) +from api.infrastructure.fastapi.schemas.providers import CreateProvider, CreateProviderResponse, Provider, Providers, UpdateProvider +from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess +from api.utils.dependencies import get_model_registry, get_postgres_session +from api.utils.variables import EndpointRoute + +logger = logging.getLogger(__name__) + + +@router.post( + path=EndpointRoute.ADMIN_PROVIDERS, + dependencies=[Security(dependency=get_current_key)], + status_code=201, + responses=get_documentation_responses([ + InconsistentModelMaxContextLengthHTTPException, + InconsistentModelVectorSizeHTTPException, + InvalidProviderTypeHTTPException, + ProviderNotReachableHTTPException, + ProviderAlreadyExistsHTTPException, + RouterNotFoundHTTPException, + InsufficientPermissionHTTPException, + ]), +) +async def create_provider( + request: Request, + body: CreateProvider, + create_provider_use_case: CreateProviderUseCase = Depends(create_provider_use_case_factory), + request_context: RequestContext = Depends(get_request_context), +) -> CreateProviderResponse: + try: + command = CreateProviderCommand( + router_id=body.router, + user_id=request_context.get().user_id, + provider_type=body.type, + url=body.url, + key=body.key, + timeout=body.timeout, + model_name=body.model_name, + model_hosting_zone=body.model_hosting_zone, + model_total_params=body.model_total_params, + model_active_params=body.model_active_params, + qos_metric=body.qos_metric, + qos_limit=body.qos_limit, + ) + result = await create_provider_use_case.execute(command) + except Exception as e: + logger.exception( + "Unexpected error while executing create_provider use case", + extra={ + "user_id": request_context.get().user_id, + "provider_router_id": body.router, + "provider_url": body.url, + "provider_model_name": body.model_name, + "error_type": type(e).__name__, + }, + ) + raise InternalServerHTTPException() + + match result: + case CreateProviderUseCaseSuccess(created_provider): + return CreateProviderResponse.model_validate(created_provider, from_attributes=True) + + case InconsistentModelMaxContextLengthError(expected_max_context_length=expected_max_context_length, actual_max_context_length=actual_max_context_length, router_name=router_name): # fmt: off + raise InconsistentModelMaxContextLengthHTTPException(input_max_context_length=actual_max_context_length, model_max_context_length=expected_max_context_length, model_name=router_name) # fmt: off + case InconsistentModelVectorSizeError(expected_vector_size=expected_vector_size, actual_vector_size=actual_vector_size, router_name=router_name): # fmt: off + raise InconsistentModelVectorSizeHTTPException(input_vector_size=actual_vector_size, model_vector_size=expected_vector_size, model_name=router_name) # fmt: off + case InvalidProviderTypeError(provider_type=provider_type, router_type=router_type): + raise InvalidProviderTypeHTTPException(incorrect_provider_type=provider_type, router_type=router_type) + case ProviderNotReachableError(model_name=name): + raise ProviderNotReachableHTTPException(name=name) + case ProviderAlreadyExistsError(model_name=model_name, url=url, router_id=router_id): + raise ProviderAlreadyExistsHTTPException(model_name=model_name, url=url, router_id=router_id) + case RouterNotFoundError(router_id=router_id): + raise RouterNotFoundHTTPException(router_id=router_id) + case InsufficientPermissionError(): + raise InsufficientPermissionHTTPException() + + +@router.delete( + path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", + dependencies=[Security(dependency=get_current_key)], + status_code=204, +) +async def delete_provider( + request: Request, + provider: int = Path(description="The ID of the provider to delete."), + postgres_session: AsyncSession = Depends(get_postgres_session), + model_registry: ModelRegistry = Depends(get_model_registry), +) -> Response: + await model_registry.delete_provider(provider_id=provider, postgres_session=postgres_session) + + return Response(status_code=204) + + +@router.patch( + path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", + dependencies=[Security(dependency=get_current_key)], + status_code=204, +) +async def update_provider( + request: Request, + provider: int = Path(description="The ID of the provider to update."), + body: UpdateProvider = Body(description="The provider update request."), + postgres_session: AsyncSession = Depends(get_postgres_session), + model_registry: ModelRegistry = Depends(get_model_registry), +) -> Response: + await model_registry.update_provider( + provider_id=provider, + router_id=body.router, + timeout=body.timeout, + model_hosting_zone=body.model_hosting_zone, + model_total_params=body.model_total_params, + model_active_params=body.model_active_params, + qos_metric=body.qos_metric, + qos_limit=body.qos_limit, + postgres_session=postgres_session, + ) + + return Response(status_code=204) + + +@router.get( + path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", + dependencies=[Security(dependency=get_current_key)], + status_code=200, + response_model=Provider, +) +async def get_provider( + request: Request, + provider: int = Path(description="The ID of the provider to get."), + postgres_session: AsyncSession = Depends(get_postgres_session), + model_registry: ModelRegistry = Depends(get_model_registry), +) -> JSONResponse: + providers = await model_registry.get_providers(router_id=None, provider_id=provider, postgres_session=postgres_session) + provider = providers[0] + + return JSONResponse(status_code=200, content=provider.model_dump()) + + +@router.get( + path=EndpointRoute.ADMIN_PROVIDERS, + dependencies=[Security(dependency=get_current_key)], + status_code=200, + response_model=Providers, +) +async def get_providers( + request: Request, + router: int | None = Query(default=None, description="Filter providers by router ID."), + offset: int = Query(default=0, ge=0, description="The offset of the tokens to get."), + limit: int = Query(default=10, ge=1, le=100, description="The limit of the tokens to get."), + order_by: Literal["id", "model_name", "created"] = Query(default="id", description="The field to order the tokens by."), + order_direction: Literal["asc", "desc"] = Query(default="asc", description="The direction to order the tokens by."), + postgres_session: AsyncSession = Depends(get_postgres_session), + model_registry: ModelRegistry = Depends(get_model_registry), +) -> JSONResponse: + providers = await model_registry.get_providers( + router_id=router, + provider_id=None, + postgres_session=postgres_session, + offset=offset, + limit=limit, + order_by=order_by, + order_direction=order_direction, + ) + + return JSONResponse(status_code=200, content=Providers(data=providers).model_dump()) diff --git a/api/infrastructure/fastapi/endpoints/admin_router.py b/api/infrastructure/fastapi/endpoints/admin/routers.py similarity index 75% rename from api/infrastructure/fastapi/endpoints/admin_router.py rename to api/infrastructure/fastapi/endpoints/admin/routers.py index f1b226bb9..02eb47f57 100644 --- a/api/infrastructure/fastapi/endpoints/admin_router.py +++ b/api/infrastructure/fastapi/endpoints/admin/routers.py @@ -1,12 +1,14 @@ import logging -from fastapi import APIRouter, Body, Depends, Security +from fastapi import Body, Depends, Security from api.dependencies import create_router_use_case, get_request_context from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.domain.userinfo.errors import InsufficientPermissionError from api.infrastructure.fastapi.access import get_current_key from api.infrastructure.fastapi.context import RequestContext +from api.infrastructure.fastapi.documentation import get_documentation_responses +from api.infrastructure.fastapi.endpoints.admin import router from api.infrastructure.fastapi.endpoints.exceptions import ( InsufficientPermissionHTTPException, InternalServerHTTPException, @@ -14,14 +16,22 @@ RouterAlreadyExistsHTTPException, ) from api.infrastructure.fastapi.schemas.routers import CreateRouter, CreateRouterResponse -from api.use_cases.admin import CreateRouterUseCase, CreateRouterUseCaseSuccess -from api.utils.variables import EndpointRoute, RouterName +from api.use_cases.admin import CreateRouterCommand, CreateRouterUseCase, CreateRouterUseCaseSuccess +from api.utils.variables import EndpointRoute logger = logging.getLogger(__name__) -router = APIRouter(prefix="/v1", tags=[RouterName.ADMIN.title()]) -@router.post(path=EndpointRoute.ADMIN_ROUTERS, dependencies=[Security(dependency=get_current_key)], status_code=201) +@router.post( + path=EndpointRoute.ADMIN_ROUTERS, + dependencies=[Security(dependency=get_current_key)], + status_code=201, + responses=get_documentation_responses([ + RouterAliasAlreadyExistsHTTPException, + RouterAlreadyExistsHTTPException, + InsufficientPermissionHTTPException, + ]), +) async def create_router( body: CreateRouter = Body(description="The router creation request."), create_router_use_case: CreateRouterUseCase = Depends(create_router_use_case), @@ -31,7 +41,7 @@ async def create_router( Create a router (without any providers). """ try: - result = await create_router_use_case.execute( + command = CreateRouterCommand( user_id=request_context.get().user_id, name=body.name, router_type=body.type, @@ -40,6 +50,7 @@ async def create_router( cost_prompt_tokens=body.cost_prompt_tokens, cost_completion_tokens=body.cost_completion_tokens, ) + result = await create_router_use_case.execute(command) except Exception as e: logger.exception( "Unexpected error while executing create_router use case", diff --git a/api/infrastructure/fastapi/endpoints/exceptions.py b/api/infrastructure/fastapi/endpoints/exceptions.py index 7036b3e8e..59da1715e 100644 --- a/api/infrastructure/fastapi/endpoints/exceptions.py +++ b/api/infrastructure/fastapi/endpoints/exceptions.py @@ -1,42 +1,103 @@ from fastapi import HTTPException + # 400 +class InvalidProviderTypeHTTPException(HTTPException): + status_code = 400 + detail = "Invalid model provider type {input_type} for {expected_type} router." + + def __init__(self, incorrect_provider_type: str, router_type: str) -> None: + super().__init__(status_code=self.status_code, detail=f"Invalid model provider type {incorrect_provider_type} for {router_type} router.") # 401 +class InvalidAPIKeyException(HTTPException): + status_code = 401 + detail = "Invalid API key." + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) -# 403 -class InvalidAuthenticationSchemeException(HTTPException): - def __init__(self, detail: str = "Invalid authentication scheme.") -> None: - super().__init__(status_code=403, detail=detail) +class InvalidAuthenticationSchemeException(HTTPException): + status_code = 401 + detail = "Invalid authentication scheme." -class InvalidAPIKeyException(HTTPException): - def __init__(self, detail: str = "Invalid API key.") -> None: - super().__init__(status_code=403, detail=detail) + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) +# 403 class InsufficientPermissionHTTPException(HTTPException): - def __init__(self, detail: str = "Insufficient rights.") -> None: - super().__init__(status_code=403, detail=detail) + status_code = 403 + detail = "Insufficient rights." + + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) + + +class InconsistentModelMaxContextLengthHTTPException(HTTPException): + status_code = 403 + detail = "Inconsistent max context length for {model_name}. Expected: {expected_length}. Actual: {actual_length}" + + def __init__(self, input_max_context_length: int, model_max_context_length: int, model_name: str) -> None: + super().__init__( + status_code=self.status_code, + detail=f"Inconsistent max context length for {model_name}. Expected: {model_max_context_length}. Actual: {input_max_context_length}", + ) + + +class InconsistentModelVectorSizeHTTPException(HTTPException): + status_code = 403 + detail = "Inconsistent vector size for {model_name}. Expected: {expected_size}. Actual: {actual_size}" + + def __init__(self, input_vector_size: int, model_vector_size: int, model_name: str) -> None: + super().__init__( + status_code=self.status_code, + detail=f"Inconsistent vector size for {model_name}. Expected: {model_vector_size}. Actual: {input_vector_size}", + ) # 404 class ModelNotFoundHTTPException(HTTPException): - def __init__(self, detail: str = "Model not found.") -> None: - super().__init__(status_code=404, detail=detail) + status_code = 404 + detail = "Model not found." + + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) + + +class RouterNotFoundHTTPException(HTTPException): + status_code = 404 + detail = "Model router {router_id} not found." + + def __init__(self, router_id: int) -> None: + super().__init__(status_code=self.status_code, detail=f"Model router {router_id} not found.") # 409 class RouterAliasAlreadyExistsHTTPException(HTTPException): + status_code = 409 + detail = "Following aliases already exist: '{router_aliases}'" + def __init__(self, aliases: list[str]): - super().__init__(status_code=409, detail=f"Following aliases already exist: '{aliases}'") + super().__init__(status_code=self.status_code, detail=f"Following aliases already exist: '{aliases}'") class RouterAlreadyExistsHTTPException(HTTPException): + status_code = 409 + detail = "Router {router_name} already exists." + def __init__(self, name: str): - super().__init__(status_code=409, detail=f"Router '{name}' already exists.") + super().__init__(status_code=self.status_code, detail=f"Router {name} already exists.") + + +class ProviderAlreadyExistsHTTPException(HTTPException): + status_code = 409 + detail = "Model provider {model_name} for url {url} already exists for router {router_id}." + + def __init__(self, model_name: str, url: str, router_id: int) -> None: + super().__init__(status_code=409, detail=f"Model provider {model_name} for url {url} already exists for router {router_id}.") # 413 @@ -46,6 +107,12 @@ def __init__(self, name: str): # 424 +class ProviderNotReachableHTTPException(HTTPException): + status_code = 424 + detail = "Model provider {provider_name} not reachable." + + def __init__(self, name: str) -> None: + super().__init__(status_code=self.status_code, detail=f"Model provider {name} not reachable.") # 429 @@ -53,10 +120,11 @@ def __init__(self, name: str): # 500 class InternalServerHTTPException(HTTPException): - """Exception for unexpected internal errors.""" + status_code = 500 + detail = "An unexpected error occurred" - def __init__(self, detail: str = "An unexpected error occurred"): - super().__init__(status_code=500, detail=detail) + def __init__(self) -> None: + super().__init__(status_code=self.status_code, detail=self.detail) # 503 diff --git a/api/infrastructure/fastapi/endpoints/models.py b/api/infrastructure/fastapi/endpoints/models.py index db2118b8e..9839ac65a 100644 --- a/api/infrastructure/fastapi/endpoints/models.py +++ b/api/infrastructure/fastapi/endpoints/models.py @@ -3,12 +3,11 @@ from api.dependencies import get_models_use_case from api.infrastructure.fastapi.access import get_current_key +from api.infrastructure.fastapi.documentation import get_documentation_responses from api.infrastructure.fastapi.endpoints.exceptions import ModelNotFoundHTTPException from api.infrastructure.fastapi.schemas.models import Model, Models -from api.schemas.exception import HTTPExceptionModel from api.use_cases.models import GetModelsUseCase from api.use_cases.models._getmodelsusecase import ModelNotFound, Success -from api.utils.exceptions import ModelNotFoundException from api.utils.variables import EndpointRoute, RouterName router = APIRouter(prefix="/v1", tags=[RouterName.MODELS.title()]) @@ -19,7 +18,7 @@ dependencies=[Security(dependency=get_current_key)], status_code=200, response_model=Model, - responses={ModelNotFoundException().status_code: {"model": HTTPExceptionModel, "description": {ModelNotFoundException().detail}}}, + responses=get_documentation_responses([ModelNotFoundHTTPException]), ) async def get_model( request: Request, @@ -45,7 +44,7 @@ async def get_model( dependencies=[Security(dependency=get_current_key)], status_code=200, response_model=Models, - responses={ModelNotFoundException().status_code: {"model": HTTPExceptionModel, "description": {ModelNotFoundException().detail}}}, + responses=get_documentation_responses([ModelNotFoundHTTPException]), ) async def get_models( request: Request, diff --git a/api/infrastructure/fastapi/schemas/models.py b/api/infrastructure/fastapi/schemas/models.py index 17b185c75..94d80496e 100644 --- a/api/infrastructure/fastapi/schemas/models.py +++ b/api/infrastructure/fastapi/schemas/models.py @@ -3,7 +3,7 @@ from pydantic import Field -from api.domain.router.entities import Model as ModelEntity +from api.domain.model import Model as ModelEntity from api.schemas import BaseModel diff --git a/api/infrastructure/fastapi/schemas/providers.py b/api/infrastructure/fastapi/schemas/providers.py new file mode 100644 index 000000000..e92d22405 --- /dev/null +++ b/api/infrastructure/fastapi/schemas/providers.py @@ -0,0 +1,115 @@ +from enum import Enum +from typing import Literal + +import pycountry +from pydantic import Field, constr, model_validator + +from api.schemas import BaseModel +from api.schemas.core.models import Metric +from api.utils.variables import DEFAULT_TIMEOUT + +# Add world as a country code, default value of the carbon footprint computation framework +country_codes = [country.alpha_3 for country in pycountry.countries] + ["WOR"] +country_codes_dict = {str(code).upper(): str(code) for code in sorted(set(country_codes))} +ProviderCarbonFootprintZone: type[Enum] = Enum("ProviderCarbonFootprintZone", country_codes_dict, type=str) + + +class ProviderType(str, Enum): + ALBERT = "albert" + OPENAI = "openai" + MISTRAL = "mistral" + TEI = "tei" + VLLM = "vllm" + + +class CreateProvider(BaseModel): + router: int = Field(..., description="ID of the model to create the provider for (router ID, eg. 123).") # fmt: off + type: ProviderType = Field(..., description="Model provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1) | None = Field(default=None, description="Model provider API url. The url must only contain the domain name (without `/v1` suffix for example). Depends of the model provider type, the url can be optional (Albert, OpenAI).") # fmt: off + key: constr(strip_whitespace=True, min_length=1) | None = Field(default=None, description="Model provider API key.") # fmt: off + timeout: int = Field(default=DEFAULT_TIMEOUT, description="Timeout for the model provider requests, after user receive an 503 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai") # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(default=None, description="The metric to use for the quality of service policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + + @model_validator(mode="after") + def format_provider(self): + if self.qos_metric is not None and self.qos_limit is None: + raise ValueError("QoS value is required if QoS metric is provided.") + + if self.url is None: + if self.type == ProviderType.ALBERT: + self.url = "https://albert.api.etalab.gouv.fr/" + elif self.type == ProviderType.MISTRAL: + self.url = "https://albert.api.etalab.gouv.fr/" + elif self.type == ProviderType.OPENAI: + self.url = "https://api.openai.com/" + else: + raise ValueError("URL is required for this model provider type.") + + elif not self.url.endswith("/"): + self.url = f"{self.url}/" + + return self + + +class CreateProviderResponse(BaseModel): + id: int = Field(..., description="Provider ID.") # fmt: off + router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off + user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off + type: ProviderType = Field(..., description="Provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off + key: str | None = Field(description="Provider API key.") # fmt: off + timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off + updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off + + +class UpdateProvider(BaseModel): + router: int | None = Field(default=None, description="The ID of the new router to assign to the provider.") # fmt: off + timeout: int | None = Field(default=None, description="Timeout for the model provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone | None = Field(default=None, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai") # fmt: off + model_total_params: int | None = Field(default=None, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int | None = Field(default=None, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(default=None, description="The metric to use for the quality of service policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + + @model_validator(mode="after") + def validate_model(self): + if self.qos_metric is not None and self.qos_limit is None: + raise ValueError("QoS value is required if QoS metric is provided.") + + return self + + +class Provider(BaseModel): + object: Literal["provider"] = "provider" + id: int = Field(..., description="provider ID.") # fmt: off + router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off + user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off + type: ProviderType = Field(..., description="provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off + key: str | None = Field(description="provider API key.") # fmt: off + timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off + updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off + + +class Providers(BaseModel): + object: Literal["list"] = "list" + data: list[Provider] diff --git a/api/infrastructure/model/__init__.py b/api/infrastructure/model/__init__.py new file mode 100644 index 000000000..6597b06c7 --- /dev/null +++ b/api/infrastructure/model/__init__.py @@ -0,0 +1,3 @@ +from api.infrastructure.model._modelprovidergateway import ModelProviderGateway + +__all__ = ["ModelProviderGateway"] diff --git a/api/infrastructure/model/_modelprovidergateway.py b/api/infrastructure/model/_modelprovidergateway.py new file mode 100644 index 000000000..4d47c1804 --- /dev/null +++ b/api/infrastructure/model/_modelprovidergateway.py @@ -0,0 +1,32 @@ +from api.clients.model import BaseModelProvider +from api.domain.model import ModelType as RouterType +from api.domain.provider import ProviderCapabilities, ProviderGateway, ProviderNotReachableError + + +class ModelProviderGateway(ProviderGateway): + async def get_capabilities(self, router_type, provider_type, url, key, timeout, model_name): + try: + client = self._build_client(provider_type, url, key, timeout, model_name) + max_context_length = await client.get_max_context_length() + if router_type == RouterType.TEXT_EMBEDDINGS_INFERENCE: + vector_size = await client.get_vector_size() + else: + vector_size = None + return ProviderCapabilities( + max_context_length=max_context_length, + vector_size=vector_size, + ) + except AssertionError as e: + return ProviderNotReachableError(model_name) + + def _build_client(self, provider_type, url, key, timeout, model_name): + cls = BaseModelProvider.import_module(type=provider_type) + return cls( + url=url, + key=key, + timeout=timeout, + model_name=model_name, + model_hosting_zone=None, + model_total_params=0, + model_active_params=0, + ) diff --git a/api/infrastructure/postgres/__init__.py b/api/infrastructure/postgres/__init__.py index 5f0b1a76b..f4f72e897 100644 --- a/api/infrastructure/postgres/__init__.py +++ b/api/infrastructure/postgres/__init__.py @@ -1,7 +1,15 @@ from ._postgreskeyrepository import PostgresKeyRepository +from ._postgresproviderrepository import PostgresProviderRepository from ._postgresrolesrepository import PostgresRolesRepository from ._postgresrouterrepository import PostgresRouterRepository from ._postgresuserinforepository import PostgresUserInfoRepository from ._postgresusersrepository import PostgresUserRepository -__all__ = ["PostgresKeyRepository", "PostgresUserInfoRepository", "PostgresUserRepository", "PostgresRolesRepository", "PostgresRouterRepository"] +__all__ = [ + "PostgresKeyRepository", + "PostgresUserInfoRepository", + "PostgresUserRepository", + "PostgresRolesRepository", + "PostgresRouterRepository", + "PostgresProviderRepository", +] diff --git a/api/infrastructure/postgres/_postgresproviderrepository.py b/api/infrastructure/postgres/_postgresproviderrepository.py new file mode 100644 index 000000000..5e59808b8 --- /dev/null +++ b/api/infrastructure/postgres/_postgresproviderrepository.py @@ -0,0 +1,77 @@ +from sqlalchemy import insert +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from api.domain.model.entities import Metric +from api.domain.provider import ProviderRepository +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.errors import ProviderAlreadyExistsError +from api.sql.models import Provider as ProviderTable + + +class PostgresProviderRepository(ProviderRepository): + def __init__(self, postgres_session: AsyncSession): + self.postgres_session = postgres_session + + async def create_provider( + self, + router_id: int, + user_id: int, + provider_type: ProviderType, + url: str, + key: str | None, + timeout: int, + model_name: str, + model_hosting_zone: ProviderCarbonFootprintZone, + model_total_params: int, + model_active_params: int, + qos_metric: Metric | None, + qos_limit: float | None, + vector_size: int, + max_context_length: int, + ) -> Provider | ProviderAlreadyExistsError: + try: + qos_metric = qos_metric.value if qos_metric is not None else None + query = ( + insert(ProviderTable) + .values( + router_id=router_id, + user_id=user_id, + type=provider_type.value, + url=url, + key=key, + timeout=timeout, + model_name=model_name, + model_hosting_zone=model_hosting_zone, + model_total_params=model_total_params, + model_active_params=model_active_params, + qos_metric=qos_metric, + qos_limit=qos_limit, + max_context_length=max_context_length, + vector_size=vector_size, + ) + .returning(ProviderTable) + ) + result = await self.postgres_session.execute(query) + row = result.scalar_one() + return Provider( + router_id=row.router_id, + user_id=row.user_id, + type=row.type, + url=row.url, + key=row.key, + timeout=row.timeout, + model_name=row.model_name, + model_hosting_zone=row.model_hosting_zone, + model_total_params=row.model_total_params, + model_active_params=row.model_active_params, + qos_metric=row.qos_metric, + qos_limit=row.qos_limit, + max_context_length=row.max_context_length, + vector_size=row.vector_size, + id=row.id, + ) + except IntegrityError as e: + if "unique_provider_router_id_url_model_name" in str(e.orig): + return ProviderAlreadyExistsError(model_name=model_name, url=url, router_id=router_id) + raise diff --git a/api/infrastructure/postgres/_postgresrouterrepository.py b/api/infrastructure/postgres/_postgresrouterrepository.py index 50d166c76..4c3ab19c3 100644 --- a/api/infrastructure/postgres/_postgresrouterrepository.py +++ b/api/infrastructure/postgres/_postgresrouterrepository.py @@ -3,8 +3,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from api.domain.key.entities import MASTER_USER_ID +from api.domain.model import ModelType as RouterType from api.domain.router import RouterRepository -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.sql.models import Organization as OrganizationTable from api.sql.models import Provider as ProviderTable @@ -14,6 +15,57 @@ class PostgresRouterRepository(RouterRepository): + async def get_aliases_by_router_id(self, router_id: int) -> list[str]: + query = select(RouterAliasTable.value).where(RouterAliasTable.router_id == router_id) + result = await self.postgres_session.execute(query) + return [row[0] for row in result.all()] + + async def get_router_by_id(self, router_id: int) -> Router | None: + provider_count_subquery = ( + select(func.count(ProviderTable.id)).where(ProviderTable.router_id == RouterTable.id).correlate(RouterTable).scalar_subquery() + ) + query = ( + select( + RouterTable.id, + RouterTable.name, + RouterTable.user_id, + RouterTable.type, + RouterTable.load_balancing_strategy, + RouterTable.cost_prompt_tokens, + RouterTable.cost_completion_tokens, + ProviderTable.max_context_length, + ProviderTable.vector_size, + provider_count_subquery.label("providers"), + cast(func.extract("epoch", RouterTable.created), Integer).label("created"), + cast(func.extract("epoch", RouterTable.updated), Integer).label("updated"), + ) + .where(RouterTable.id == router_id) + .join(ProviderTable, ProviderTable.router_id == RouterTable.id, isouter=True) + .limit(1) + ) + + result = await self.postgres_session.execute(query) + row = result.one_or_none() + if row is None: + return None + user_id = MASTER_USER_ID if row.user_id is None else row.user_id + aliases = await self.get_aliases_by_router_id(router_id) + return Router( + id=row.id, + name=row.name, + user_id=user_id, + type=RouterType(row.type), + aliases=aliases, + load_balancing_strategy=RouterLoadBalancingStrategy(row.load_balancing_strategy), + vector_size=row.vector_size, + max_context_length=row.max_context_length, + cost_prompt_tokens=row.cost_prompt_tokens or 0.0, + cost_completion_tokens=row.cost_completion_tokens or 0.0, + providers=row.providers, + created=row.created, + updated=row.updated, + ) + def __init__(self, postgres_session: AsyncSession, app_title: str): self.postgres_session = postgres_session self.app_title = app_title @@ -56,7 +108,7 @@ async def get_all_routers(self) -> list[Router]: result = await self.postgres_session.execute(query) router_results = [row._asdict() for row in result.all()] - aliases = await self.get_aliases_by_router_id() + aliases = await self.get_all_aliases_grouped_by_router() for row in router_results: user_id = MASTER_USER_ID if row["user_id"] is None else row["user_id"] @@ -65,7 +117,7 @@ async def get_all_routers(self) -> list[Router]: id=row["id"], name=row["name"], user_id=user_id, - type=ModelType(row["type"]), + type=RouterType(row["type"]), aliases=aliases.get(row["id"], []), load_balancing_strategy=RouterLoadBalancingStrategy(row["load_balancing_strategy"]), vector_size=row["vector_size"], @@ -79,7 +131,7 @@ async def get_all_routers(self) -> list[Router]: ) return routers - async def get_aliases_by_router_id(self) -> dict[str, list[str]]: + async def get_all_aliases_grouped_by_router(self) -> dict[str, list[str]]: aliases_query = select(RouterAliasTable.router_id.label("router_id"), RouterAliasTable.value) aliases_result = await self.postgres_session.execute(aliases_query) aliases = {} @@ -92,7 +144,7 @@ async def get_aliases_by_router_id(self) -> dict[str, list[str]]: async def create_router( self, name: str, - router_type: ModelType, + router_type: RouterType, load_balancing_strategy: RouterLoadBalancingStrategy, cost_prompt_tokens: float, cost_completion_tokens: float, @@ -148,7 +200,7 @@ async def create_router( id=row.id, name=row.name, user_id=user_id, - type=ModelType(row.type), + type=RouterType(row.type), aliases=aliases, load_balancing_strategy=RouterLoadBalancingStrategy(row.load_balancing_strategy), vector_size=None, diff --git a/api/routers/registry.py b/api/routers/registry.py index a9fd85a2e..e637a15e8 100644 --- a/api/routers/registry.py +++ b/api/routers/registry.py @@ -12,7 +12,7 @@ class RouterDefinition: ROUTER_DEFINITIONS: tuple[RouterDefinition, ...] = ( # Admin routers RouterDefinition(name=RouterName.ADMIN, module_path="api.endpoints.admin.organizations"), - RouterDefinition(name=RouterName.ADMIN, module_path="api.endpoints.admin.providers"), + RouterDefinition(name=RouterName.ADMIN, module_path="api.infrastructure.fastapi.endpoints.admin.providers"), RouterDefinition(name=RouterName.ADMIN, module_path="api.endpoints.admin.roles"), RouterDefinition(name=RouterName.ADMIN, module_path="api.endpoints.admin.routers"), RouterDefinition(name=RouterName.ADMIN, module_path="api.infrastructure.fastapi.endpoints.admin_router"), diff --git a/api/tests/__init__.py b/api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integ/test_swagger.py b/api/tests/integ/test_swagger.py deleted file mode 100644 index 0857b6d98..000000000 --- a/api/tests/integ/test_swagger.py +++ /dev/null @@ -1,15 +0,0 @@ -from fastapi.testclient import TestClient -import pytest - -from api.utils.configuration import configuration - - -@pytest.mark.usefixtures("client") -class TestSwagger: - def test_swagger(self, client: TestClient): - """Test the GET /swagger response status code.""" - response = client.get_without_permissions(url=configuration.settings.swagger_docs_url) - assert response.status_code == 200, response.text - - response = client.get_without_permissions(url=configuration.settings.swagger_openapi_url) - assert response.status_code == 200, response.text diff --git a/api/tests/integration/__init__.py b/api/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 8be3a0277..41a77f56f 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -1,25 +1,29 @@ from collections.abc import AsyncGenerator -from types import SimpleNamespace import asyncpg from httpx import ASGITransport, AsyncClient import pytest import pytest_asyncio -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.pool import NullPool from api.app import create_app from api.dependencies import get_postgres_session +from api.helpers.models import ModelRegistry +from api.schemas.core.configuration import Configuration, Dependencies, Settings from api.sql.models import Base from api.tests.integration import factories +from api.utils.dependencies import get_model_registry +from api.utils.dependencies import get_postgres_session as get_postgres_session_utils TEST_DATABASE_URL = "postgresql+asyncpg://postgres:changeme@localhost:5432/test_db" @pytest.fixture def test_configuration(): - return SimpleNamespace( - settings=SimpleNamespace( + configuration = Configuration.model_construct( + settings=Settings.model_construct( app_title="test", swagger_summary=None, swagger_version="0.0.0", @@ -35,8 +39,9 @@ def test_configuration(): hidden_routers=[], monitoring_prometheus_enabled=False, ), - dependencies=SimpleNamespace(sentry=None), + dependencies=Dependencies.model_construct(sentry=None), ) + return configuration @pytest_asyncio.fixture(scope="session") @@ -63,29 +68,89 @@ async def test_engine(): await engine.dispose() -@pytest_asyncio.fixture(scope="session") -async def test_session_factory(test_engine): - return async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) +def _all_sql_factories(): + result = [] + stack = list(factories.BaseSQLFactory.__subclasses__()) + while stack: + cls = stack.pop() + result.append(cls) + stack.extend(cls.__subclasses__()) + return result + + +# @pytest_asyncio.fixture(scope="session") +# async def test_session_factory(test_engine): +# return async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) + +# +# @pytest_asyncio.fixture(scope="function") +# async def db_session(test_session_factory) -> AsyncGenerator[AsyncSession]: +# async with test_session_factory() as session: +# all_sql_factories = factories.BaseSQLFactory.__subclasses__() +# session.expire_on_commit = False +# try: +# async with session.begin_nested(): +# for factory in all_sql_factories: +# factory._meta.sqlalchemy_session = session +# yield session +# finally: +# if session.in_transaction(): +# await session.rollback() +# await session.close() + + +def pytest_addoption(parser): + parser.addoption( + "--commit-db", + action="store_true", + default=False, + help="Commit DB changes after each test (for debugging with psql).", + ) @pytest_asyncio.fixture(scope="function") -async def db_session(test_session_factory) -> AsyncGenerator[AsyncSession]: - async with test_session_factory() as session: - all_sql_factories = factories.BaseSQLFactory.__subclasses__() - session.expire_on_commit = False +async def db_session(test_engine, request) -> AsyncGenerator[AsyncSession]: + async with test_engine.connect() as connection: + transaction = await connection.begin() + + session = AsyncSession(bind=connection, expire_on_commit=False) + await session.begin_nested() + + all_sql_factories = _all_sql_factories() + for factory in all_sql_factories: + factory._meta.sqlalchemy_session = session + + @event.listens_for(session.sync_session, "after_transaction_end") + def restart_savepoint(sess, trans): + if trans.nested and not trans._parent.nested: + sess.begin_nested() + try: - async with session.begin_nested(): - for factory in all_sql_factories: - factory._meta.sqlalchemy_session = session - yield session + yield session finally: - if session.in_transaction(): - await session.rollback() + event.remove(session.sync_session, "after_transaction_end", restart_savepoint) + for factory in all_sql_factories: + factory._meta.sqlalchemy_session = None await session.close() + if request.config.getoption("--commit-db"): + await transaction.commit() + else: + await transaction.rollback() + + +@pytest_asyncio.fixture(scope="session") +def model_registry(): + return ModelRegistry( + app_title="test", + queuing_enabled=False, + max_priority=0, + max_retries=0, + retry_countdown=0, + ) @pytest_asyncio.fixture(scope="function") -async def client(db_session, test_configuration) -> AsyncGenerator[AsyncClient, None]: +async def app(db_session, model_registry, test_configuration): app = create_app(test_configuration, skip_lifespan=True) async def override_get_postgres_session(): @@ -99,9 +164,16 @@ async def override_get_postgres_session(): raise app.dependency_overrides[get_postgres_session] = override_get_postgres_session + app.dependency_overrides[get_postgres_session_utils] = override_get_postgres_session + app.dependency_overrides[get_model_registry] = lambda: model_registry try: - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - yield ac + yield app finally: app.dependency_overrides.clear() + + +@pytest_asyncio.fixture(scope="function") +async def client(app) -> AsyncGenerator[AsyncClient, None]: + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac diff --git a/api/tests/integration/endpoints/__init__.py b/api/tests/integration/endpoints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration/endpoints/test_admin_providers.py b/api/tests/integration/endpoints/test_admin_providers.py new file mode 100644 index 000000000..691d4b3ff --- /dev/null +++ b/api/tests/integration/endpoints/test_admin_providers.py @@ -0,0 +1,149 @@ +from unittest.mock import AsyncMock + +import httpx +from httpx import AsyncClient +import pytest +import pytest_asyncio +import respx + +from api.dependencies import create_provider_use_case_factory +from api.domain.model.errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError +from api.domain.provider.errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError +from api.domain.router.errors import RouterNotFoundError +from api.domain.userinfo.errors import InsufficientPermissionError +from api.schemas.models import ModelType +from api.tests.helpers import create_token +from api.tests.integration.factories import RouterSQLFactory, UserSQLFactory +from api.utils.variables import EndpointRoute + +URL = f"/v1{EndpointRoute.ADMIN_PROVIDERS}" + +DEFAULT_PROVIDER_URL = "http://my-test-provider/" + + +def _valid_body(router_id: int, **overrides) -> dict: + """Return a minimal valid provider creation body, with optional overrides.""" + body = { + "router": router_id, + "type": "albert", + "model_name": "my-model", + "url": DEFAULT_PROVIDER_URL, + } + body.update(overrides) + return body + + +def _mock_provider_reachable(respx_mock, base_url=DEFAULT_PROVIDER_URL, max_context_length=4096, vector_size=768): + """Mock GET /v1/models and POST /v1/embeddings for a reachable albert provider.""" + base_url = base_url.rstrip("/") + respx_mock.get(f"{base_url}/v1/models").mock( + return_value=httpx.Response( + 200, + json={ + "data": [{"id": "my-model", "aliases": [], "max_context_length": max_context_length}], + }, + ) + ) + embedding = [0.0] * vector_size if vector_size else [] + respx_mock.post(f"{base_url}/v1/embeddings").mock( + return_value=httpx.Response( + 200, + json={ + "data": [{"embedding": embedding}], + }, + ) + ) + + +@pytest.mark.asyncio(loop_scope="session") +class TestCreateProvider: + @pytest_asyncio.fixture(autouse=True) + async def setup(self, db_session): + self.admin_user = UserSQLFactory(admin_user=True) + self.token = await create_token(db_session, name="admin_token", user=self.admin_user) + + @pytest_asyncio.fixture(autouse=True) + async def cleanup_overrides(self, app): + yield + app.dependency_overrides.pop(create_provider_use_case_factory, None) + + @respx.mock + async def test_happy_path(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + _mock_provider_reachable(respx) + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(router.id), + ) + + assert response.status_code == 201, response.text + assert isinstance(response.json()["id"], int) + + @pytest.mark.parametrize( + "use_case_result,expected_status,expected_detail", + [ + ( + RouterNotFoundError(router_id=1), + 404, + "Model router 1 not found.", + ), + ( + InvalidProviderTypeError(provider_type="tei", router_type="text-generation"), + 400, + "Invalid model provider type tei for text-generation router.", + ), + ( + ProviderNotReachableError(model_name="my-model"), + 424, + "Model provider my-model not reachable.", + ), + ( + ProviderAlreadyExistsError(model_name="my-model", url=DEFAULT_PROVIDER_URL, router_id=1), + 409, + f"Model provider my-model for url {DEFAULT_PROVIDER_URL} already exists for router 1.", + ), + ( + InconsistentModelMaxContextLengthError(expected_max_context_length=4096, actual_max_context_length=2048, router_name="my-router"), + 403, + "Inconsistent max context length for my-router. Expected: 4096. Actual: 2048", + ), + ( + InconsistentModelVectorSizeError(expected_vector_size=768, actual_vector_size=384, router_name="my-router"), + 403, + "Inconsistent vector size for my-router. Expected: 768. Actual: 384", + ), + ( + InsufficientPermissionError(), + 403, + "Insufficient rights.", + ), + ], + ) + async def test_error_maps_to_correct_http_status(self, client: AsyncClient, app, use_case_result, expected_status, expected_detail): + mock_use_case = AsyncMock() + mock_use_case.execute.return_value = use_case_result + app.dependency_overrides[create_provider_use_case_factory] = lambda: mock_use_case + + response = await client.post( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + json=_valid_body(router_id=1), + ) + + assert response.status_code == expected_status + assert response.json().get("detail") == expected_detail + + @pytest.mark.parametrize( + "headers,expected_status,expected_detail", + [ + ({}, 401, "Not authenticated"), + ({"Authorization": "Bearer invalid-token"}, 403, "Invalid API key."), + ], + ) + async def test_auth(self, client: AsyncClient, headers, expected_status, expected_detail): + response = await client.post(url=URL, headers=headers, json=_valid_body(router_id=1)) + + assert response.status_code == expected_status + assert response.json().get("detail") == expected_detail diff --git a/api/tests/integration/test_admin_router.py b/api/tests/integration/endpoints/test_admin_router.py similarity index 91% rename from api/tests/integration/test_admin_router.py rename to api/tests/integration/endpoints/test_admin_router.py index 4a356be37..8a71045e6 100644 --- a/api/tests/integration/test_admin_router.py +++ b/api/tests/integration/endpoints/test_admin_router.py @@ -2,12 +2,9 @@ import pytest from api.domain.router.entities import RouterLoadBalancingStrategy -from api.schemas.models import ModelType +from api.schemas.models import ModelType as RouterType from api.tests.helpers import create_token -from api.tests.integration.factories import ( - RouterSQLFactory, - UserSQLFactory, -) +from api.tests.integration.factories import RouterSQLFactory, UserSQLFactory from api.utils.variables import EndpointRoute @@ -75,7 +72,7 @@ async def test_create_router_with_duplicate_name(self, client: AsyncClient, db_s RouterSQLFactory( user=admin_user, name=duplicate_name, - type=ModelType.TEXT_GENERATION, + type=RouterType.TEXT_GENERATION, ) await db_session.flush() @@ -83,7 +80,7 @@ async def test_create_router_with_duplicate_name(self, client: AsyncClient, db_s router_data = { "name": duplicate_name, - "type": ModelType.TEXT_GENERATION, + "type": RouterType.TEXT_GENERATION, "aliases": [], "load_balancing_strategy": "shuffle", "cost_prompt_tokens": 0.001, @@ -99,20 +96,20 @@ async def test_create_router_with_duplicate_name(self, client: AsyncClient, db_s # Assert assert response.status_code in [400, 409], f"Expected 400 or 409, got {response.status_code}" - assert response.json().get("detail") == f"Router '{duplicate_name}' already exists." + assert response.json().get("detail") == f"Router {duplicate_name} already exists." async def test_create_router_with_duplicate_alias(self, client: AsyncClient, db_session): # Arrange admin_user = UserSQLFactory(admin_user=True) duplicate_alias = "duplicate-alias" - RouterSQLFactory(user=admin_user, name="existing-router", type=ModelType.TEXT_GENERATION, alias=[duplicate_alias]) + RouterSQLFactory(user=admin_user, name="existing-router", type=RouterType.TEXT_GENERATION, alias=[duplicate_alias]) await db_session.flush() token = await create_token(db_session, name="admin_token", user=admin_user) router_data = { "name": "new-router", - "type": ModelType.TEXT_GENERATION.value, + "type": RouterType.TEXT_GENERATION.value, "aliases": [duplicate_alias], "load_balancing_strategy": RouterLoadBalancingStrategy.SHUFFLE.value, "cost_prompt_tokens": 0.001, diff --git a/api/tests/integration/test_models.py b/api/tests/integration/endpoints/test_models.py similarity index 100% rename from api/tests/integration/test_models.py rename to api/tests/integration/endpoints/test_models.py diff --git a/api/tests/integration/postgres/__init__.py b/api/tests/integration/postgres/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration/postgres/test_postgresproviderrepository.py b/api/tests/integration/postgres/test_postgresproviderrepository.py new file mode 100644 index 000000000..a1bb8f877 --- /dev/null +++ b/api/tests/integration/postgres/test_postgresproviderrepository.py @@ -0,0 +1,86 @@ +import pytest + +from api.domain.model.entities import Metric +from api.domain.provider import Provider, ProviderAlreadyExistsError, ProviderCarbonFootprintZone, ProviderType +from api.domain.router.entities import ModelType +from api.infrastructure.postgres import PostgresProviderRepository +from api.tests.integration.factories import ( + ProviderSQLFactory, + RouterSQLFactory, + UserSQLFactory, +) + +_EXCLUDE = {"id", "created", "updated"} + + +def _create_provider_args(user, router, **overrides): + return { + "user_id": user.id, + "router_id": router.id, + "provider_type": ProviderType.ALBERT, + "url": "http://test.com/", + "key": "model-key", + "timeout": 60, + "model_name": "my-model", + "model_hosting_zone": ProviderCarbonFootprintZone.FRA, + "model_total_params": 1000, + "model_active_params": 2000, + "qos_metric": Metric.TTFT, + "qos_limit": 12, + "vector_size": 10, + "max_context_length": 20, + **overrides, + } + + +@pytest.fixture +def repository(db_session): + return PostgresProviderRepository(db_session) + + +@pytest.mark.asyncio(loop_scope="session") +class TestCreateProvider: + async def test_create_provider_should_return_created_provider(self, repository, db_session): + # Arrange + user = UserSQLFactory(admin_user=True) + router = RouterSQLFactory(user=user, type=ModelType.TEXT_GENERATION) + await db_session.flush() + + # Act + result = await repository.create_provider(**_create_provider_args(user, router)) + + # Assert + expected = Provider( + id=result.id, + router_id=router.id, + user_id=user.id, + type=ProviderType.ALBERT, + url="http://test.com/", + key="model-key", + timeout=60, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.FRA, + model_total_params=1000, + model_active_params=2000, + qos_metric=Metric.TTFT, + qos_limit=12, + vector_size=10, + max_context_length=20, + ) + assert result.model_dump(exclude=_EXCLUDE) == expected.model_dump(exclude=_EXCLUDE) + + async def test_create_provider_should_return_provider_already_exists_when_same_url_name_and_router_are_used(self, repository, db_session): + # Arrange + user = UserSQLFactory(admin_user=True) + router = RouterSQLFactory(user=user, type=ModelType.TEXT_GENERATION) + ProviderSQLFactory(type=ProviderType.ALBERT, url="http://test.com/", model_name="duplicate-provider", router=router) + await db_session.flush() + + # Act + result = await repository.create_provider(**_create_provider_args(user, router, model_name="duplicate-provider")) + + # Assert + assert isinstance(result, ProviderAlreadyExistsError) + assert result.router_id == router.id + assert result.url == "http://test.com/" + assert result.model_name == "duplicate-provider" diff --git a/api/tests/integration/test_postgresrouterrepository.py b/api/tests/integration/postgres/test_postgresrouterrepository.py similarity index 87% rename from api/tests/integration/test_postgresrouterrepository.py rename to api/tests/integration/postgres/test_postgresrouterrepository.py index 4b77ed797..c77c7d532 100644 --- a/api/tests/integration/test_postgresrouterrepository.py +++ b/api/tests/integration/postgres/test_postgresrouterrepository.py @@ -1,14 +1,11 @@ import pytest from api.domain.key.entities import MASTER_USER_ID -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.model import ModelType as RouterType +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.infrastructure.postgres import PostgresRouterRepository -from api.tests.integration.factories import ( - OrganizationSQLFactory, - RouterSQLFactory, - UserSQLFactory, -) +from api.tests.integration.factories import OrganizationSQLFactory, RouterSQLFactory, UserSQLFactory @pytest.fixture @@ -29,13 +26,13 @@ async def test_get_all_routers_should_return_all_routers(self, repository, db_se user_2 = UserSQLFactory() router_1 = RouterSQLFactory( - user=user_1, name="router_1", type=ModelType.TEXT_GENERATION, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, providers=2 + user=user_1, name="router_1", type=RouterType.TEXT_GENERATION, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, providers=2 ) router_2 = RouterSQLFactory( - user=user_1, name="router_2", type=ModelType.TEXT_EMBEDDINGS_INFERENCE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, providers=1 + user=user_1, name="router_2", type=RouterType.TEXT_EMBEDDINGS_INFERENCE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, providers=1 ) router_3 = RouterSQLFactory( - user=user_2, name="router_3", type=ModelType.TEXT_EMBEDDINGS_INFERENCE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, providers=1 + user=user_2, name="router_3", type=RouterType.TEXT_EMBEDDINGS_INFERENCE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, providers=1 ) # Act @@ -48,7 +45,7 @@ async def test_get_all_routers_should_return_all_routers(self, repository, db_se assert router_names == {router_1.name, router_2.name, router_3.name} result_router_1 = result_routers[0] - assert result_router_1.type == ModelType.TEXT_GENERATION + assert result_router_1.type == RouterType.TEXT_GENERATION assert result_router_1.providers == 2 assert result_router_1.cost_prompt_tokens == 0.001 assert result_router_1.cost_completion_tokens == 0.002 @@ -58,7 +55,7 @@ async def test_get_all_routers_should_return_all_routers(self, repository, db_se async def test_get_all_routers_should_return_routers_with_master_id_user(self, repository, db_session): # Arrange RouterSQLFactory( - user=None, name="router_1", type=ModelType.TEXT_GENERATION, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, providers=2 + user=None, name="router_1", type=RouterType.TEXT_GENERATION, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, providers=2 ) # Act @@ -92,7 +89,7 @@ async def test_get_all_aliases_should_return_all_aliases(self, repository, db_se # Act await db_session.flush() - aliases = await repository.get_aliases_by_router_id() + aliases = await repository.get_all_aliases_grouped_by_router() # Assert assert aliases == { router_1.id: ["alias1_m1", "alias2_m1"], @@ -112,7 +109,7 @@ async def test_create_router_should_return_created_router_without_alias(self, re # Act result = await repository.create_router( name="test-router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.001, cost_completion_tokens=0.002, @@ -122,7 +119,7 @@ async def test_create_router_should_return_created_router_without_alias(self, re # Assert assert isinstance(result, Router) assert result.name == "test-router" - assert result.type == ModelType.TEXT_GENERATION + assert result.type == RouterType.TEXT_GENERATION assert result.load_balancing_strategy == RouterLoadBalancingStrategy.SHUFFLE assert result.cost_prompt_tokens == 0.001 assert result.cost_completion_tokens == 0.002 @@ -142,7 +139,7 @@ async def test_create_router_should_return_router_name_already_exists_when_name_ # Act result = await repository.create_router( name="duplicate-router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, @@ -160,7 +157,7 @@ async def test_create_router_with_master_user_id_should_set_db_user_id_to_null(s # Act result = await repository.create_router( name="master-router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, @@ -180,7 +177,7 @@ async def test_create_router_with_aliases_should_insert_aliases(self, repository # Act result = await repository.create_router( name="router-with-aliases", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, @@ -207,7 +204,7 @@ async def test_create_router_should_return_router_alias_already_exists_when_one_ # Act result = await repository.create_router( name="router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, aliases=[duplicate_alias], load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, @@ -229,7 +226,7 @@ async def test_create_router_should_return_router_alias_already_exists_when_seve # Act result = await repository.create_router( name="router", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, aliases=duplicate_aliases, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, diff --git a/api/tests/unit/test_helpers/test_modelregistry/test_providers.py b/api/tests/unit/test_helpers/test_modelregistry/test_providers.py index 2168e3a23..9aedf79ca 100644 --- a/api/tests/unit/test_helpers/test_modelregistry/test_providers.py +++ b/api/tests/unit/test_helpers/test_modelregistry/test_providers.py @@ -760,7 +760,7 @@ async def test_update_provider_change_router_invalid_type(postgres_session: Asyn ) # Mock get_providers to return provider (will be called with router_id=None, provider_id=1) - # Note: get_providers creates Provider with type from DB (string), which Pydantic converts to enum + # Note: get_providers creates provider with type from DB (string), which Pydantic converts to enum provider_result = _Result() provider_result.mappings = lambda: _MappingsResult( [ diff --git a/api/tests/unit/use_case/factories.py b/api/tests/unit/use_case/factories.py index d317b3350..f247f071f 100644 --- a/api/tests/unit/use_case/factories.py +++ b/api/tests/unit/use_case/factories.py @@ -4,8 +4,10 @@ import factory from factory import fuzzy +from api.domain.model import ModelType as RouterType +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType from api.domain.role.entities import Limit, LimitType, PermissionType, Role -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.user.entities import User from api.domain.userinfo.entities import UserInfo @@ -42,7 +44,7 @@ class Meta: id = factory.Sequence(lambda n: n + 1) name = factory.Faker("bothify", text="router_####") user_id = factory.Faker("random_int", min=1, max=1000) - type = factory.Faker("random_element", elements=list(ModelType)) + type = factory.Faker("random_element", elements=list(RouterType)) aliases = None load_balancing_strategy = factory.Faker("random_element", elements=list(RouterLoadBalancingStrategy)) vector_size = None @@ -62,7 +64,7 @@ class Params: ) embedding = factory.Trait( - type=ModelType.TEXT_EMBEDDINGS_INFERENCE, + type=RouterType.TEXT_EMBEDDINGS_INFERENCE, vector_size=factory.Faker("random_element", elements=[384, 768, 1536, 3072]), max_context_length=factory.Faker("random_element", elements=[512, 1024, 2048, 8192]), ) @@ -70,6 +72,27 @@ class Params: with_providers = factory.Trait(providers=factory.Faker("random_int", min=1, max=5)) +class ProviderFactory(factory.Factory): + class Meta: + model = Provider + + id = factory.Sequence(lambda n: n + 1) + router_id = factory.Faker("random_int", min=1, max=1000) + user_id = factory.Faker("random_int", min=1, max=1000) + type = factory.Faker("random_element", elements=list(ProviderType)) + url = factory.Faker("url") + key = None + timeout = 30 + model_name = factory.Faker("bothify", text="model-????") + model_hosting_zone = ProviderCarbonFootprintZone.WOR + model_total_params = 0 + model_active_params = 0 + qos_metric = None + qos_limit = None + created = factory.LazyFunction(lambda: int(datetime.now(UTC).timestamp())) + updated = factory.LazyFunction(lambda: int(datetime.now(UTC).timestamp())) + + class UserFactory(factory.Factory): class Meta: model = User diff --git a/api/tests/unit/use_case/test_createproviderusecase.py b/api/tests/unit/use_case/test_createproviderusecase.py new file mode 100644 index 000000000..dc0e484e7 --- /dev/null +++ b/api/tests/unit/use_case/test_createproviderusecase.py @@ -0,0 +1,314 @@ +from unittest.mock import AsyncMock + +import pytest + +from api.domain.model import ModelType as RouterType +from api.domain.model.errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError +from api.domain.provider import ProviderCapabilities +from api.domain.provider.entities import ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError +from api.domain.router.errors import RouterNotFoundError +from api.domain.userinfo.errors import InsufficientPermissionError +from api.tests.unit.use_case.factories import ProviderFactory, RouterFactory, UserInfoFactory +from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess + + +@pytest.fixture +def use_case(): + return CreateProviderUseCase( + router_repository=AsyncMock(), + provider_repository=AsyncMock(), + provider_gateway=AsyncMock(), + user_info_repository=AsyncMock(), + ) + + +@pytest.fixture +def sample_router(): + return RouterFactory( + id=1, + name="test-router", + type=RouterType.TEXT_GENERATION, + providers=0, + ) + + +@pytest.fixture +def sample_router_with_providers(): + return RouterFactory( + id=1, + name="test-router", + type=RouterType.TEXT_GENERATION, + providers=2, + max_context_length=4096, + vector_size=None, + ) + + +@pytest.fixture +def sample_embedding_router_with_providers(): + return RouterFactory( + id=1, + name="embedding-router", + type=RouterType.TEXT_EMBEDDINGS_INFERENCE, + providers=1, + max_context_length=512, + vector_size=768, + ) + + +@pytest.fixture +def sample_provider(): + return ProviderFactory( + id=1, + router_id=1, + user_id=1, + type=ProviderType.VLLM, + url="https://example.com/", + model_name="my-model", + ) + + +@pytest.fixture +def default_command(): + return CreateProviderCommand( + router_id=1, + user_id=1, + provider_type=ProviderType.VLLM, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.WOR, + model_total_params=0, + model_active_params=0, + qos_metric=None, + qos_limit=None, + ) + + +def with_provider_type(command: CreateProviderCommand, provider_type: ProviderType) -> CreateProviderCommand: + return CreateProviderCommand( + router_id=command.router_id, + user_id=command.user_id, + provider_type=provider_type, + url=command.url, + key=command.key, + timeout=command.timeout, + model_name=command.model_name, + model_hosting_zone=command.model_hosting_zone, + model_total_params=command.model_total_params, + model_active_params=command.model_active_params, + qos_metric=command.qos_metric, + qos_limit=command.qos_limit, + ) + + +class TestCreateProviderUseCase: + @pytest.mark.asyncio + async def test_should_create_provider_when_router_exists_without_any_provider(self, use_case, sample_router, sample_provider, default_command): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=4096, vector_size=None) + use_case.provider_repository.create_provider.return_value = sample_provider + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, CreateProviderUseCaseSuccess) + assert result.provider == sample_provider + use_case.router_repository.get_router_by_id.assert_called_once_with(router_id=1) + use_case.provider_gateway.get_capabilities.assert_called_once_with( + router_type=RouterType.TEXT_GENERATION, + provider_type=ProviderType.VLLM, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + ) + use_case.provider_repository.create_provider.assert_called_once_with( + router_id=1, + user_id=1, + provider_type=ProviderType.VLLM, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.WOR, + model_total_params=0, + model_active_params=0, + qos_metric=None, + qos_limit=None, + max_context_length=4096, + vector_size=None, + ) + + @pytest.mark.asyncio + async def test_should_create_provider_when_router_has_a_different_provider( + self, use_case, sample_router_with_providers, sample_provider, default_command + ): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router_with_providers + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=4096, vector_size=None) + use_case.provider_repository.create_provider.return_value = sample_provider + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, CreateProviderUseCaseSuccess) + assert result.provider == sample_provider + use_case.provider_repository.create_provider.assert_called_once_with( + router_id=1, + user_id=1, + provider_type=ProviderType.VLLM, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.WOR, + model_total_params=0, + model_active_params=0, + qos_metric=None, + qos_limit=None, + max_context_length=4096, + vector_size=None, + ) + + @pytest.mark.asyncio + async def test_should_create_embedding_provider_when_vector_size_matches( + self, use_case, sample_embedding_router_with_providers, sample_provider, default_command + ): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_embedding_router_with_providers + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=512, vector_size=768) + use_case.provider_repository.create_provider.return_value = sample_provider + + # Act + result = await use_case.execute(with_provider_type(default_command, ProviderType.TEI)) + + # Assert + assert isinstance(result, CreateProviderUseCaseSuccess) + assert result.provider == sample_provider + use_case.provider_repository.create_provider.assert_called_once_with( + router_id=1, + user_id=1, + provider_type=ProviderType.TEI, + url="https://example.com/", + key=None, + timeout=30, + model_name="my-model", + model_hosting_zone=ProviderCarbonFootprintZone.WOR, + model_total_params=0, + model_active_params=0, + qos_metric=None, + qos_limit=None, + max_context_length=512, + vector_size=768, + ) + + @pytest.mark.asyncio + async def test_should_return_router_not_found_error_when_router_does_not_exist(self, use_case, default_command): + # Arrange + use_case.router_repository.get_router_by_id.return_value = None + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, RouterNotFoundError) + assert result.router_id == 1 + use_case.provider_gateway.get_capabilities.assert_not_called() + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_invalid_provider_type_error_when_type_not_compatible(self, use_case, default_command): + # Arrange + router = RouterFactory(id=1, name="tei-router", type=RouterType.TEXT_CLASSIFICATION) + use_case.router_repository.get_router_by_id.return_value = router + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, InvalidProviderTypeError) + assert result.provider_type == ProviderType.VLLM.value + assert result.router_type == RouterType.TEXT_CLASSIFICATION + use_case.provider_gateway.get_capabilities.assert_not_called() + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_provider_not_reachable_error_when_gateway_fails(self, use_case, sample_router, default_command): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router + use_case.provider_gateway.get_capabilities.return_value = ProviderNotReachableError(model_name="my-model") + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, ProviderNotReachableError) + assert result.model_name == "my-model" + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_inconsistent_max_context_length_error_when_mismatch(self, use_case, sample_router_with_providers, default_command): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router_with_providers + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=2048, vector_size=None) + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, InconsistentModelMaxContextLengthError) + assert result.actual_max_context_length == 2048 + assert result.expected_max_context_length == 4096 + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_inconsistent_vector_size_error_when_mismatch( + self, use_case, sample_embedding_router_with_providers, default_command + ): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_embedding_router_with_providers + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=512, vector_size=384) + + # Act + result = await use_case.execute(with_provider_type(default_command, ProviderType.TEI)) + + # Assert + assert isinstance(result, InconsistentModelVectorSizeError) + assert result.actual_vector_size == 384 + assert result.expected_vector_size == 768 + use_case.provider_repository.create_provider.assert_not_called() + + @pytest.mark.asyncio + async def test_should_return_provider_already_exists_error(self, use_case, sample_router, default_command): + # Arrange + use_case.router_repository.get_router_by_id.return_value = sample_router + use_case.provider_gateway.get_capabilities.return_value = ProviderCapabilities(max_context_length=4096, vector_size=None) + use_case.provider_repository.create_provider.return_value = ProviderAlreadyExistsError( + model_name="my-model", url="https://example.com/", router_id=1 + ) + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, ProviderAlreadyExistsError) + assert result.model_name == "my-model" + assert result.url == "https://example.com/" + assert result.router_id == 1 + + @pytest.mark.asyncio + async def test_should_return_insufficient_permission_error_when_user_not_admin(self, use_case, default_command): + # Arrange + use_case.user_info_repository.get_user_info.return_value = UserInfoFactory(id=1, without_permission=True, limits=[]) + + # Act + result = await use_case.execute(default_command) + + # Assert + assert isinstance(result, InsufficientPermissionError) diff --git a/api/tests/unit/use_case/test_createrouterusecase.py b/api/tests/unit/use_case/test_createrouterusecase.py index bf95c3cde..9ba32d21e 100644 --- a/api/tests/unit/use_case/test_createrouterusecase.py +++ b/api/tests/unit/use_case/test_createrouterusecase.py @@ -2,25 +2,17 @@ import pytest -from api.domain.router.entities import ModelType, RouterLoadBalancingStrategy +from api.domain.model import ModelType as RouterType +from api.domain.router.entities import RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.domain.userinfo.errors import InsufficientPermissionError from api.tests.unit.use_case.factories import RouterFactory, UserInfoFactory -from api.use_cases.admin import ( - CreateRouterUseCase, -) +from api.use_cases.admin import CreateRouterCommand, CreateRouterUseCase @pytest.fixture -def router_repository(): - repo = AsyncMock() - return repo - - -@pytest.fixture -def user_info_repository(): - repo = AsyncMock() - return repo +def use_case(): + return CreateRouterUseCase(router_repository=AsyncMock(), user_info_repository=AsyncMock()) @pytest.fixture @@ -38,7 +30,7 @@ def sample_router_with_aliases(): return RouterFactory( id=1, name="test-model", - type=ModelType.TEXT_GENERATION, + type=RouterType.TEXT_GENERATION, aliases=["alias1", "alias2"], user_id=1, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, @@ -48,41 +40,33 @@ def sample_router_with_aliases(): ) -@pytest.fixture -def use_case(router_repository, user_info_repository): - return CreateRouterUseCase( - router_repository=router_repository, - user_info_repository=user_info_repository, - ) - - class TestCreateRouterUseCase: @pytest.mark.asyncio - async def test_should_create_router_with_aliases_when_aliases_are_given( - self, router_repository, user_info_repository, admin_user_info, sample_router_with_aliases, use_case - ): + async def test_should_create_router_with_aliases_when_aliases_are_given(self, use_case, admin_user_info, sample_router_with_aliases): # Arrange - user_info_repository.get_user_info.return_value = admin_user_info - router_repository.create_router.return_value = sample_router_with_aliases + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.router_repository.create_router.return_value = sample_router_with_aliases # Act result = await use_case.execute( - user_id=1, - name="test-model", - router_type=ModelType.TEXT_GENERATION, - aliases=["alias1", "alias2"], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.01, - cost_completion_tokens=0.02, + command=CreateRouterCommand( + user_id=1, + name="test-model", + router_type=RouterType.TEXT_GENERATION, + aliases=["alias1", "alias2"], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.01, + cost_completion_tokens=0.02, + ) ) # Assert assert result.router == sample_router_with_aliases - user_info_repository.get_user_info.assert_called_once_with(user_id=1) - router_repository.create_router.assert_called_once_with( + use_case.user_info_repository.get_user_info.assert_called_once_with(user_id=1) + use_case.router_repository.create_router.assert_called_once_with( name="test-model", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.01, cost_completion_tokens=0.02, @@ -91,36 +75,36 @@ async def test_should_create_router_with_aliases_when_aliases_are_given( ) @pytest.mark.asyncio - async def test_should_create_router_without_aliases_if_no_alias_is_given( - self, router_repository, user_info_repository, admin_user_info, use_case - ): + async def test_should_create_router_without_aliases_if_no_alias_is_given(self, use_case, admin_user_info): # Arrange router_without_aliases = RouterFactory( id=2, name="model-no-alias", - type=ModelType.TEXT_GENERATION, + type=RouterType.TEXT_GENERATION, aliases=[], user_id=1, ) - user_info_repository.get_user_info.return_value = admin_user_info - router_repository.create_router.return_value = router_without_aliases + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.router_repository.create_router.return_value = router_without_aliases # Act result = await use_case.execute( - user_id=1, - name="model-no-alias", - router_type=ModelType.TEXT_GENERATION, - aliases=[], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.0, - cost_completion_tokens=0.0, + command=CreateRouterCommand( + user_id=1, + name="model-no-alias", + router_type=RouterType.TEXT_GENERATION, + aliases=[], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.0, + cost_completion_tokens=0.0, + ) ) # Assert assert result.router == router_without_aliases - router_repository.create_router.assert_called_once_with( + use_case.router_repository.create_router.assert_called_once_with( name="model-no-alias", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, @@ -129,85 +113,85 @@ async def test_should_create_router_without_aliases_if_no_alias_is_given( ) @pytest.mark.asyncio - async def test_should_return_insufficient_permission_error_if_user_not_admin( - self, router_repository, user_info_repository, non_admin_user_info, use_case - ): + async def test_should_return_insufficient_permission_error_if_user_not_admin(self, use_case, non_admin_user_info): # Arrange - user_info_repository.get_user_info.return_value = non_admin_user_info + use_case.user_info_repository.get_user_info.return_value = non_admin_user_info # Act error = await use_case.execute( - user_id=2, - name="test-model", - router_type=ModelType.TEXT_GENERATION, - aliases=[], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.0, - cost_completion_tokens=0.0, + command=CreateRouterCommand( + user_id=2, + name="test-model", + router_type=RouterType.TEXT_GENERATION, + aliases=[], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.0, + cost_completion_tokens=0.0, + ) ) # Assert assert isinstance(error, InsufficientPermissionError) - router_repository.create_router.assert_not_called() + use_case.router_repository.create_router.assert_not_called() @pytest.mark.asyncio - async def test_should_return_router_alias_already_exists_when_alias_already_exists( - self, router_repository, user_info_repository, admin_user_info, use_case - ): + async def test_should_return_router_alias_already_exists_when_alias_already_exists(self, use_case, admin_user_info): # Arrange - user_info_repository.get_user_info.return_value = admin_user_info - router_repository.create_router.return_value = RouterAliasAlreadyExistsError(aliases=["alias1"]) + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.router_repository.create_router.return_value = RouterAliasAlreadyExistsError(aliases=["alias1"]) # Act error = await use_case.execute( - user_id=1, - name="test-model", - router_type=ModelType.TEXT_GENERATION, - aliases=["alias1", "alias2"], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.0, - cost_completion_tokens=0.0, + command=CreateRouterCommand( + user_id=1, + name="test-model", + router_type=RouterType.TEXT_GENERATION, + aliases=["alias1", "alias2"], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.0, + cost_completion_tokens=0.0, + ) ) # Assert assert isinstance(error, RouterAliasAlreadyExistsError) - router_repository.create_router.assert_called_once_with( + use_case.router_repository.create_router.assert_called_once_with( aliases=["alias1", "alias2"], cost_completion_tokens=0.0, cost_prompt_tokens=0.0, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, name="test-model", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, user_id=1, ) @pytest.mark.asyncio - async def test_should_return_router_name_already_exists_when_name_already_exists( - self, router_repository, user_info_repository, admin_user_info, use_case - ): + async def test_should_return_router_name_already_exists_when_name_already_exists(self, use_case, admin_user_info): # Arrange - user_info_repository.get_user_info.return_value = admin_user_info - router_repository.create_router.return_value = RouterNameAlreadyExistsError(name="existing-router") + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.router_repository.create_router.return_value = RouterNameAlreadyExistsError(name="existing-router") # Act error = await use_case.execute( - user_id=1, - name="test-model", - router_type=ModelType.TEXT_GENERATION, - aliases=["alias1", "alias2"], - load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, - cost_prompt_tokens=0.0, - cost_completion_tokens=0.0, + command=CreateRouterCommand( + user_id=admin_user_info.id, + name="test-model", + router_type=RouterType.TEXT_GENERATION, + aliases=["alias1", "alias2"], + load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, + cost_prompt_tokens=0.0, + cost_completion_tokens=0.0, + ) ) # Assert assert isinstance(error, RouterNameAlreadyExistsError) - router_repository.create_router.assert_called_once_with( + use_case.router_repository.create_router.assert_called_once_with( + user_id=admin_user_info.id, name="test-model", - router_type=ModelType.TEXT_GENERATION, + router_type=RouterType.TEXT_GENERATION, load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE, cost_prompt_tokens=0.0, cost_completion_tokens=0.0, - user_id=admin_user_info.id, aliases=["alias1", "alias2"], ) diff --git a/api/tests/unit/use_case/test_getmodelsusecase.py b/api/tests/unit/use_case/test_getmodelsusecase.py index a46458382..7dccc4e2f 100644 --- a/api/tests/unit/use_case/test_getmodelsusecase.py +++ b/api/tests/unit/use_case/test_getmodelsusecase.py @@ -3,8 +3,8 @@ import pytest +from api.domain.model import ModelType from api.domain.role.entities import LimitType -from api.domain.router.entities import ModelType from api.domain.userinfo.entities import Limit from api.tests.unit.use_case.factories import RouterFactory, UserInfoFactory from api.use_cases.models import GetModelsUseCase diff --git a/api/use_cases/admin/__init__.py b/api/use_cases/admin/__init__.py index e9e497d35..7f55f81f9 100644 --- a/api/use_cases/admin/__init__.py +++ b/api/use_cases/admin/__init__.py @@ -1,6 +1,3 @@ -from ._createrouterusecase import CreateRouterUseCase, CreateRouterUseCaseSuccess +from ._createrouterusecase import CreateRouterCommand, CreateRouterUseCase, CreateRouterUseCaseSuccess -__all__ = [ - "CreateRouterUseCase", - "CreateRouterUseCaseSuccess", -] +__all__ = ["CreateRouterCommand", "CreateRouterUseCase", "CreateRouterUseCaseSuccess"] diff --git a/api/use_cases/admin/_createrouterusecase.py b/api/use_cases/admin/_createrouterusecase.py index a0e25b694..0afabba5d 100644 --- a/api/use_cases/admin/_createrouterusecase.py +++ b/api/use_cases/admin/_createrouterusecase.py @@ -1,12 +1,24 @@ from dataclasses import dataclass +from api.domain.model import ModelType as RouterType from api.domain.router import RouterRepository -from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.domain.userinfo import UserInfoRepository from api.domain.userinfo.errors import InsufficientPermissionError +@dataclass +class CreateRouterCommand: + user_id: int + name: str + router_type: RouterType + aliases: list[str] + load_balancing_strategy: RouterLoadBalancingStrategy + cost_prompt_tokens: float + cost_completion_tokens: float + + @dataclass class CreateRouterUseCaseSuccess: router: Router @@ -24,27 +36,21 @@ def __init__(self, router_repository: RouterRepository, user_info_repository: Us async def execute( self, - user_id: int, - name: str, - router_type: ModelType, - aliases: list[str], - load_balancing_strategy: RouterLoadBalancingStrategy, - cost_prompt_tokens: float, - cost_completion_tokens: float, + command: CreateRouterCommand, ) -> CreateRouterUseCaseResult: - user_info = await self.user_info_repository.get_user_info(user_id=user_id) + user_info = await self.user_info_repository.get_user_info(user_id=command.user_id) if not user_info.is_admin: return InsufficientPermissionError() result = await self.router_repository.create_router( - name=name, - router_type=router_type, - load_balancing_strategy=load_balancing_strategy, - cost_prompt_tokens=cost_prompt_tokens, - cost_completion_tokens=cost_completion_tokens, - user_id=user_id, - aliases=aliases, + name=command.name, + router_type=command.router_type, + load_balancing_strategy=command.load_balancing_strategy, + cost_prompt_tokens=command.cost_prompt_tokens, + cost_completion_tokens=command.cost_completion_tokens, + user_id=command.user_id, + aliases=command.aliases, ) match result: diff --git a/api/use_cases/admin/providers/__init__.py b/api/use_cases/admin/providers/__init__.py new file mode 100644 index 000000000..6ba58c5e0 --- /dev/null +++ b/api/use_cases/admin/providers/__init__.py @@ -0,0 +1,3 @@ +from ._createproviderusecase import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess + +__all__ = ["CreateProviderCommand", "CreateProviderUseCase", "CreateProviderUseCaseSuccess"] diff --git a/api/use_cases/admin/providers/_createproviderusecase.py b/api/use_cases/admin/providers/_createproviderusecase.py new file mode 100644 index 000000000..e4068430a --- /dev/null +++ b/api/use_cases/admin/providers/_createproviderusecase.py @@ -0,0 +1,126 @@ +from dataclasses import dataclass + +from api.domain.model import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError, ModelType +from api.domain.provider import InvalidProviderTypeError, ProviderGateway, ProviderNotReachableError, ProviderRepository +from api.domain.provider.entities import COMPATIBLE_PROVIDER_TYPES, Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.errors import ProviderAlreadyExistsError +from api.domain.router import RouterRepository +from api.domain.router.errors import RouterNotFoundError +from api.domain.userinfo import UserInfoRepository +from api.domain.userinfo.errors import InsufficientPermissionError +from api.schemas.core.models import Metric + + +@dataclass +class CreateProviderCommand: + router_id: int + user_id: int + provider_type: ProviderType + url: str + key: str | None + timeout: int + model_name: str + model_hosting_zone: ProviderCarbonFootprintZone + model_total_params: int + model_active_params: int + qos_metric: Metric | None + qos_limit: float | None + + +@dataclass +class CreateProviderUseCaseSuccess: + provider: Provider + + +type CreateProviderUseCaseResult = ( + CreateProviderUseCaseSuccess + | InvalidProviderTypeError + | ProviderNotReachableError + | InconsistentModelMaxContextLengthError + | InconsistentModelVectorSizeError + | RouterNotFoundError + | ProviderAlreadyExistsError + | InsufficientPermissionError +) + + +class CreateProviderUseCase: + def __init__( + self, + router_repository: RouterRepository, + provider_repository: ProviderRepository, + provider_gateway: ProviderGateway, + user_info_repository: UserInfoRepository, + ): + self.router_repository = router_repository + self.provider_repository = provider_repository + self.provider_gateway = provider_gateway + self.user_info_repository = user_info_repository + + async def execute(self, command: CreateProviderCommand) -> CreateProviderUseCaseResult: + user_info = await self.user_info_repository.get_user_info(user_id=command.user_id) + + if not user_info.is_admin: + return InsufficientPermissionError() + + router = await self.router_repository.get_router_by_id(router_id=command.router_id) + if router is None: + return RouterNotFoundError(command.router_id) + + if command.provider_type.value not in COMPATIBLE_PROVIDER_TYPES[router.type]: + return InvalidProviderTypeError(provider_type=command.provider_type.value, router_type=router.type.value) + + result = await self.provider_gateway.get_capabilities( + router_type=router.type, + provider_type=command.provider_type, + url=command.url, + key=command.key, + timeout=command.timeout, + model_name=command.model_name, + ) + # @ TODO: separate health check logic from get_capabilities + + match result: + case ProviderNotReachableError() as error: + return error + case provider_capabilities: + pass + + max_context_length = provider_capabilities.max_context_length + if router.type == ModelType.TEXT_EMBEDDINGS_INFERENCE: + vector_size = provider_capabilities.vector_size + else: + vector_size = None + + if router.providers > 0: + if router.vector_size != vector_size: + return InconsistentModelVectorSizeError( + actual_vector_size=vector_size, expected_vector_size=router.vector_size, router_name=router.name + ) + if router.max_context_length != max_context_length: + return InconsistentModelMaxContextLengthError( + actual_max_context_length=max_context_length, expected_max_context_length=router.max_context_length, router_name=router.name + ) + + result = await self.provider_repository.create_provider( + router_id=command.router_id, + user_id=command.user_id, + provider_type=command.provider_type, + url=command.url, + key=command.key, + timeout=command.timeout, + model_name=command.model_name, + model_hosting_zone=command.model_hosting_zone, + model_total_params=command.model_total_params, + model_active_params=command.model_active_params, + qos_metric=command.qos_metric, + qos_limit=command.qos_limit, + max_context_length=max_context_length, + vector_size=vector_size, + ) + + match result: + case Provider() as provider: + return CreateProviderUseCaseSuccess(provider) + case error: + return error diff --git a/api/use_cases/models/_getmodelsusecase.py b/api/use_cases/models/_getmodelsusecase.py index 69ed93888..40ca49ef1 100644 --- a/api/use_cases/models/_getmodelsusecase.py +++ b/api/use_cases/models/_getmodelsusecase.py @@ -1,7 +1,7 @@ from dataclasses import dataclass +from api.domain.model import Model, ModelCosts from api.domain.router import RouterRepository -from api.domain.router.entities import Model, ModelCosts from api.domain.userinfo import UserInfoRepository diff --git a/api/utils/variables.py b/api/utils/variables.py index 81617f416..e58482f6e 100644 --- a/api/utils/variables.py +++ b/api/utils/variables.py @@ -11,7 +11,7 @@ class RouterName(StrEnum): - ADMIN = ("admin", "api.endpoints.admin") + ADMIN = ("admin", "api.infrastructure.fastapi.endpoints.admin") AUDIO = ("audio", "api.endpoints.audio") AUTH = ("auth", "api.endpoints.auth") CHAT = ("chat", "api.endpoints.chat") diff --git a/playground/app/features/providers/models.py b/playground/app/features/providers/models.py index 4f243e9a3..7da297636 100644 --- a/playground/app/features/providers/models.py +++ b/playground/app/features/providers/models.py @@ -2,7 +2,7 @@ class Provider(Entity): - """Provider model.""" + """provider model.""" id: int | None = None router: str | None = None diff --git a/playground/app/features/providers/state.py b/playground/app/features/providers/state.py index b93757938..1a33cef4d 100644 --- a/playground/app/features/providers/state.py +++ b/playground/app/features/providers/state.py @@ -218,7 +218,7 @@ async def delete_entity(self): response.raise_for_status() self.handle_delete_entity_dialog_change(is_open=False) - yield rx.toast.success("Provider deleted successfully", position="bottom-right") + yield rx.toast.success("provider deleted successfully", position="bottom-right") async for _ in self.load_entities(): yield @@ -289,7 +289,7 @@ async def create_entity(self): ) response.raise_for_status() - yield rx.toast.success("Provider created successfully", position="bottom-right") + yield rx.toast.success("provider created successfully", position="bottom-right") async for _ in self.load_entities(): yield @@ -356,7 +356,7 @@ async def edit_entity(self): response.raise_for_status() self.handle_settings_entity_dialog_change(is_open=False) - yield rx.toast.success("Provider updated successfully", position="bottom-right") + yield rx.toast.success("provider updated successfully", position="bottom-right") async for _ in self.load_entities(): yield