Skip to content

Commit 6d9c8a9

Browse files
jopemachineclaude
andcommitted
fix(BA-5650-G): align scheduler db_source and tests with renamed fields
Update remaining call sites in scheduler db_source and test fixtures to use the renamed dataclass fields: - ``main_access_key`` (previously ``access_key``) on ``ScheduledSessionData``, ``SessionDataForPull``, ``SessionDataForStart``, ``SessionWorkload``, ``SweptSessionInfo``, ``TerminatingSessionData``, ``PendingSessionData``. - ``owner_id`` (previously ``user_uuid``) on ``SessionEnqueueData``, ``KernelEnqueueData``, ``SessionDataForStart``, ``SessionWorkload``. - ``pending_sessions.owner_ids`` (previously ``user_uuids``) on ``PendingSessions``. Joined ``users`` for SweptSessionInfo and ScheduledSessionData queries so ``main_access_key`` is sourced from the user record rather than the session column being dropped in slice K. Drop unused ``target_main_access_key`` argument from ``EndpointRow.delegate_endpoint_ownership`` call (already unused by the ORM helper); keep the parameter on the repository facade for caller compatibility. Inject ``user_repository=MagicMock()`` into ``AgentRegistry`` test construction in ``test_reconcile_agent_resources``. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c24db41 commit 6d9c8a9

5 files changed

Lines changed: 85 additions & 64 deletions

File tree

