Skip to content

Commit 593cf2e

Browse files
author
Benjamin PILIA
committed
Refacto router_repository et fixes
1 parent 02f620c commit 593cf2e

8 files changed

Lines changed: 87 additions & 30 deletions

File tree

api/dependencies.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository
1010
from api.schemas.core.context import RequestContext
1111
from api.use_cases.admin.providers import CreateProviderUseCase
12-
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase
13-
from api.use_cases.admin.routers._updaterouterusecase import UpdateRouterUseCase
12+
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase, UpdateRouterUseCase
1413
from api.use_cases.models import GetModelsUseCase
1514
from api.utils.configuration import configuration
1615
from api.utils.context import global_context, request_context

api/domain/router/entities.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,11 @@ def with_type(self, router_type: RouterType) -> "Router":
5252
def with_load_balancing_strategy(self, strategy: RouterLoadBalancingStrategy) -> "Router":
5353
return self.model_copy(update={"load_balancing_strategy": strategy})
5454

55-
def with_costs(self, prompt_tokens: float, completion_tokens: float) -> "Router":
56-
return self.model_copy(update={"cost_prompt_tokens": prompt_tokens, "cost_completion_tokens": completion_tokens})
55+
def with_cost_prompt_tokens(self, prompt_tokens: float) -> "Router":
56+
return self.model_copy(update={"cost_prompt_tokens": prompt_tokens})
57+
58+
def with_cost_completion_tokens(self, completion_tokens: float) -> "Router":
59+
return self.model_copy(update={"cost_completion_tokens": completion_tokens})
5760

5861
def with_aliases(self, aliases: list[str]) -> "Router":
5962
return self.model_copy(update={"aliases": aliases})

api/infrastructure/fastapi/endpoints/admin/routers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@
3939
GetRoutersCommand,
4040
GetRoutersUseCase,
4141
GetRoutersUseCaseSuccess,
42+
UpdateRouterCommand,
43+
UpdateRouterUseCase,
44+
UpdateRouterUseCaseSuccess,
4245
)
43-
from api.use_cases.admin.routers._updaterouterusecase import UpdateRouterCommand, UpdateRouterUseCase, UpdateRouterUseCaseSuccess
4446
from api.utils.variables import EndpointRoute
4547

