Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/badges/coverage.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"schemaVersion":1,"label":"coverage","message":"50.85%","color":"red"}
{"schemaVersion":1,"label":"coverage","message":"49.87%","color":"red"}
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,4 @@ playground/.gitignore
playground/requirements.txt
run.sh
.claude

bruno
9 changes: 6 additions & 3 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ def _register_routers(app: FastAPI, configuration: Configuration) -> None:
include_in_schema = enabled_router not in hidden_routers
app.include_router(router=router, include_in_schema=include_in_schema)

# @TODO: legacy import before total clean archi migration
# @TODO: create admin folder in infrastructure.fastapi.endpoints with router declaration in __init__.py
# @TODO: legacy import, remove after total clean archi migration
if RouterName.ADMIN not in disabled_routers:
module = import_module("api.infrastructure.fastapi.endpoints.admin_router")
module = import_module("api.endpoints.admin.routers")
app.include_router(router=module.router, include_in_schema=RouterName.ADMIN not in hidden_routers)

if RouterName.ADMIN not in disabled_routers:
module = import_module("api.endpoints.admin.providers")
app.include_router(router=module.router, include_in_schema=RouterName.ADMIN not in hidden_routers)


Expand Down
1 change: 1 addition & 0 deletions api/clients/model/_albertmodelprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def get_max_context_length(self) -> int | None:
response = await client.get(url=url, headers=self.headers, timeout=self.timeout)
response.raise_for_status()
except Exception as e:
# TODO: remove exc_info=True and return error instead of exception
logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True)
raise AssertionError(f"Model is not reachable ({e}).")

Expand Down
2 changes: 1 addition & 1 deletion api/clients/model/_basemodelprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def import_module(type: ProviderType) -> "type[BaseModelProvider]":
return getattr(module, f"{type.capitalize()}ModelProvider")

@staticmethod
async def get_max_context_length(self) -> int | None:
async def get_max_context_length() -> int | None:
"""
Get the max context length of the model provider to store in the database. Useful
to check provider consistency.
Expand Down
1 change: 1 addition & 0 deletions api/clients/model/_mistralmodelprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def get_max_context_length(self) -> int | None:
async with httpx.AsyncClient() as client:
response = await client.get(url=url, headers=self.headers, timeout=self.timeout)
response.raise_for_status()

except Exception as e:
logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True)
raise AssertionError(f"Model is not reachable ({e}).")
Expand Down
29 changes: 14 additions & 15 deletions api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,17 @@
from sqlalchemy.ext.asyncio import AsyncSession

from api.domain.key import KeyRepository
from api.infrastructure.postgres import PostgresKeyRepository, PostgresRouterRepository, PostgresUserInfoRepository
from api.infrastructure.model import ModelProviderGateway
from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository
from api.schemas.core.context import RequestContext
from api.use_cases.admin import CreateRouterUseCase
from api.use_cases.admin.providers import CreateProviderUseCase
from api.use_cases.models import GetModelsUseCase
from api.utils.configuration import configuration
from api.utils.context import global_context, request_context


async def get_postgres_session() -> AsyncGenerator[AsyncSession]:
"""
Get a PostgreSQL postgres_session from the global context.

Returns:
AsyncSession: A PostgreSQL postgres_session instance.
"""

session_factory = global_context.postgres_session_factory
async with session_factory() as postgres_session:
try:
Expand All @@ -34,13 +29,6 @@ async def get_postgres_session() -> AsyncGenerator[AsyncSession]:


def get_request_context() -> ContextVar[RequestContext]:
"""
Get the RequestContext ContextVar from the global context.

