Skip to content

Commit b2578a6

Browse files
authored
refacto(models): split getmodelsusecase in two use cases (#890)
* split get model use cases * fix bootstrap syntax * add get_router_by_name_or_alias method * Update unit coverage badge * last fix --------- Co-authored-by: leoguillaume <leoguillaume@users.noreply.github.com>
1 parent 9981a1c commit b2578a6

37 files changed

Lines changed: 776 additions & 572 deletions

.github/badges/coverage.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"schemaVersion":1,"label":"coverage","message":"56.00%","color":"red"}
1+
{"schemaVersion":1,"label":"coverage","message":"56.24%","color":"red"}

api/dependencies.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase, UpdateRouterUseCase
3737
from api.use_cases.admin.users import CreateUserUseCase
3838
from api.use_cases.health import GetHealthModelsUseCase
39-
from api.use_cases.models import GetModelsUseCase
39+
from api.use_cases.models import GetModelsUseCase, GetModelUseCase
4040
from api.utils.configuration import configuration
4141
from api.utils.context import global_context, request_context
4242

@@ -139,13 +139,16 @@ def get_health_models_use_case_factory(
139139

140140

141141
# models use cases
142-
def get_models_use_case_factory(
143-
postgres_session: AsyncSession = Depends(get_postgres_session),
144-
request_context: RequestContext = Depends(get_request_context),
145-
) -> GetModelsUseCase:
142+
def get_models_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetModelsUseCase:
146143
return GetModelsUseCase(
147144
router_repository=_router_repository(postgres_session),
148-
user_id=request_context.get().user_id,
145+
user_with_role_query=_user_with_role_query(postgres_session),
146+
)
147+
148+
149+
def get_model_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetModelUseCase:
150+
return GetModelUseCase(
151+
router_repository=_router_repository(postgres_session),
149152
user_with_role_query=_user_with_role_query(postgres_session),
150153
)
151154

api/domain/model/errors.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from dataclasses import dataclass
22

3-
from api.domain.provider.entities import ProviderType
4-
from api.utils.variables import EndpointRoute
5-
63

74
@dataclass
85
class InconsistentModelVectorSizeError:
@@ -13,18 +10,11 @@ class InconsistentModelVectorSizeError:
1310

1411
@dataclass
1512
class ModelNotFoundError:
16-
pass
13+
name: str
1714

1815

1916
@dataclass
2017
class InconsistentModelMaxContextLengthError:
2118
expected_max_context_length: int
2219
actual_max_context_length: int
2320
router_name: str
24-
25-
26-
# @TODO: move to provider.errors.py
27-
@dataclass
28-
class UnsupportedEndpointError:
29-
endpoint: EndpointRoute
30-
provider_type: ProviderType | None = None

api/domain/provider/_providergateway.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33

44
from api.domain.model.entities import ModelType as RouterType
5+
from api.domain.model.errors import ModelNotFoundError
56
from api.domain.provider.entities import ProviderType
67
from api.domain.provider.errors import ProviderNotReachableError
78

@@ -22,5 +23,5 @@ async def get_capabilities(
2223
key: str | None,
2324
timeout: int,
2425
model_name: str,
25-
) -> ProviderCapabilities | ProviderNotReachableError:
26+
) -> ProviderCapabilities | ModelNotFoundError | ProviderNotReachableError:
2627
pass

api/domain/provider/errors.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from dataclasses import dataclass
22

3+
from api.domain.provider.entities import ProviderType
4+
from api.utils.variables import EndpointRoute
5+
6+
7+
@dataclass
8+
class UnsupportedEndpointError:
9+
endpoint: EndpointRoute
10+
provider_type: ProviderType | None = None
11+
312

413
@dataclass
514
class InvalidProviderTypeError:
@@ -14,11 +23,6 @@ class ProviderNotReachableError:
1423
detail: str
1524

1625

17-
@dataclass
18-
class ModelProviderNotFoundError:
19-
model_name: str
20-
21-
2226
@dataclass
2327
class ProviderAlreadyExistsError:
2428
model_name: str

