Skip to content

Commit bae5714

Browse files
authored
feat(health): call /metrics to set health of models (#911)
* add basic auth for metrics endpoint and change health/models logic * add tests * add docs * fix tests * Update unit coverage badge --------- Co-authored-by: leoguillaume <leoguillaume@users.noreply.github.com>
1 parent b489342 commit bae5714

43 files changed

Lines changed: 1128 additions & 143 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/badges/coverage.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"schemaVersion":1,"label":"coverage","message":"56.19%","color":"red"}
1+
{"schemaVersion":1,"label":"coverage","message":"56.67%","color":"red"}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""add basic auth to provider definition
2+
3+
Revision ID: 0daf52aadaf0
4+
Revises: 7498a3a48271
5+
Create Date: 2026-06-10 18:28:42.580559
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = '0daf52aadaf0'
16+
down_revision: Union[str, None] = '7498a3a48271'
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
"""Upgrade schema."""
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.add_column('provider', sa.Column('basic_auth', sa.JSON(), nullable=True))
25+
# ### end Alembic commands ###
26+
27+
28+
def downgrade() -> None:
29+
"""Downgrade schema."""
30+
# ### commands auto generated by Alembic - please adjust! ###
31+
op.drop_column('provider', 'basic_auth')
32+
# ### end Alembic commands ###