4648
logger = logging.getLogger(__name__)
@@ -216,7 +218,9 @@ async def delete_router(
216218
@router.patch(
217219
path=EndpointRoute.ADMIN_ROUTERS + "/{router_id}",
218220
dependencies=[Security(dependency=get_current_key)],
219-
responses=get_documentation_responses([RouterNotFoundHTTPException, NotAdminUserHTTPException]),
221+
responses=get_documentation_responses(
222+
[RouterNotFoundHTTPException, NotAdminUserHTTPException, RouterAliasAlreadyExistsHTTPException, RouterAlreadyExistsHTTPException]
223+
),
220224
status_code=200,
221225
)
222226
async def update_router(
@@ -248,13 +252,13 @@ async def update_router(
248252
)
249253
raise InternalServerHTTPException()
250254
match result:
251-
case UpdateRouterUseCaseSuccess(updated_router):
255+
case UpdateRouterUseCaseSuccess(router=updated_router):
252256
return Router.model_validate(updated_router, from_attributes=True)
253257
case RouterNotFoundError(router_id=not_found_id):
254258
raise RouterNotFoundHTTPException(not_found_id)
255259
case UserIsNotAdminError():
256260
raise NotAdminUserHTTPException()
257-
case RouterAliasAlreadyExistsError(name):
258-
raise RouterAliasAlreadyExistsHTTPException(name)
261+
case RouterAliasAlreadyExistsError(aliases):
262+
raise RouterAliasAlreadyExistsHTTPException(aliases)
259263
case RouterNameAlreadyExistsError(name):
260264
raise RouterAlreadyExistsHTTPException(name)

api/infrastructure/postgres/_postgresrouterrepository.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -204,33 +204,60 @@ async def delete_router(self, router_id: int) -> Router | None:
204204
await self.postgres_session.execute(delete(RouterTable).where(RouterTable.id == router_id))
205205
return router
206206

207-
async def update_router(self, router: Router) -> Router | RouterNameAlreadyExistsError:
208-
db_user_id = None if router.user_id == MASTER_USER_ID else router.user_id
207+
async def update_router(self, router_to_update: Router) -> Router | RouterNameAlreadyExistsError:
208+
db_user_id = None if router_to_update.user_id == MASTER_USER_ID else router_to_update.user_id
209209

210210
try:
211-
await self.postgres_session.execute(
211+
update_query = (
212212
update(RouterTable)
213-
.where(RouterTable.id == router.id)
213+
.where(RouterTable.id == router_to_update.id)
214214
.values(
215215
user_id=db_user_id,
216-
name=router.name,
217-
type=router.type.value,
218-
load_balancing_strategy=router.load_balancing_strategy.value,
219-
cost_prompt_tokens=router.cost_prompt_tokens,
220-
cost_completion_tokens=router.cost_completion_tokens,
216+
name=router_to_update.name,
217+
type=router_to_update.type.value,
218+
load_balancing_strategy=router_to_update.load_balancing_strategy.value,
219+
cost_prompt_tokens=router_to_update.cost_prompt_tokens,
220+
cost_completion_tokens=router_to_update.cost_completion_tokens,
221+
)
222+
.returning(
223+
RouterTable.id,
224+
RouterTable.name,
225+
RouterTable.user_id,
226+
RouterTable.type,
227+
RouterTable.load_balancing_strategy,
228+
RouterTable.cost_prompt_tokens,
229+
RouterTable.cost_completion_tokens,
230+
cast(func.extract("epoch", RouterTable.created), Integer).label("created"),
231+
cast(func.extract("epoch", RouterTable.updated), Integer).label("updated"),
221232
)
222233
)
234+
result = await self.postgres_session.execute(update_query)
235+
row = result.one()
223236

224-
if router.aliases is not None:
225-
await self.postgres_session.execute(delete(RouterAliasTable).where(RouterAliasTable.router_id == router.id))
226-
if router.aliases:
237+
if router_to_update.aliases is not None:
238+
await self.postgres_session.execute(delete(RouterAliasTable).where(RouterAliasTable.router_id == router_to_update.id))
239+
if router_to_update.aliases:
227240
await self.postgres_session.execute(
228241
insert(RouterAliasTable),
229-
[{"value": alias, "router_id": router.id} for alias in router.aliases],
242+
[{"value": alias, "router_id": router_to_update.id} for alias in router_to_update.aliases],
230243
)
231244
except IntegrityError as e:
232245
if "router_name_key" in str(e.orig):
233-
return RouterNameAlreadyExistsError(name=router.name)
246+
return RouterNameAlreadyExistsError(name=router_to_update.name)
234247
raise
235248

236-
return router
249+
return Router(
250+
id=row.id,
251+
name=row.name,
252+
user_id=router_to_update.user_id,
253+
type=RouterType(row.type),
254+
aliases=router_to_update.aliases,
255+
load_balancing_strategy=RouterLoadBalancingStrategy(row.load_balancing_strategy),
256+
vector_size=router_to_update.vector_size,
257+
max_context_length=router_to_update.max_context_length,
258+
cost_prompt_tokens=row.cost_prompt_tokens or 0.0,
259+
cost_completion_tokens=row.cost_completion_tokens or 0.0,
260+
providers=router_to_update.providers,
261+
created=row.created,
262+
updated=row.updated,
263+
)

api/tests/integration/postgres/test_postgresrouterrepository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ async def test_update_router_should_return_router_with_updated_costs(self, repos
551551
await db_session.flush()
552552

553553
# Act
554-
result = await repository.update_router(to_router_domain(router).with_costs(0.010, 0.020))
554+
result = await repository.update_router(to_router_domain(router).with_cost_prompt_tokens(0.010).with_cost_completion_tokens(0.020))
555555

556556
# Assert
557557
assert isinstance(result, Router)

api/tests/unit/use_case/admin/routers/test_updaterouterusecase.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from api.domain.model import ModelType as RouterType
6-
from api.domain.router.entities import Router, RouterLoadBalancingStrategy
6+
from api.domain.router.entities import RouterLoadBalancingStrategy
77
from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError, RouterNotFoundError
88
from api.domain.userinfo.errors import UserIsNotAdminError
99
from api.tests.unit.use_case.factories import RouterFactory, UserInfoFactory
@@ -36,7 +36,7 @@ def unauthorized_user_info():
3636

3737

3838
@pytest.fixture
39-
def sample_router() -> Router:
39+
def sample_router():
4040
return RouterFactory(id=42, user_id=1, name="original-name", aliases=[])
4141

4242

@@ -50,7 +50,8 @@ async def test_should_return_updated_router_when_user_is_admin_and_router_exists
5050
sample_router.with_name("new-name")
5151
.with_type(RouterType.TEXT_EMBEDDINGS_INFERENCE)
5252
.with_load_balancing_strategy(RouterLoadBalancingStrategy.LEAST_BUSY)
53-
.with_costs(0.005, 0.010)
53+
.with_cost_prompt_tokens(0.005)
54+
.with_cost_completion_tokens(0.010)
5455
.with_aliases(["alias-a", "alias-b"])
5556
)
5657
use_case.user_info_repository.get_user_info.return_value = admin_user_info
@@ -169,6 +170,23 @@ async def test_should_propagate_router_name_already_exists_error_from_repository
169170
assert isinstance(result, RouterNameAlreadyExistsError)
170171
assert result.name == "taken-name"
171172

173+
@pytest.mark.asyncio
174+
async def test_should_add_aliases_when_router_has_no_aliases_and_command_updates_aliases(self, use_case, router_repository, admin_user_info):
175+
# Arrange
176+
router = RouterFactory(id=42, user_id=1, aliases=None)
177+
updated_router = router.with_aliases(["new-alias"])
178+
use_case.user_info_repository.get_user_info.return_value = admin_user_info
179+
use_case.router_repository.get_router_by_id.return_value = router
180+
use_case.router_repository.get_aliases.return_value = []
181+
use_case.router_repository.update_router.return_value = updated_router
182+
183+
# Act
184+
result = await use_case.execute(command=UpdateRouterCommand(user_id=admin_user_info.id, router_id=42, aliases=["new-alias"]))
185+
186+
# Assert
187+
assert isinstance(result, UpdateRouterUseCaseSuccess)
188+
router_repository.update_router.assert_called_once_with(router=router.with_aliases(["new-alias"]))
189+
172190
@pytest.mark.asyncio
173191
async def test_should_not_call_update_router_when_router_should_not_be_updated(self, use_case, router_repository, admin_user_info, sample_router):
174192
# Arrange

api/use_cases/admin/routers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._deleterouterusecase import DeleteRouterCommand, DeleteRouterUseCase, DeleteRouterUseCaseSuccess
33
from ._getonerouterusecase import GetOneRouterCommand, GetOneRouterUseCase, GetOneRouterUseCaseSuccess
44
from ._getroutersusecase import GetRoutersCommand, GetRoutersUseCase, GetRoutersUseCaseSuccess
5+
from ._updaterouterusecase import UpdateRouterCommand, UpdateRouterUseCase, UpdateRouterUseCaseSuccess
56

67
__all__ = [
78
"CreateRouterCommand",
@@ -16,4 +17,7 @@
1617
"DeleteRouterCommand",
1718
"DeleteRouterUseCase",
1819
"DeleteRouterUseCaseSuccess",
20+
"UpdateRouterCommand",
21+
"UpdateRouterUseCase",
22+
"UpdateRouterUseCaseSuccess",
1923
]

api/use_cases/admin/routers/_updaterouterusecase.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ async def execute(
5050

5151
if command.aliases:
5252
existing_aliases = await self.router_repository.get_aliases()
53-
conflicting_aliases = set(command.aliases) & (set(existing_aliases) - set(router.aliases))
53+
conflicting_aliases = set(command.aliases) & (set(existing_aliases) - set(router.aliases or []))
5454
if conflicting_aliases:
5555
return RouterAliasAlreadyExistsError(aliases=list(conflicting_aliases))
5656

@@ -62,7 +62,9 @@ async def execute(
6262
if command.load_balancing_strategy is not None:
6363
updated_router = updated_router.with_load_balancing_strategy(command.load_balancing_strategy)
6464
if command.cost_prompt_tokens is not None:
65-
updated_router = updated_router.with_costs(command.cost_prompt_tokens, command.cost_completion_tokens)
65+
updated_router = updated_router.with_cost_prompt_tokens(command.cost_prompt_tokens)
66+
if command.cost_completion_tokens is not None:
67+
updated_router = updated_router.with_cost_completion_tokens(command.cost_completion_tokens)
6668
if command.aliases is not None:
6769
updated_router = updated_router.with_aliases(command.aliases)
6870

0 commit comments

Comments
 (0)