Skip to content

Commit 7ab7788

Browse files
committed
refactor(BA-5713): address slice E review feedback
- repositories/scheduler/types/session.py: rename PendingSessionData.access_key -> main_access_key (and drop the outdated resolved-main_access_key comment). - repositories/scheduler/db_source/db_source.py: update the PendingSessionData call site to main_access_key + owner_id keyword names matching the dataclass. - scheduler/drf.py: use existing_sess.user_uuid (SessionRow stores the owner UUID there, not owner_id). - scheduler/predicates.py: guard every _resolve_main_access_key consumer (check_concurrency, check_keypair_resource_limit, check_pending_session_count_limit, check_pending_session_resource_limit) with an early main_ak-is-None return so that NULL main_access_key users don't fall through to keypair policy lookups that match with NULL.
1 parent 81cdac3 commit 7ab7788

4 files changed

Lines changed: 16 additions & 9 deletions

File tree

src/ai/backend/manager/repositories/scheduler/db_source/db_source.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,9 @@ async def _fetch_pending_sessions(
351351
if session_id not in sessions_map:
352352
sessions_map[session_id] = PendingSessionData(
353353
id=session_id,
354-
access_key=row.access_key,
354+
main_access_key=row.access_key,
355355
requested_slots=row.requested_slots,
356-
user_uuid=row.user_uuid,
356+
owner_id=row.user_uuid,
357357
group_id=row.group_id,
358358
domain_name=row.domain_name,
359359
scaling_group_name=row.scaling_group_name,

src/ai/backend/manager/repositories/scheduler/types/session.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ class PendingSessionData:
4444
"""Pending session data for scheduling."""
4545

4646
id: SessionId
47-
# Resolved main_access_key of the owner; required for keypair-scoped concurrency tracking and resource policy lookups.
48-
access_key: AccessKey
47+
main_access_key: AccessKey
4948
requested_slots: ResourceSlot
5049
owner_id: UUID
5150
group_id: UUID
@@ -65,7 +64,7 @@ def to_session_workload(self) -> SessionWorkload:
6564
kernel_workloads = [k.to_kernel_workload() for k in self.kernels]
6665
return SessionWorkload(
6766
session_id=self.id,
68-
main_access_key=self.access_key,
67+
main_access_key=self.main_access_key,
6968
requested_slots=self.requested_slots,
7069
owner_id=self.owner_id,
7170
group_id=self.group_id,

src/ai/backend/manager/scheduler/drf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def pick_session(
6060
slot_share = Decimal(value) / slot_cap
6161
if dominant_share < slot_share:
6262
dominant_share = slot_share
63-
owner_id = existing_sess.owner_id
63+
owner_id = existing_sess.user_uuid
6464
if owner_id is not None:
6565
if self.per_user_dominant_share[owner_id] < dominant_share:
6666
self.per_user_dominant_share[owner_id] = dominant_share

src/ai/backend/manager/scheduler/predicates.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ async def check_concurrency(
6060
sess_ctx: SessionRow,
6161
) -> PredicateResult:
6262
main_ak = await _resolve_main_access_key(db_sess, sess_ctx)
63+
if main_ak is None:
64+
return PredicateResult(
65+
False,
66+
"Session owner has no main_access_key; cannot evaluate concurrency policy",
67+
)
6368

6469
async def _get_max_concurrent_sessions() -> int:
6570
resouce_policy_q = sa.select(KeyPairRow.resource_policy).where(
@@ -144,6 +149,8 @@ async def check_keypair_resource_limit(
144149
sess_ctx: SessionRow,
145150
) -> PredicateResult:
146151
main_ak = await _resolve_main_access_key(db_sess, sess_ctx)
152+
if main_ak is None:
153+
return PredicateResult(False, "Session owner has no main_access_key")
147154
resouce_policy_q = sa.select(KeyPairRow.resource_policy).where(KeyPairRow.access_key == main_ak)
148155
select_query = sa.select(KeyPairResourcePolicyRow).where(
149156
KeyPairResourcePolicyRow.name == resouce_policy_q.scalar_subquery()
@@ -162,9 +169,6 @@ async def check_keypair_resource_limit(
162169
total_keypair_allowed = ResourceSlot.from_policy(
163170
resource_policy_map, cast(Mapping[str, Any], sched_ctx.known_slot_types)
164171
)
165-
166-
if main_ak is None:
167-
return PredicateResult(False, "Session has no access key")
168172
key_occupied = await sched_ctx.registry.get_keypair_occupancy(
169173
AccessKey(main_ak), db_sess=db_sess
170174
)
@@ -308,6 +312,8 @@ async def check_pending_session_count_limit(
308312
failure_msgs = []
309313

310314
main_ak = await _resolve_main_access_key(db_sess, sess_ctx)
315+
if main_ak is None:
316+
return PredicateResult(False, "Session owner has no main_access_key")
311317
query = (
312318
sa.select(SessionRow)
313319
.where(
@@ -370,6 +376,8 @@ async def check_pending_session_resource_limit(
370376
failure_msgs = []
371377

372378
main_ak = await _resolve_main_access_key(db_sess, sess_ctx)
379+
if main_ak is None:
380+
return PredicateResult(False, "Session owner has no main_access_key")
373381
query = (
374382
sa.select(SessionRow)
375383
.where(

0 commit comments

Comments
 (0)