Skip to content

Commit 23d352c

Browse files
leoguillaumeBenjamin PILIA
authored andcommitted
feat(elasticsearch): simplify search method (#763)
* feat(elasticsearch): simplify search method * Update unit coverage badge --------- Co-authored-by: leoguillaume <leoguillaume@users.noreply.github.com>
1 parent d3dd1ce commit 23d352c

20 files changed

Lines changed: 363 additions & 49 deletions

File tree

.github/README_INTEGRATION_TESTS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ To execute a specific test, you can use the following command:
7272
CONFIG_FILE=api/tests/integ/config.test.yml PYTHONPATH=. pytest api/tests/integ/<path_to_test_file>::<TestClass>::<test_name> --config-file=pyproject.toml
7373
7474
# Example
75-
CONFIG_FILE=api/tests/integ/config.test.yml PYTHONPATH=. pytest api/tests/integ/test_admin/test_admin_providers.py::TestAdminProviders::test_create_provider_with_text_generation_model --config-file=pyproject.toml
75+
CONFIG_FILE=api/tests/integ/config.test.yml PYTHONPATH=. pytest api/tests/integ/test_admin/test_create_provider.py::TestAdminProviders::test_create_provider_with_text_generation_model --config-file=pyproject.toml
7676
```
7777

7878
To run a group of tests, you can use the following command:
@@ -81,7 +81,7 @@ To run a group of tests, you can use the following command:
8181
CONFIG_FILE=api/tests/integ/config.test.yml PYTHONPATH=. pytest api/tests/integ/<path_to_test_file> --config-file=pyproject.toml
8282
8383
# Example
84-
CONFIG_FILE=api/tests/integ/config.test.yml PYTHONPATH=. pytest api/tests/integ/test_admin/test_admin_providers.py --config-file=pyproject.toml
84+
CONFIG_FILE=api/tests/integ/config.test.yml PYTHONPATH=. pytest api/tests/integ/test_admin/test_create_provider.py --config-file=pyproject.toml
8585
```
8686

8787
## Run with VSCode

.github/badges/coverage.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"schemaVersion":1,"label":"coverage","message":"50.17%","color":"red"}
1+
{"schemaVersion":1,"label":"coverage","message":"50.16%","color":"red"}

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,4 +222,5 @@ playground/.gitignore
222222
playground/requirements.txt
223223
run.sh
224224
.claude
225-
bruno
225+
bruno
226+
docs/.astro

api/dependencies.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from api.infrastructure.model import ModelProviderGateway
99
from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository
1010
from api.schemas.core.context import RequestContext
11-
from api.use_cases.admin.providers import CreateProviderUseCase
11+
from api.use_cases.admin.providers import CreateProviderUseCase, DeleteProviderUseCase
1212
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase
1313
from api.use_cases.models import GetModelsUseCase
1414
from api.utils.configuration import configuration
@@ -94,6 +94,13 @@ def delete_router_use_case_factory(postgres_session: AsyncSession = Depends(get_
9494
)
9595

9696

97+
def delete_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> DeleteProviderUseCase:
98+
return DeleteProviderUseCase(
99+
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
100+
user_info_repository=_user_info_repository(postgres_session),
101+
)
102+
103+
97104
def get_key_repository(postgres_session: AsyncSession = Depends(get_postgres_session)) -> KeyRepository:
98105
return PostgresKeyRepository(postgres_session=postgres_session)
99106

api/domain/provider/_providerrepository.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,10 @@ async def create_provider(
2525
max_context_length: int,
2626
) -> Provider | ProviderAlreadyExistsError:
2727
pass
28+
29+
@abstractmethod
30+
async def delete_provider(
31+
self,
32+
provider_id: int,
33+
) -> Provider | None:
34+
pass

api/domain/provider/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,8 @@ class ProviderAlreadyExistsError:
1717
model_name: str
1818
url: str
1919
router_id: int
20+
21+
22+
@dataclass
23+
class ProviderNotFoundError:
24+
provider_id: int

api/endpoints/admin/providers.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,6 @@
1313
from api.utils.variables import EndpointRoute
1414

1515

16-
@router.delete(
17-
path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}",
18-
dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))],
19-
status_code=204,
20-
)
21-
async def delete_provider(
22-
request: Request,
23-
provider: int = Path(description="The ID of the provider to delete."),
24-
postgres_session: AsyncSession = Depends(get_postgres_session),
25-
model_registry: ModelRegistry = Depends(get_model_registry),
26-
) -> Response:
27-
"""
28-
Delete a router provider.
29-
"""
30-
await model_registry.delete_provider(provider_id=provider, postgres_session=postgres_session)
31-
32-
return Response(status_code=204)
33-
34-
3516
@router.patch(
3617
path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}",
3718
dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))],