src/ai/backend/manager/repositories/scheduler/db_source/db_source.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ async def _fetch_user_policies(
700700
"""Fetch user resource policies for users in pending sessions."""
701701
user_policies: dict[UUID, UserResourcePolicy] = {}
702702

703-
if not pending_sessions.user_uuids:
703+
if not pending_sessions.owner_ids:
704704
return user_policies
705705

706706
user_policy_result = await db_sess.execute(
@@ -716,7 +716,7 @@ async def _fetch_user_policies(
716716
KeyPairResourcePolicyRow,
717717
KeyPairRow.resource_policy == KeyPairResourcePolicyRow.name,
718718
)
719-
.where(UserRow.uuid.in_(pending_sessions.user_uuids))
719+
.where(UserRow.uuid.in_(pending_sessions.owner_ids))
720720
)
721721

722722
for row in user_policy_result:
@@ -1146,11 +1146,14 @@ async def get_terminating_sessions_by_ids(
11461146
for kernel in session_row.kernels
11471147
]
11481148

1149+
owner_main_ak = (
1150+
session_row.user.main_access_key if session_row.user else None
1151+
)
11491152
terminating_sessions.append(
11501153
TerminatingSessionData(
11511154
session_id=session_row.id,
1152-
access_key=AccessKey(session_row.access_key)
1153-
if session_row.access_key
1155+
main_access_key=AccessKey(owner_main_ak)
1156+
if owner_main_ak
11541157
else AccessKey(""),
11551158
creation_id=session_row.creation_id or "",
11561159
status=session_row.status,
@@ -1183,11 +1186,12 @@ async def get_pending_timeout_sessions_by_ids(
11831186
sa.select(
11841187
SessionRow.id,
11851188
SessionRow.creation_id,
1186-
SessionRow.access_key,
1189+
UserRow.main_access_key,
11871190
SessionRow.created_at,
11881191
ScalingGroupRow.scheduler_opts,
11891192
)
11901193
.select_from(SessionRow)
1194+
.join(UserRow, SessionRow.user_uuid == UserRow.uuid)
11911195
.join(ScalingGroupRow, SessionRow.scaling_group_name == ScalingGroupRow.name)
11921196
.where(
11931197
SessionRow.id.in_(session_ids),
@@ -1213,7 +1217,7 @@ async def get_pending_timeout_sessions_by_ids(
12131217
SweptSessionInfo(
12141218
session_id=row.id,
12151219
creation_id=row.creation_id,
1216-
access_key=row.access_key,
1220+
main_access_key=row.main_access_key,
12171221
)
12181222
)
12191223

@@ -1302,8 +1306,8 @@ async def enqueue_session(
13021306
id=session_data.id,
13031307
creation_id=session_data.creation_id,
13041308
name=session_data.name,
1305-
access_key=session_data.access_key,
1306-
user_uuid=session_data.user_uuid,
1309+
access_key=session_data.main_access_key,
1310+
user_uuid=session_data.owner_id,
13071311
group_id=session_data.group_id,
13081312
domain_name=session_data.domain_name,
13091313
scaling_group_name=session_data.scaling_group_name,
@@ -1349,8 +1353,8 @@ async def enqueue_session(
13491353
scaling_group=kernel.scaling_group,
13501354
domain_name=kernel.domain_name,
13511355
group_id=kernel.group_id,
1352-
user_uuid=kernel.user_uuid,
1353-
access_key=kernel.access_key,
1356+
user_uuid=kernel.owner_id,
1357+
access_key=kernel.main_access_key,
13541358
image=kernel.image,
13551359
architecture=kernel.architecture,
13561360
registry=kernel.registry,
@@ -1387,7 +1391,7 @@ async def enqueue_session(
13871391
element_type=RBACElementType.SESSION,
13881392
scope_ref=RBACElementRef(
13891393
element_type=RBACElementType.USER,
1390-
element_id=str(session_data.user_uuid),
1394+
element_id=str(session_data.owner_id),
13911395
),
13921396
additional_scope_refs=[
13931397
RBACElementRef(
@@ -1843,21 +1847,28 @@ async def allocate_sessions(
18431847
# First, fetch session data to get creation_id and access_key
18441848
session_ids = {alloc.session_id for alloc in allocation_batch.allocations}
18451849
if session_ids:
1846-
query = sa.select(
1847-
SessionRow.id, SessionRow.creation_id, SessionRow.access_key
1848-
).where(SessionRow.id.in_(session_ids))
1850+
query = (
1851+
sa.select(
1852+
SessionRow.id, SessionRow.creation_id, UserRow.main_access_key
1853+
)
1854+
.select_from(SessionRow)
1855+
.join(UserRow, SessionRow.user_uuid == UserRow.uuid)
1856+
.where(SessionRow.id.in_(session_ids))
1857+
)
18491858
result = await db_sess.execute(query)
1850-
session_data_map = {row.id: (row.creation_id, row.access_key) for row in result}
1859+
session_data_map = {
1860+
row.id: (row.creation_id, row.main_access_key) for row in result
1861+
}
18511862

18521863
# Create SessionEventData for each allocated session
18531864
for allocation in allocation_batch.allocations:
18541865
if session_data := session_data_map.get(allocation.session_id):
1855-
creation_id, access_key = session_data
1866+
creation_id, main_access_key = session_data
18561867
scheduled_sessions.append(
18571868
ScheduledSessionData(
18581869
session_id=allocation.session_id,
18591870
creation_id=creation_id,
1860-
access_key=access_key,
1871+
main_access_key=main_access_key,
18611872
reason="triggered-by-scheduler",
18621873
)
18631874
)
@@ -2917,7 +2928,9 @@ async def _get_sessions_by_statuses(
29172928
scheduled_session = ScheduledSessionData(
29182929
session_id=session.id,
29192930
creation_id=session.creation_id or "",
2920-
access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""),
2931+
main_access_key=AccessKey(session.access_key)
2932+
if session.access_key
2933+
else AccessKey(""),
29212934
reason="triggered-by-scheduler",
29222935
)
29232936
scheduled_sessions.append(scheduled_session)
@@ -2962,7 +2975,7 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes
29622975
ScheduledSessionData(
29632976
session_id=session.id,
29642977
creation_id=session.creation_id or "",
2965-
access_key=AccessKey(session.access_key)
2978+
main_access_key=AccessKey(session.access_key)
29662979
if session.access_key
29672980
else AccessKey(""),
29682981
reason="triggered-by-scheduler",
@@ -3102,7 +3115,7 @@ async def _get_sessions_for_pull(
31023115
sessions_map[session_id] = SessionDataForPull(
31033116
session_id=session_id,
31043117
creation_id=row.creation_id,
3105-
access_key=row.access_key,
3118+
main_access_key=row.access_key,
31063119
kernels=[],
31073120
)
31083121

@@ -3294,13 +3307,13 @@ async def _get_sessions_for_start(
32943307
SessionDataForStart(
32953308
session_id=session_info["id"],
32963309
creation_id=session_info["creation_id"],
3297-
access_key=session_info["access_key"],
3310+
main_access_key=session_info["access_key"],
32983311
session_type=session_info["session_type"],
32993312
name=session_info["name"],
33003313
cluster_mode=session_info["cluster_mode"],
33013314
kernels=kernel_bindings,
33023315
environ=session_info.get("environ", {}),
3303-
user_uuid=session_info["user_uuid"],
3316+
owner_id=session_info["user_uuid"],
33043317
user_email=user_info.email,
33053318
user_name=user_info.username,
33063319
)
@@ -4074,7 +4087,7 @@ async def _fetch_sessions_for_pull_by_ids(
40744087
sessions_map[session_id] = SessionDataForPull(
40754088
session_id=session_id,
40764089
creation_id=row.creation_id,
4077-
access_key=row.access_key,
4090+
main_access_key=row.access_key,
40784091
kernels=[],
40794092
)
40804093

@@ -4293,13 +4306,13 @@ async def _fetch_sessions_for_start_by_ids(
42934306
SessionDataForStart(
42944307
session_id=session_info["id"],
42954308
creation_id=session_info["creation_id"],
4296-
access_key=session_info["access_key"],
4309+
main_access_key=session_info["access_key"],
42974310
session_type=session_info["session_type"],
42984311
name=session_info["name"],
42994312
cluster_mode=session_info["cluster_mode"],
43004313
kernels=kernel_bindings,
43014314
environ=session_info.get("environ", {}),
4302-
user_uuid=session_info["user_uuid"],
4315+
owner_id=session_info["user_uuid"],
43034316
user_email=user_info.email,
43044317
user_name=user_info.username,
43054318
)
@@ -4369,7 +4382,7 @@ async def search_sessions_with_kernels(
43694382
sessions_map[row.id] = SessionDataForPull(
43704383
session_id=row.id,
43714384
creation_id=row.creation_id,
4372-
access_key=row.access_key,
4385+
main_access_key=row.access_key,
43734386
kernels=[],
43744387
)
43754388

@@ -4625,13 +4638,13 @@ async def search_sessions_with_kernels_and_user(
46254638
SessionDataForStart(
46264639
session_id=session_info["id"],
46274640
creation_id=session_info["creation_id"],
4628-
access_key=session_info["access_key"],
4641+
main_access_key=session_info["access_key"],
46294642
session_type=session_info["session_type"],
46304643
name=session_info["name"],
46314644
cluster_mode=session_info["cluster_mode"],
46324645
kernels=session_info["kernels"],
46334646
environ=session_info.get("environ") or {},
4634-
user_uuid=session_info["user_uuid"],
4647+
owner_id=session_info["user_uuid"],
46354648
user_email=user_info.email,
46364649
user_name=user_info.username,
46374650
)

src/ai/backend/manager/repositories/user/db_source/db_source.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,10 +674,16 @@ async def delegate_endpoint_ownership(
674674
target_user_uuid: UUID,
675675
target_main_access_key: AccessKey,
676676
) -> None:
677-
"""Delegate endpoint ownership to another user."""
677+
"""Delegate endpoint ownership to another user.
678+
679+
``target_main_access_key`` is accepted for caller compatibility but is
680+
no longer needed by ``EndpointRow.delegate_endpoint_ownership`` — the
681+
downstream ``delegate_ownership`` calls only need ``target_user_uuid``.
682+
"""
683+
del target_main_access_key # unused, kept for caller signature compatibility
678684
async with self._db.begin_session() as session:
679685
await EndpointRow.delegate_endpoint_ownership(
680-
session, user_uuid, target_user_uuid, target_main_access_key
686+
session, user_uuid, target_user_uuid
681687
)
682688

683689
async def delete_endpoints(

src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def prepare(
125125
id=session_id,
126126
creation_id=spec.session_creation_id,
127127
name=spec.session_name,
128-
main_access_key=spec.main_access_key,
128+
main_access_key=spec.access_key,
129129
owner_id=spec.user_scope.user_uuid,
130130
group_id=spec.user_scope.group_id,
131131
domain_name=spec.user_scope.domain_name,
@@ -255,7 +255,7 @@ async def _prepare_kernels(
255255
domain_name=spec.user_scope.domain_name,
256256
group_id=spec.user_scope.group_id,
257257
owner_id=spec.user_scope.user_uuid,
258-
main_access_key=spec.main_access_key,
258+
main_access_key=spec.access_key,
259259
image=image_info.canonical if image_info else self.DEFAULT_IMAGE_NAME,
260260
architecture=image_info.architecture if image_info else self.DEFAULT_ARCHITECTURE,
261261
registry=image_info.registry if image_info else self.DEFAULT_REGISTRY,

0 commit comments

Comments
 (0)