Skip to content

Commit 36ccf30

Browse files
committed
refactor(BA-5650-D): switch session repository to owner_id
``SessionRepository`` and the underlying ``SessionDBSource`` now take ``owner_id: UUID`` on every method that previously accepted ``owner_access_key: AccessKey``. Affects: - ``get_session_validated`` - ``match_sessions`` - ``update_session_name`` - ``find_dependency_sessions`` / ``_find_dependent_sessions`` - ``get_target_session_ids`` - ``get_session_with_group`` The matching ``dependency_graph`` helpers and ``creators`` are updated in lockstep. Service-layer callers still pass ``owner_access_key`` temporarily; they will be migrated in the next slice.
1 parent b3aa496 commit 36ccf30

5 files changed

Lines changed: 45 additions & 51 deletions

File tree

changes/BA-5650-D.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Collapse `SessionRepository` / `SessionDBSource` signatures to take `owner_id: UUID` instead of `owner_access_key: AccessKey`. No external behavior change; downstream service callers are updated in a later slice.

src/ai/backend/manager/repositories/session/creators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class SessionRowCreatorSpec(CreatorSpec[SessionRow]):
1717
This spec is designed for retrofitting existing code that already builds
1818
SessionRow instances. It simply returns the provided row in build_row().
1919
20-
For scope information needed by RBACEntityCreator, use the row's user_uuid
20+
For scope information needed by RBACEntityCreator, use the row's owner_id
2121
field as the scope_id with ScopeType.USER.
2222
"""
2323

src/ai/backend/manager/repositories/session/db_source/db_source.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ async def get_session_owner(self, session_id: str | SessionId) -> UserData:
6363
async def get_session_validated(
6464
self,
6565
session_name_or_id: str | SessionId,
66-
owner_access_key: AccessKey,
66+
owner_id: uuid.UUID,
6767
kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY,
6868
allow_stale: bool = False,
6969
eager_loading_op: Sequence[_AbstractLoad] | None = None,
@@ -73,7 +73,7 @@ async def get_session_validated(
7373
return await SessionRow.get_session(
7474
db_sess,
7575
session_name_or_id,
76-
owner_access_key,
76+
owner_id=owner_id,
7777
kernel_loading_strategy=kernel_loading_strategy,
7878
allow_stale=allow_stale,
7979
eager_loading_op=list(eager_loading_op) if eager_loading_op else None,
@@ -82,13 +82,13 @@ async def get_session_validated(
8282
async def match_sessions(
8383
self,
8484
id_or_name_prefix: str,
85-
owner_access_key: AccessKey,
85+
owner_id: uuid.UUID,
8686
) -> list[SessionRow]:
8787
async with self._db.begin_readonly_session_read_committed() as db_sess:
8888
return await SessionRow.match_sessions(
8989
db_sess,
9090
id_or_name_prefix,
91-
owner_access_key,
91+
owner_id=owner_id,
9292
)
9393

9494
async def get_session_to_determine_status(
@@ -132,15 +132,15 @@ async def update_session_name(
132132
self,
133133
session_name_or_id: str | SessionId,
134134
new_name: str,
135-
owner_access_key: AccessKey,
135+
owner_id: uuid.UUID,
136136
) -> SessionRow:
137137
async def _update(db_session: AsyncSession) -> SessionRow:
138138
# Check if new name already exists for this owner
139139
try:
140140
await SessionRow.get_session(
141141
db_session,
142142
new_name,
143-
owner_access_key,
143+
owner_id=owner_id,
144144
kernel_loading_strategy=KernelLoadingStrategy.NONE,
145145
)
146146
raise SessionAlreadyExists(f"Session with name '{new_name}' already exists")
@@ -151,7 +151,7 @@ async def _update(db_session: AsyncSession) -> SessionRow:
151151
session_row = await SessionRow.get_session(
152152
db_session,
153153
session_name_or_id,
154-
owner_access_key,
154+
owner_id=owner_id,
155155
kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS,
156156
)
157157

@@ -305,13 +305,13 @@ async def modify_session(
305305
if session_row is None:
306306
raise SessionNotFound(f"Session not found (id:{session_id})")
307307

308-
if session_name and session_row.access_key is not None:
308+
if session_name and session_row.user_uuid is not None:
309309
# Check the owner of the target session has any session with the same name
310310
try:
311311
sess = await SessionRow.get_session(
312312
db_session,
313313
session_name,
314-
AccessKey(session_row.access_key),
314+
owner_id=session_row.user_uuid,
315315
)
316316
except SessionNotFound:
317317
pass
@@ -371,7 +371,7 @@ async def _find_dependent_sessions(
371371
self,
372372
db_sess: AsyncSession,
373373
root_session_name_or_id: str | uuid.UUID,
374-
access_key: AccessKey,
374+
owner_id: uuid.UUID,
375375
allow_stale: bool = False,
376376
) -> tuple[uuid.UUID, set[uuid.UUID]]:
377377
"""
@@ -401,7 +401,7 @@ async def _find_recursive_dependencies(session_id: uuid.UUID) -> set[uuid.UUID]:
401401
root_session = await SessionRow.get_session(
402402
db_sess,
403403
root_session_name_or_id,
404-
access_key=access_key,
404+
owner_id=owner_id,
405405
allow_stale=allow_stale,
406406
)
407407
root_session_id = cast(uuid.UUID, root_session.id)
@@ -412,14 +412,14 @@ async def _find_recursive_dependencies(session_id: uuid.UUID) -> set[uuid.UUID]:
412412
async def get_target_session_ids(
413413
self,
414414
session_name_or_id: str | uuid.UUID,
415-
access_key: AccessKey,
415+
owner_id: uuid.UUID,
416416
recursive: bool = False,
417417
) -> list[SessionId]:
418418
"""
419419
Get list of session IDs including dependent sessions if recursive.
420420
421421
:param session_name_or_id: Name or ID of the primary session
422-
:param access_key: Access key of the session owner
422+
:param owner_id: User UUID of the session owner
423423
:param recursive: If True, include dependent sessions
424424
:return: List of session IDs
425425
"""
@@ -430,7 +430,7 @@ async def get_target_session_ids(
430430
root_id, dependent_ids = await self._find_dependent_sessions(
431431
db_sess,
432432
session_name_or_id,
433-
access_key,
433+
owner_id,
434434
allow_stale=True,
435435
)
436436
# Return dependent sessions first, then root session
@@ -441,7 +441,7 @@ async def get_target_session_ids(
441441
session = await SessionRow.get_session(
442442
db_sess,
443443
session_name_or_id,
444-
access_key,
444+
owner_id=owner_id,
445445
kernel_loading_strategy=KernelLoadingStrategy.NONE,
446446
allow_stale=True,
447447
)
@@ -454,19 +454,19 @@ async def get_target_session_ids(
454454
async def find_dependency_sessions(
455455
self,
456456
session_name_or_id: uuid.UUID | str,
457-
access_key: AccessKey,
457+
owner_id: uuid.UUID,
458458
) -> dict[str, list[Any] | str]:
459459
async with self._db.begin_readonly_session_read_committed() as db_sess:
460460
return await find_dependency_sessions(
461461
session_name_or_id,
462462
db_sess,
463-
access_key,
463+
owner_id,
464464
)
465465

466466
async def get_session_with_group(
467467
self,
468468
session_name_or_id: str | SessionId,
469-
owner_access_key: AccessKey,
469+
owner_id: uuid.UUID,
470470
kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY,
471471
allow_stale: bool = False,
472472
) -> SessionRow:
@@ -475,7 +475,7 @@ async def get_session_with_group(
475475
return await SessionRow.get_session(
476476
db_sess,
477477
session_name_or_id,
478-
owner_access_key,
478+
owner_id=owner_id,
479479
kernel_loading_strategy=kernel_loading_strategy,
480480
allow_stale=allow_stale,
481481
eager_loading_op=[selectinload(SessionRow.group)],
@@ -484,14 +484,14 @@ async def get_session_with_group(
484484
async def get_session_with_routing_minimal(
485485
self,
486486
session_name_or_id: str | SessionId,
487-
owner_access_key: AccessKey,
487+
owner_id: uuid.UUID,
488488
) -> SessionRow:
489489
"""Get session with minimal routing information"""
490490
async with self._db.begin_readonly_session_read_committed() as db_sess:
491491
return await SessionRow.get_session(
492492
db_sess,
493493
session_name_or_id,
494-
owner_access_key,
494+
owner_id=owner_id,
495495
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
496496
eager_loading_op=[
497497
selectinload(SessionRow.routing).options(noload("*")),

src/ai/backend/manager/repositories/session/dependency_graph.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import sqlalchemy as sa
1414
from sqlalchemy.ext.asyncio import AsyncSession as SASession
1515

16-
from ai.backend.common.types import AccessKey
1716
from ai.backend.manager.errors.kernel import InvalidSessionData, SessionNotFound
1817
from ai.backend.manager.models.kernel import kernels
1918
from ai.backend.manager.models.session import SessionDependencyRow, SessionRow
@@ -23,12 +22,12 @@
2322
async def _find_dependency_sessions(
2423
session_name_or_id: UUID | str,
2524
db_session: SASession,
26-
access_key: AccessKey,
25+
owner_id: UUID,
2726
) -> dict[str, list[Any] | str]:
2827
sessions = await SessionRow.match_sessions(
2928
db_session,
3029
session_name_or_id,
31-
access_key=access_key,
30+
owner_id=owner_id,
3231
)
3332

3433
if len(sessions) < 1:
@@ -66,7 +65,7 @@ async def _find_dependency_sessions(
6665
"status": str(kernel_query_result[0]),
6766
"status_changed": str(kernel_query_result[1]),
6867
"depends_on": [
69-
await _find_dependency_sessions(dependency_session_id, db_session, access_key)
68+
await _find_dependency_sessions(dependency_session_id, db_session, owner_id)
7069
for dependency_session_id in dependency_session_ids
7170
],
7271
}
@@ -77,15 +76,15 @@ async def _find_dependency_sessions(
7776
async def find_dependency_sessions(
7877
session_name_or_id: UUID | str,
7978
db_session: SASession,
80-
access_key: AccessKey,
79+
owner_id: UUID,
8180
) -> dict[str, list[Any] | str]:
82-
return await _find_dependency_sessions(session_name_or_id, db_session, access_key)
81+
return await _find_dependency_sessions(session_name_or_id, db_session, owner_id)
8382

8483

8584
async def find_dependent_sessions(
8685
root_session_name_or_id: str | UUID,
8786
db_session: SASession,
88-
access_key: AccessKey,
87+
owner_id: UUID,
8988
*,
9089
allow_stale: bool = False,
9190
) -> set[UUID]:
@@ -108,7 +107,7 @@ async def _find_dependent_sessions(session_id: UUID) -> set[UUID]:
108107
root_session = await SessionRow.get_session(
109108
db_session,
110109
root_session_name_or_id,
111-
access_key=access_key,
110+
owner_id=owner_id,
112111
allow_stale=allow_stale,
113112
)
114113
return await _find_dependent_sessions(cast(UUID, root_session.id))

src/ai/backend/manager/repositories/session/repository.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ async def get_session_owner(self, session_id: str | SessionId) -> UserData:
5757
async def get_session_validated(
5858
self,
5959
session_name_or_id: str | SessionId,
60-
owner_access_key: AccessKey,
60+
owner_id: uuid.UUID,
6161
kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY,
6262
allow_stale: bool = False,
6363
eager_loading_op: Sequence[_AbstractLoad] | None = None,
6464
) -> SessionRow:
6565
return await self._db_source.get_session_validated(
6666
session_name_or_id,
67-
owner_access_key,
67+
owner_id,
6868
kernel_loading_strategy,
6969
allow_stale,
7070
eager_loading_op,
@@ -74,9 +74,9 @@ async def get_session_validated(
7474
async def match_sessions(
7575
self,
7676
id_or_name_prefix: str,
77-
owner_access_key: AccessKey,
77+
owner_id: uuid.UUID,
7878
) -> list[SessionRow]:
79-
return await self._db_source.match_sessions(id_or_name_prefix, owner_access_key)
79+
return await self._db_source.match_sessions(id_or_name_prefix, owner_id)
8080

8181
@session_repository_resilience.apply()
8282
async def get_session_to_determine_status(
@@ -104,11 +104,9 @@ async def update_session_name(
104104
self,
105105
session_name_or_id: str | SessionId,
106106
new_name: str,
107-
owner_access_key: AccessKey,
107+
owner_id: uuid.UUID,
108108
) -> SessionRow:
109-
return await self._db_source.update_session_name(
110-
session_name_or_id, new_name, owner_access_key
111-
)
109+
return await self._db_source.update_session_name(session_name_or_id, new_name, owner_id)
112110

113111
@session_repository_resilience.apply()
114112
async def get_container_registry(
@@ -210,52 +208,48 @@ async def query_userinfo(
210208
async def get_target_session_ids(
211209
self,
212210
session_name_or_id: str | uuid.UUID,
213-
access_key: AccessKey,
211+
owner_id: uuid.UUID,
214212
recursive: bool = False,
215213
) -> list[SessionId]:
216214
"""
217215
Get list of session IDs including dependent sessions if recursive.
218216
219217
:param session_name_or_id: Name or ID of the primary session
220-
:param access_key: Access key of the session owner
218+
:param owner_id: User UUID of the session owner
221219
:param recursive: If True, include dependent sessions
222220
:return: List of session IDs
223221
"""
224-
return await self._db_source.get_target_session_ids(
225-
session_name_or_id, access_key, recursive
226-
)
222+
return await self._db_source.get_target_session_ids(session_name_or_id, owner_id, recursive)
227223

228224
@session_repository_resilience.apply()
229225
async def find_dependency_sessions(
230226
self,
231227
session_name_or_id: uuid.UUID | str,
232-
access_key: AccessKey,
228+
owner_id: uuid.UUID,
233229
) -> dict[str, list[Any] | str]:
234-
return await self._db_source.find_dependency_sessions(session_name_or_id, access_key)
230+
return await self._db_source.find_dependency_sessions(session_name_or_id, owner_id)
235231

236232
@session_repository_resilience.apply()
237233
async def get_session_with_group(
238234
self,
239235
session_name_or_id: str | SessionId,
240-
owner_access_key: AccessKey,
236+
owner_id: uuid.UUID,
241237
kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY,
242238
allow_stale: bool = False,
243239
) -> SessionRow:
244240
"""Get session with group information eagerly loaded"""
245241
return await self._db_source.get_session_with_group(
246-
session_name_or_id, owner_access_key, kernel_loading_strategy, allow_stale
242+
session_name_or_id, owner_id, kernel_loading_strategy, allow_stale
247243
)
248244

249245
@session_repository_resilience.apply()
250246
async def get_session_with_routing_minimal(
251247
self,
252248
session_name_or_id: str | SessionId,
253-
owner_access_key: AccessKey,
249+
owner_id: uuid.UUID,
254250
) -> SessionRow:
255251
"""Get session with minimal routing information"""
256-
return await self._db_source.get_session_with_routing_minimal(
257-
session_name_or_id, owner_access_key
258-
)
252+
return await self._db_source.get_session_with_routing_minimal(session_name_or_id, owner_id)
259253

260254
@session_repository_resilience.apply()
261255
async def search(

0 commit comments

Comments
 (0)