api/domain/router/_routerrepository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def get_router_by_id(self, router_id: int) -> Router | RouterNotFoundError
3030
pass
3131

3232
@abstractmethod
33-
async def get_aliases_by_router_id(self, router_id: int) -> list[str]:
33+
async def get_router_by_name_or_alias(self, name_or_alias: str) -> Router | RouterNotFoundError:
3434
pass
3535

3636
@abstractmethod

api/domain/router/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ class RouterNameAlreadyExistsError:
1313

1414
@dataclass
1515
class RouterNotFoundError:
16-
id: int
16+
id: int | None = None
17+
name: str | None = None

api/infrastructure/fastapi/endpoints/exceptions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ def __init__(self, input_vector_size: int, model_vector_size: int, model_name: s
6969
# 404
7070
class ModelNotFoundHTTPException(HTTPException):
7171
status_code = 404
72-
detail = "Model not found."
72+
detail = "Model {name} not found."
7373

74-
def __init__(self) -> None:
75-
super().__init__(status_code=self.status_code, detail=self.detail)
74+
def __init__(self, name: str) -> None:
75+
super().__init__(status_code=self.status_code, detail=f"Model {name} not found.")
7676

7777

7878
class RoleNotFoundHTTPException(HTTPException):

api/infrastructure/fastapi/endpoints/health.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from fastapi import APIRouter, Depends, Request, Security
1+
from fastapi import APIRouter, Depends, Security
22
from fastapi.responses import JSONResponse
33

44
from api.dependencies import get_health_models_use_case_factory, get_request_context
@@ -16,6 +16,9 @@
1616

1717
@router.get(path=EndpointRoute.HEALTH, status_code=200)
1818
def get_health() -> JSONResponse:
19+
"""
20+
Get the health of the API.
21+
"""
1922
return JSONResponse(content={"status": "ok"}, status_code=200)
2023

2124

@@ -26,20 +29,20 @@ def get_health() -> JSONResponse:
2629
responses=get_documentation_responses([]),
2730
)
2831
async def get_health_models(
29-
request: Request,
3032
get_health_models_use_case: GetHealthModelsUseCase = Depends(get_health_models_use_case_factory),
3133
request_context: RequestContext = Depends(get_request_context),
3234
) -> JSONResponse:
3335
"""
34-
Get a model by name and provide basic information.
36+
Get the health of the models.
3537
"""
3638

3739
result = await get_health_models_use_case.execute(command=GetHealthModelsCommand(user_id=request_context.get().user_id))
3840