api/helpers/_elasticsearchvectorstore.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ async def upsert(self, client: AsyncElasticsearch, chunks: list[ElasticsearchChu
155155

156156
@staticmethod
157157
def _escape_query_string_value(value: str) -> str:
158-
# Escape reserved Lucene query_string characters so user-provided text
159-
# is treated as a literal token sequence in metadata filters.
158+
# Escape reserved Lucene query_string characters so user-provided text is treated as a literal token sequence in metadata filters.
160159
return re.sub(r'([+\-=&|><!(){}\[\]^"~*?:\\/ ])', r"\\\1", value)
161160

162161
@staticmethod
@@ -259,17 +258,10 @@ async def _lexical_search(
259258
"size": limit,
260259
"from": offset,
261260
"_source": {"excludes": ["embedding"]},
261+
"sort": [{"_score": {"order": "desc"}}],
262262
}
263263
results = await client.search(index=self.index_name, body=body)
264-
searches = [
265-
Search(
266-
method=SearchMethod.LEXICAL.value,
267-
score=hit["_score"],
268-
chunk=Chunk(**hit["_source"]),
269-
)
270-
for hit in results["hits"]["hits"]
271-
]
272-
searches = sorted(searches, key=lambda x: x.score, reverse=True)[:limit]
264+
searches = [Search(method=SearchMethod.LEXICAL.value, score=hit["_score"], chunk=Chunk(**hit["_source"])) for hit in results["hits"]["hits"]]
273265

274266
return searches
275267

api/infrastructure/fastapi/endpoints/admin/providers.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from fastapi.responses import JSONResponse, Response
77
from sqlalchemy.ext.asyncio import AsyncSession
88

9-
from api.dependencies import create_provider_use_case_factory, get_request_context
9+
from api.dependencies import create_provider_use_case_factory, delete_provider_use_case_factory, get_request_context
1010
from api.domain.model import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError
1111
from api.domain.provider import InvalidProviderTypeError, ProviderNotReachableError
12-
from api.domain.provider.errors import ProviderAlreadyExistsError
12+
from api.domain.provider.errors import ProviderAlreadyExistsError, ProviderNotFoundError
1313
from api.domain.router.errors import RouterNotFoundError
1414
from api.domain.userinfo.errors import UserIsNotAdminError
1515
from api.helpers.models import ModelRegistry
@@ -24,11 +24,19 @@
2424
InvalidProviderTypeHTTPException,
2525
NotAdminUserHTTPException,
2626
ProviderAlreadyExistsHTTPException,
27+
ProviderNotFoundHTTPException,
2728
ProviderNotReachableHTTPException,
2829
RouterNotFoundHTTPException,
2930
)
3031
from api.infrastructure.fastapi.schemas.providers import CreateProvider, CreateProviderResponse, Provider, Providers, UpdateProvider
31-
from api.use_cases.admin.providers import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess
32+
from api.use_cases.admin.providers import (
33+
CreateProviderCommand,
34+
CreateProviderUseCase,
35+
CreateProviderUseCaseSuccess,
36+
DeleteProviderCommand,
37+
DeleteProviderUseCase,
38+
DeleteProviderUseCaseSuccess,
39+
)
3240
from api.utils.dependencies import get_model_registry, get_postgres_session
3341
from api.utils.variables import EndpointRoute
3442

@@ -107,19 +115,39 @@ async def create_provider(
107115

108116

109117
@router.delete(
110-
path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}",
118+
path=EndpointRoute.ADMIN_PROVIDERS + "/{provider_id}",
111119
dependencies=[Security(dependency=get_current_key)],
112-
status_code=204,
120+
status_code=200,
113121
)
114122
async def delete_provider(
115-
request: Request,
116-
provider: int = Path(description="The ID of the provider to delete."),
117-
postgres_session: AsyncSession = Depends(get_postgres_session),
118-
model_registry: ModelRegistry = Depends(get_model_registry),
119-
) -> Response:
120-
await model_registry.delete_provider(provider_id=provider, postgres_session=postgres_session)
123+
provider_id: int = Path(description="The ID of the provider to delete."),
124+
delete_provider_use_case: DeleteProviderUseCase = Depends(delete_provider_use_case_factory),
125+
request_context: ContextVar[RequestContext] = Depends(get_request_context),
126+
) -> Provider:
127+
command = DeleteProviderCommand(
128+
user_id=request_context.get().user_id,
129+
provider_id=provider_id,
130+
)
131+
try:
132+
result = await delete_provider_use_case.execute(command)
133+
except Exception as e:
134+
logger.exception(
135+
"Unexpected error while executing delete_provider use case",
136+
extra={
137+
"user_id": command.user_id,
138+
"provider_id": command.provider_id,
139+
"error_type": type(e).__name__,
140+
},
141+
)
142+
raise InternalServerHTTPException()
121143

122-
return Response(status_code=204)
144+
match result:
145+
case DeleteProviderUseCaseSuccess(deleted_router):
146+
return Provider.model_validate(deleted_router, from_attributes=True)
147+
case ProviderNotFoundError(provider_id=not_found_id):
148+
raise ProviderNotFoundHTTPException(not_found_id)
149+
case UserIsNotAdminError():
150+
raise NotAdminUserHTTPException()
123151

124152

125153
@router.patch(

api/infrastructure/fastapi/endpoints/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def __init__(self, router_id: int) -> None:
8383
super().__init__(status_code=self.status_code, detail=f"Model router {router_id} not found.")
8484

8585

86+
class ProviderNotFoundHTTPException(HTTPException):
87+
status_code = 404
88+
detail = "Model provider {provider_id} not found."
89+
90+
def __init__(self, provider_id: int) -> None:
91+
super().__init__(status_code=self.status_code, detail=f"Model provider {provider_id} not found.")
92+
93+
8694
# 409
8795
class RouterAliasAlreadyExistsHTTPException(HTTPException):
8896
status_code = 409

0 commit comments

Comments
 (0)