|
1 | | -from fastapi import APIRouter, Depends, Path, Request, Security |
| 1 | +from fastapi import APIRouter, Depends, Path, Security |
2 | 2 | from fastapi.responses import JSONResponse |
3 | 3 |
|
4 | | -from api.dependencies import get_models_use_case_factory |
| 4 | +from api.dependencies import get_model_use_case_factory, get_models_use_case_factory, get_request_context |
5 | 5 | from api.domain.model.errors import ModelNotFoundError |
6 | 6 | from api.domain.user.errors import UserExpiredError |
7 | 7 | from api.infrastructure.fastapi.access import get_current_key |
| 8 | +from api.infrastructure.fastapi.context import RequestContext |
8 | 9 | from api.infrastructure.fastapi.documentation import get_documentation_responses |
9 | | -from api.infrastructure.fastapi.endpoints.exceptions import ( |
10 | | - AccountExpiredHTTPException, |
11 | | - ModelNotFoundHTTPException, |
12 | | -) |
13 | | -from api.infrastructure.fastapi.schemas.models import ModelResponse, ModelsResponse |
14 | | -from api.use_cases.models import GetModelsUseCase, GetModelUseCaseSucess |
| 10 | +from api.infrastructure.fastapi.endpoints.exceptions import AccountExpiredHTTPException, ModelNotFoundHTTPException |
| 11 | +from api.infrastructure.fastapi.schemas.models import Model, ModelsResponse |
| 12 | +from api.use_cases.models import GetModelCommand, GetModelsCommand, GetModelsUseCase, GetModelsUseCaseSucess, GetModelUseCase, GetModelUseCaseSucess |
15 | 13 | from api.utils.variables import EndpointRoute, RouterName |
16 | 14 |
|
17 | 15 | router = APIRouter(prefix="/v1", tags=[RouterName.MODELS.title()]) |
18 | 16 |
|
19 | 17 |
|
20 | 18 | @router.get( |
21 | | - path=EndpointRoute.MODELS + "/{model:path}", |
| 19 | + path=EndpointRoute.MODELS, |
22 | 20 | dependencies=[Security(dependency=get_current_key)], |
23 | 21 | status_code=200, |
24 | | - response_model=ModelResponse, |
25 | | - responses=get_documentation_responses([ModelNotFoundHTTPException]), |
| 22 | + response_model=ModelsResponse, |
| 23 | + responses=get_documentation_responses([]), |
26 | 24 | ) |
27 | | -async def get_model( |
28 | | - request: Request, |
29 | | - model: str = Path(description="The name of the model to get."), |
| 25 | +async def get_models( |
30 | 26 | get_models_use_case: GetModelsUseCase = Depends(get_models_use_case_factory), |
| 27 | + request_context: RequestContext = Depends(get_request_context), |
31 | 28 | ) -> ModelNotFoundHTTPException | JSONResponse: |
32 | 29 | """ |
33 | | - Get a model by name and provide basic information. |
| 30 | + Lists the currently available models and provides basic information. |
34 | 31 | """ |
35 | | - result = await get_models_use_case.execute(name=model) |
| 32 | + command = GetModelsCommand(user_id=request_context.get().user_id) |
| 33 | + result = await get_models_use_case.execute(command=command) |
36 | 34 |
|
37 | 35 | match result: |
38 | | - case GetModelUseCaseSucess(models): |
39 | | - models = [ModelResponse(**model.model_dump()) for model in models] |
40 | | - model = models[0] |
41 | | - return JSONResponse(content=model.model_dump(), status_code=200) |
42 | | - case ModelNotFoundError(): |
43 | | - raise ModelNotFoundHTTPException() |
| 36 | + case GetModelsUseCaseSucess(models): |
| 37 | + return JSONResponse(content=ModelsResponse.model_validate({"data": models}, from_attributes=True).model_dump(), status_code=200) |
44 | 38 | case UserExpiredError(): |
45 | 39 | raise AccountExpiredHTTPException() |
46 | 40 |
|
47 | 41 |
|
48 | 42 | @router.get( |
49 | | - path=EndpointRoute.MODELS, |
| 43 | + path=EndpointRoute.MODELS + "/{model:path}", |
50 | 44 | dependencies=[Security(dependency=get_current_key)], |
51 | 45 | status_code=200, |
52 | | - response_model=ModelsResponse, |
53 | | - responses=get_documentation_responses([]), |
| 46 | + response_model=Model, |
| 47 | + responses=get_documentation_responses([ModelNotFoundHTTPException]), |
54 | 48 | ) |
55 | | -async def get_models( |
56 | | - request: Request, |
57 | | - get_models_use_case: GetModelsUseCase = Depends(get_models_use_case_factory), |
| 49 | +async def get_model( |
| 50 | + model: str = Path(description="The name of the model to get."), |
| 51 | + get_model_use_case: GetModelUseCase = Depends(get_model_use_case_factory), |
| 52 | + request_context: RequestContext = Depends(get_request_context), |
58 | 53 | ) -> JSONResponse: |
59 | 54 | """ |
60 | | - Lists the currently available models and provides basic information. |
| 55 | + Get a model by name and provide basic information. |
61 | 56 | """ |
62 | | - result = await get_models_use_case.execute(name=None) |
| 57 | + command = GetModelCommand(user_id=request_context.get().user_id, name=model) |
| 58 | + result = await get_model_use_case.execute(command=command) |
| 59 | + |
63 | 60 | match result: |
64 | | - case GetModelUseCaseSucess(models): |
65 | | - models = [ModelResponse(**model.model_dump()) for model in models] |
66 | | - return JSONResponse(content=ModelsResponse(data=models).model_dump(), status_code=200) |
| 61 | + case GetModelUseCaseSucess(model): |
| 62 | + return JSONResponse(content=Model.model_validate(model, from_attributes=True).model_dump(), status_code=200) |
| 63 | + case ModelNotFoundError(name): |
| 64 | + raise ModelNotFoundHTTPException(name=name) |
67 | 65 | case UserExpiredError(): |
68 | 66 | raise AccountExpiredHTTPException() |
0 commit comments