api/dependencies.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,13 @@ def _router_rate_limiter() -> RouterRateLimiter:
162162
# health use cases
163163
def get_health_models_use_case_factory(
164164
postgres_session: AsyncSession = Depends(get_postgres_session),
165+
provider_adapter_builder: ProviderAdapterBuilder = Depends(_provider_adapter_builder),
166+
provider_client: ProviderClient = Depends(_provider_client),
165167
redis_client: Redis = Depends(get_redis_client),
166168
) -> GetHealthModelsUseCase:
167169
return GetHealthModelsUseCase(
170+
provider_adapter_builder=provider_adapter_builder,
171+
provider_client=provider_client,
168172
provider_metrics_logger=_provider_metrics_logger(redis_client),
169173
router_repository=_router_repository(postgres_session),
170174
provider_repository=_provider_repository(postgres_session),

api/domain/provider/_providerrepository.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from api.domain import SortOrder
4-
from api.domain.provider.entities import HostingZone, Metric, Provider, ProviderPage, ProviderSortField, ProviderType
4+
from api.domain.provider.entities import BasicAuth, HostingZone, Metric, Provider, ProviderPage, ProviderSortField, ProviderType
55
from api.domain.provider.errors import ProviderAlreadyExistsError, ProviderNotFoundError
66

77

@@ -14,6 +14,7 @@ async def create_provider(
1414
provider_type: ProviderType,
1515
url: str,
1616
key: str | None,
17+
basic_auth: BasicAuth | None,
1718
timeout: int,
1819
model_name: str,
1920
model_hosting_zone: HostingZone,

api/domain/provider/entities.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from enum import StrEnum
22
from http import HTTPMethod
3-
from typing import Annotated
3+
from typing import Annotated, Literal
44

55
import pycountry
6-
from pydantic import BaseModel, Field
6+
from pydantic import Field
77

8-
from api.domain import EntitiesPage
8+
from api.domain import BaseModel, EntitiesPage
99
from api.domain.embeddings.entities import CreateEmbeddingsBody, Embeddings
1010
from api.domain.model.entities import Models, ModelType
1111
from api.domain.rerank.entities import CreateRerankBody, Rerank
@@ -33,6 +33,11 @@ class QoSMetric(StrEnum):
3333
PERFORMANCE = "performance" # custom performance metric
3434

3535

36+
class BasicAuth(BaseModel):
37+
username: str
38+
password: str
39+
40+
3641
class ProviderType(StrEnum):
3742
ALBERT = "albert"
3843
OPENAI = "openai"
@@ -96,6 +101,7 @@ class Provider(BaseModel):
96101
type: ProviderType
97102
url: str
98103
key: str | None = None
104+
basic_auth: BasicAuth | None = None
99105
timeout: int
100106
model_name: str
101107
model_hosting_zone: HostingZone = HostingZone.WOR
@@ -146,24 +152,31 @@ class ProviderOriginalRequest(BaseModel):
146152
files: Annotated[dict | None, Field(default=None, description="The files to use for the request.")]
147153

148154

155+
class ResponseMetrics(BaseModel):
156+
latency: Annotated[int, Field(default=0, description="The latency of the response.")]
157+
ttft: Annotated[int | None, Field(default=None, description="The time to first byte of the response.")]
158+
159+
149160
class ProviderFormattedRequest(BaseModel):
150161
method: Annotated[HTTPMethod, Field(description="The HTTP method to build the request.")]
151162
url: Annotated[str, Field(description="The model API URL to build the request.")]
163+
auth: Annotated[BasicAuth | None, Field(default=None, description="The authentication to use for the request.")]
152164
body: Annotated[dict, Field(default={}, description="The JSON body to use for the request.")]
153165
form: Annotated[dict, Field(default={}, description="The form-encoded data to use for the request.")]
154166
files: Annotated[dict, Field(default={}, description="The files to use for the request.")]
155167

156168

157-
class ResponseMetrics(BaseModel):
158-
latency: Annotated[int, Field(default=0, description="The latency of the response.")]
159-
ttft: Annotated[int | None, Field(default=None, description="The time to first byte of the response.")]
160-
161-
162169
class ProviderOriginalResponse(BaseModel):
163-
data: Annotated[dict | list, Field(default={}, description="The JSON data to use for the response.")]
170+
data: Annotated[dict | list | None, Field(default=None, description="The JSON data to use for the response.")]
164171
text: Annotated[str | None, Field(default=None, description="The text data to use for the response.")]
165172

166173

174+
class ProviderMetrics(BaseModel):
175+
object: Literal["providerMetrics"] = "providerMetrics"
176+
waiting_requests: float
177+
running_requests: float
178+
179+
167180
class ProviderFormattedResponse(BaseModel):
168-
data: Annotated[AudioTranscription | ChatCompletion | ChatCompletionChunk | Embeddings | Models | OCR | Rerank | None, Field(default=None, description="The JSON data to use for the response.")] # fmt: off
181+
data: Annotated[AudioTranscription | ChatCompletion | ChatCompletionChunk | Embeddings | Models | OCR | ProviderMetrics | Rerank | None, Field(default=None, description="The JSON data to use for the response.")] # fmt: off
169182
text: Annotated[str | None, Field(default=None, description="The text data to use for the response.")]

api/infrastructure/fastapi/endpoints/admin/providers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ async def create_provider(
9090
provider_type=body.type,
9191
url=body.url,
9292
key=body.key,
93+
basic_auth=body.basic_auth,
9394
timeout=body.timeout,
9495
model_name=body.model_name,
9596
model_hosting_zone=body.model_hosting_zone,

api/infrastructure/fastapi/endpoints/health.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from api.use_cases.health import GetHealthModelsCommand, GetHealthModelsUseCase, GetHealthModelsUseCaseSuccess
1212
from api.utils.variables import EndpointRoute, RouterName
1313

14-
router = APIRouter(tags=[RouterName.MONITORING.title()])
14+
router = APIRouter(tags=[RouterName.HEALTH.title()])
1515

1616

1717
@router.get(path=EndpointRoute.HEALTH, status_code=200)

api/infrastructure/fastapi/schemas/providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pydantic import Field, StringConstraints, model_validator
44

5-
from api.domain.provider.entities import HostingZone, ProviderType, QoSMetric
5+
from api.domain.provider.entities import BasicAuth, HostingZone, ProviderType, QoSMetric
66
from api.infrastructure.fastapi.schemas import BaseModel
77
from api.schemas.core.configuration import ModelProvider
88

@@ -18,6 +18,7 @@ class CreateProviderResponse(BaseModel):
1818
type: Annotated[ProviderType, Field(..., description="Provider type.")]
1919
url: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, to_lower=True), Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).")] # fmt: off
2020
key: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1), Field(default=None, description="Provider API key.")] # fmt: off
21+
basic_auth: Annotated[BasicAuth | None, Field(default=None, description="Provider basic authentication.")]
2122
timeout: Annotated[int, Field(..., ge=1, le=3600, description="Timeout for the provider requests, after user receive an 500 error (model is too busy).")] # fmt: off
2223
model_name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(..., description="Model name from the model provider.")] # fmt: off
2324
model_hosting_zone: Annotated[HostingZone, Field(default=HostingZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"])] # fmt: off
@@ -53,7 +54,6 @@ class ProviderResponse(BaseModel):
5354
user_id: Annotated[int, Field(description="ID of the user that owns the provider.")] # fmt: off
5455
provider_type: Annotated[ProviderType, Field(alias="type", description="Provider type.")] # fmt: off
5556
url: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, to_lower=True), Field(default=None, description="provider API url. The url must only contain the domain name (without `/v1` suffix for example).")] # fmt: off
56-
key: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1), Field(default=None, description="provider API key.")]
5757
timeout: Annotated[int, Field(description="Timeout for the provider requests, after user receive an 500 error (model is too busy).")]
5858
model_name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(description="Model name from the model provider.")]
5959
model_hosting_zone: Annotated[HostingZone, Field(default=HostingZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"])] # fmt: off

api/infrastructure/http/_httpprovideradapterbuilder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from api.infrastructure.http.adapters.embeddings.openai import OpenaiEmbeddingsAdapter
1616
from api.infrastructure.http.adapters.embeddings.tei import TeiEmbeddingsAdapter
1717
from api.infrastructure.http.adapters.embeddings.vllm import VllmEmbeddingsAdapter
18+
from api.infrastructure.http.adapters.metrics.mistral import MistralMetricsAdapter
19+
from api.infrastructure.http.adapters.metrics.vllm import VllmMetricsAdapter
1820
from api.infrastructure.http.adapters.models.albert import AlbertModelsAdapter
1921
from api.infrastructure.http.adapters.models.mistral import MistralModelsAdapter
2022
from api.infrastructure.http.adapters.models.openai import OpenaiModelsAdapter
@@ -56,6 +58,10 @@ class HttpProviderAdapterBuilder(ProviderAdapterBuilder):
5658
ProviderType.TEI: TeiModelsAdapter,
5759
ProviderType.VLLM: VllmModelsAdapter,
5860
},
61+
EndpointRoute.METRICS: {
62+
ProviderType.MISTRAL: MistralMetricsAdapter,
63+
ProviderType.VLLM: VllmMetricsAdapter,
64+
},
5965
EndpointRoute.OCR: {
6066
ProviderType.ALBERT: AlbertOcrAdapter,
6167
ProviderType.MISTRAL: MistralOcrAdapter,

api/infrastructure/http/_httpproviderclient.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44

55
import httpx
6+
from httpx import BasicAuth
67

78
from api.domain.model.errors import StatusCodeModelError, TooBusyModelError, UnknownModelError
89
from api.domain.provider import ProviderClient, ProviderClientResponse
@@ -13,10 +14,14 @@
1314

1415
class HttpProviderClient(ProviderClient):
1516
async def forward_request(self, provider: Provider, formatted_request: ProviderFormattedRequest) -> ProviderClientResponse:
17+
# TEMPORARY PATCH FOR MISTRAL METRICS ENDPOINT
18+
auth = BasicAuth(username=formatted_request.auth.username, password=formatted_request.auth.password) if formatted_request.auth else None
19+
1620
async with httpx.AsyncClient(timeout=provider.timeout) as async_client:
1721
try:
1822
response = await async_client.request(
1923
headers={"Authorization": f"Bearer {provider.key}"} if provider.key else {},
24+
auth=auth,
2025
method=formatted_request.method,
2126
url=formatted_request.url,
2227
json=formatted_request.body,

0 commit comments

Comments
 (0)