3941
match result:
40-
case GetHealthModelsUseCaseSuccess():
41-
return ModelsHealthResponse(
42-
data=[ModelHealthStatus.model_validate(model, from_attributes=True) for model in result.models],
42+
case GetHealthModelsUseCaseSuccess(models):
43+
return JSONResponse(
44+
content=ModelsHealthResponse(data=[ModelHealthStatus.model_validate(model, from_attributes=True) for model in models]).model_dump(),
45+
status_code=200,
4346
)
4447
case UserExpiredError():
4548
raise AccountExpiredHTTPException()
Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,66 @@
1-
from fastapi import APIRouter, Depends, Path, Request, Security
1+
from fastapi import APIRouter, Depends, Path, Security
22
from fastapi.responses import JSONResponse
33

4-
from api.dependencies import get_models_use_case_factory
4+
from api.dependencies import get_model_use_case_factory, get_models_use_case_factory, get_request_context
55
from api.domain.model.errors import ModelNotFoundError
66
from api.domain.user.errors import UserExpiredError
77
from api.infrastructure.fastapi.access import get_current_key
8+
from api.infrastructure.fastapi.context import RequestContext
89
from api.infrastructure.fastapi.documentation import get_documentation_responses
9-
from api.infrastructure.fastapi.endpoints.exceptions import (
10-
AccountExpiredHTTPException,
11-
ModelNotFoundHTTPException,
12-
)
13-
from api.infrastructure.fastapi.schemas.models import ModelResponse, ModelsResponse
14-
from api.use_cases.models import GetModelsUseCase, GetModelUseCaseSucess
10+
from api.infrastructure.fastapi.endpoints.exceptions import AccountExpiredHTTPException, ModelNotFoundHTTPException
11+
from api.infrastructure.fastapi.schemas.models import Model, ModelsResponse
12+
from api.use_cases.models import GetModelCommand, GetModelsCommand, GetModelsUseCase, GetModelsUseCaseSucess, GetModelUseCase, GetModelUseCaseSucess
1513
from api.utils.variables import EndpointRoute, RouterName
1614

1715
router = APIRouter(prefix="/v1", tags=[RouterName.MODELS.title()])
1816

1917

2018
@router.get(
21-
path=EndpointRoute.MODELS + "/{model:path}",
19+
path=EndpointRoute.MODELS,
2220
dependencies=[Security(dependency=get_current_key)],
2321
status_code=200,
24-
response_model=ModelResponse,
25-
responses=get_documentation_responses([ModelNotFoundHTTPException]),
22+
response_model=ModelsResponse,
23+
responses=get_documentation_responses([]),
2624
)
27-
async def get_model(
28-
request: Request,
29-
model: str = Path(description="The name of the model to get."),
25+
async def get_models(
3026
get_models_use_case: GetModelsUseCase = Depends(get_models_use_case_factory),
27+
request_context: RequestContext = Depends(get_request_context),
3128
) -> ModelNotFoundHTTPException | JSONResponse:
3229
"""
33-
Get a model by name and provide basic information.
30+
Lists the currently available models and provides basic information.
3431
"""
35-
result = await get_models_use_case.execute(name=model)
32+
command = GetModelsCommand(user_id=request_context.get().user_id)
33+
result = await get_models_use_case.execute(command=command)
3634

3735
match result:
38-
case GetModelUseCaseSucess(models):
39-
models = [ModelResponse(**model.model_dump()) for model in models]
40-
model = models[0]
41-
return JSONResponse(content=model.model_dump(), status_code=200)
42-
case ModelNotFoundError():
43-
raise ModelNotFoundHTTPException()
36+
case GetModelsUseCaseSucess(models):
37+
return JSONResponse(content=ModelsResponse.model_validate({"data": models}, from_attributes=True).model_dump(), status_code=200)
4438
case UserExpiredError():
4539
raise AccountExpiredHTTPException()
4640

4741

4842
@router.get(
49-
path=EndpointRoute.MODELS,
43+
path=EndpointRoute.MODELS + "/{model:path}",
5044
dependencies=[Security(dependency=get_current_key)],
5145
status_code=200,
52-
response_model=ModelsResponse,
53-
responses=get_documentation_responses([]),
46+
response_model=Model,
47+
responses=get_documentation_responses([ModelNotFoundHTTPException]),
5448
)
55-
async def get_models(
56-
request: Request,
57-
get_models_use_case: GetModelsUseCase = Depends(get_models_use_case_factory),
49+
async def get_model(
50+
model: str = Path(description="The name of the model to get."),
51+
get_model_use_case: GetModelUseCase = Depends(get_model_use_case_factory),
52+
request_context: RequestContext = Depends(get_request_context),
5853
) -> JSONResponse:
5954
"""
60-
Lists the currently available models and provides basic information.
55+
Get a model by name and provide basic information.
6156
"""
62-
result = await get_models_use_case.execute(name=None)
57+
command = GetModelCommand(user_id=request_context.get().user_id, name=model)
58+
result = await get_model_use_case.execute(command=command)
59+
6360
match result:
64-
case GetModelUseCaseSucess(models):
65-
models = [ModelResponse(**model.model_dump()) for model in models]
66-
return JSONResponse(content=ModelsResponse(data=models).model_dump(), status_code=200)
61+
case GetModelUseCaseSucess(model):
62+
return JSONResponse(content=Model.model_validate(model, from_attributes=True).model_dump(), status_code=200)
63+
case ModelNotFoundError(name):
64+
raise ModelNotFoundHTTPException(name=name)
6765
case UserExpiredError():
6866
raise AccountExpiredHTTPException()

0 commit comments

Comments
 (0)