Returns:
ContextVar[RequestContext]: The RequestContext ContextVar instance.
"""

return request_context


Expand All @@ -55,6 +43,17 @@ def get_models_use_case(
)


def create_provider_use_case_factory(
postgres_session: AsyncSession = Depends(get_postgres_session),
) -> CreateProviderUseCase:
return CreateProviderUseCase(
router_repository=PostgresRouterRepository(postgres_session=postgres_session, app_title=configuration.settings.app_title),
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
provider_gateway=ModelProviderGateway(),
user_info_repository=PostgresUserInfoRepository(postgres_session=postgres_session),
)


def create_router_use_case(postgres_session: AsyncSession = Depends(get_postgres_session)) -> CreateRouterUseCase:
return CreateRouterUseCase(
router_repository=PostgresRouterRepository(postgres_session=postgres_session, app_title=configuration.settings.app_title),
Expand Down
4 changes: 4 additions & 0 deletions api/domain/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .entities import Model, ModelCosts, ModelType
from .errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError

__all__ = ["ModelType", "Model", "ModelCosts", "InconsistentModelMaxContextLengthError", "InconsistentModelVectorSizeError"]
34 changes: 34 additions & 0 deletions api/domain/model/entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum

from pydantic import BaseModel, Field


class Metric(str, Enum):
TTFT = "ttft" # time to first token
LATENCY = "latency" # requests latency
INFLIGHT = "inflight" # requests concurrency
PERFORMANCE = "performance" # custom performance metric


class ModelCosts(BaseModel):
prompt_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)")
completion_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million completion tokens (decrease user budget)")


class ModelType(str, Enum):
AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition"
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
IMAGE_TO_TEXT = "image-to-text"
TEXT_EMBEDDINGS_INFERENCE = "text-embeddings-inference"
TEXT_GENERATION = "text-generation"
TEXT_CLASSIFICATION = "text-classification"


class Model(BaseModel):
id: str = Field(..., description="The model identifier, which can be referenced in the API endpoints.")
type: ModelType = Field(..., description="The type of the model, which can be used to identify the model type.", examples=["text-generation"]) # fmt: off
aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off
created: int = Field(..., description="Time of creation, as Unix timestamp.")
owned_by: str = Field(..., description="The organization that owns the model.")
max_context_length: int | None = Field(default=None, description="Maximum amount of tokens a context could contains. Makes sure it is the same for all models.") # fmt: off
costs: ModelCosts = Field(..., description="Costs of the model.")
15 changes: 15 additions & 0 deletions api/domain/model/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass


@dataclass
class InconsistentModelVectorSizeError:
expected_vector_size: int
actual_vector_size: int
router_name: str


@dataclass
class InconsistentModelMaxContextLengthError:
expected_max_context_length: int
actual_max_context_length: int
router_name: str
17 changes: 17 additions & 0 deletions api/domain/provider/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from api.domain.provider._providergateway import ProviderCapabilities, ProviderGateway
from api.domain.provider._providerrepository import ProviderRepository

from .entities import Provider, ProviderCarbonFootprintZone, ProviderType
from .errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError

__all__ = [
"InvalidProviderTypeError",
"ProviderNotReachableError",
"ProviderRepository",
"ProviderGateway",
"ProviderCapabilities",
"ProviderType",
"Provider",
"ProviderCarbonFootprintZone",
"ProviderAlreadyExistsError",
]
26 changes: 26 additions & 0 deletions api/domain/provider/_providergateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from api.domain.model import ModelType as RouterType
from api.domain.provider.entities import ProviderType
from api.domain.provider.errors import ProviderNotReachableError


@dataclass
class ProviderCapabilities:
max_context_length: int | None
vector_size: int | None


class ProviderGateway(ABC):
@abstractmethod
async def get_capabilities(
self,
router_type: RouterType,
provider_type: ProviderType,
url: str,
key: str | None,
timeout: int,
model_name: str,
) -> ProviderCapabilities | ProviderNotReachableError:
pass
27 changes: 27 additions & 0 deletions api/domain/provider/_providerrepository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod

from api.domain.model.entities import Metric
from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType
from api.domain.provider.errors import ProviderAlreadyExistsError


class ProviderRepository(ABC):
@abstractmethod
async def create_provider(
self,
router_id: int,
user_id: int,
provider_type: ProviderType,
url: str,
key: str | None,
timeout: int,
model_name: str,
model_hosting_zone: ProviderCarbonFootprintZone,
model_total_params: int,
model_active_params: int,
qos_metric: Metric | None,
qos_limit: float | None,
vector_size: int,
max_context_length: int,
) -> Provider | ProviderAlreadyExistsError:
pass
76 changes: 76 additions & 0 deletions api/domain/provider/entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from enum import Enum
from typing import Literal

import pycountry
from pydantic import Field, constr

from api.domain.model.entities import ModelType
from api.schemas import BaseModel
from api.schemas.core.models import Metric

# Add world as a country code, default value of the carbon footprint computation framework
country_codes = [country.alpha_3 for country in pycountry.countries] + ["WOR"]
country_codes_dict = {str(code).upper(): str(code) for code in sorted(set(country_codes))}
ProviderCarbonFootprintZone: type[Enum] = Enum("ProviderCarbonFootprintZone", country_codes_dict, type=str)


class ProviderType(str, Enum):
ALBERT = "albert"
OPENAI = "openai"
MISTRAL = "mistral"
TEI = "tei"
VLLM = "vllm"


COMPATIBLE_PROVIDER_TYPES: dict[ModelType, list[str]] = {
ModelType.AUTOMATIC_SPEECH_RECOGNITION: [
ProviderType.ALBERT.value,
ProviderType.MISTRAL.value,
ProviderType.OPENAI.value,
ProviderType.VLLM.value,
],
ModelType.IMAGE_TEXT_TO_TEXT: [
ProviderType.ALBERT.value,
ProviderType.MISTRAL.value,
ProviderType.OPENAI.value,
ProviderType.VLLM.value,
],
ModelType.TEXT_EMBEDDINGS_INFERENCE: [
ProviderType.ALBERT.value,
ProviderType.OPENAI.value,
ProviderType.TEI.value,
ProviderType.VLLM.value,
],
ModelType.TEXT_GENERATION: [
ProviderType.ALBERT.value,
ProviderType.MISTRAL.value,
ProviderType.OPENAI.value,
ProviderType.VLLM.value,
],
ModelType.TEXT_CLASSIFICATION: [
ProviderType.ALBERT.value,
ProviderType.TEI.value,
],
ModelType.IMAGE_TO_TEXT: [
ProviderType.MISTRAL.value,
],
}


class Provider(BaseModel):
object: Literal["provider"] = "provider"
id: int = Field(..., description="Provider ID.") # fmt: off
router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off
user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off
type: ProviderType = Field(..., description="Provider type.") # fmt: off
url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off
key: str | None = Field(description="Provider API key.") # fmt: off
timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off
model_name: str = Field(..., description="Model name from the model provider.") # fmt: off
model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off
model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off
model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off
qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off
qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off
created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off
updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off
19 changes: 19 additions & 0 deletions api/domain/provider/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass


@dataclass
class InvalidProviderTypeError:
provider_type: str
router_type: str


@dataclass
class ProviderNotReachableError:
model_name: str


@dataclass
class ProviderAlreadyExistsError:
model_name: str
url: str
router_id: int
13 changes: 11 additions & 2 deletions api/domain/router/_routerrepository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from api.domain.router.entities import ModelType, Router, RouterLoadBalancingStrategy
from api.domain.model import ModelType as RouterType
from api.domain.router.entities import Router, RouterLoadBalancingStrategy
from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError


Expand All @@ -13,11 +14,19 @@ async def get_organization_name(self, user_id) -> str:
async def get_all_routers(self) -> list[Router]:
pass

@abstractmethod
async def get_router_by_id(self, router_id: int) -> Router | None:
pass

@abstractmethod
async def get_aliases_by_router_id(self, router_id: int) -> Router | None:
pass

@abstractmethod
async def create_router(
self,
name: str,
router_type: ModelType,
router_type: RouterType,
load_balancing_strategy: RouterLoadBalancingStrategy,
cost_prompt_tokens: float,
cost_completion_tokens: float,
Expand Down
Loading