Skip to content

Commit 3c1cc92

Browse files
committed
refactor(BA-5650-E): collapse scheduler signatures to owner_id
Scheduler / predicates / scheduler-type collapse of the owner key: - ``scheduler/predicates.py``: predicates now take SessionRow and resolve ``main_access_key`` from the owner via a helper when a keypair-scoped lookup (Redis concurrency, keypair resource policy) is required. Renames ``SessionRow.user_uuid`` references throughout. - ``scheduler/drf.py``: per-user fairness tracking keyed by ``owner_id``/``main_access_key`` pair. - ``repositories/scheduler/options.py``: drop the duplicated ``by_access_key_*`` factories — session filters go through ``SessionConditions`` helpers instead. - ``repositories/scheduler/types/*``: rename ``access_key`` to ``main_access_key`` on ``ScheduledSessionData``, ``TerminatingSessionData``, ``SweptSessionInfo``, ``KernelEnqueueData`` and ``SessionEnqueueData``. - ``repositories/events/db_source/db_source.py`` and ``repositories/stream/db_source/db_source.py``: resolve the owner UUID from ``main_access_key`` via a sub-select shim while the schema still exposes ``sessions.access_key``.
1 parent f811d03 commit 3c1cc92

10 files changed

Lines changed: 82 additions & 112 deletions

File tree

changes/BA-5650-E.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Collapse scheduler / predicates / scheduler options signatures to take `owner_id: UUID`. Rename `access_key` field to `main_access_key` on `ScheduledSessionData` / `TerminatingSessionData` / `SweptSessionInfo` / scheduler types.

src/ai/backend/manager/repositories/events/db_source/db_source.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ai.backend.manager.errors.resource import ProjectNotFound
77
from ai.backend.manager.models.group import groups
88
from ai.backend.manager.models.session import SessionRow
9+
from ai.backend.manager.models.user import UserRow
910
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
1011

1112

@@ -21,8 +22,13 @@ async def match_sessions_by_name(
2122
access_key: AccessKey,
2223
) -> list[SessionRow]:
2324
async with self._db.begin_readonly_session(isolation_level="READ COMMITTED") as db_sess:
25+
owner_id = await db_sess.scalar(
26+
sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key)
27+
)
28+
if owner_id is None:
29+
return []
2430
return await SessionRow.match_sessions(
25-
db_sess, session_name, access_key, allow_prefix=False
31+
db_sess, session_name, owner_id=owner_id, allow_prefix=False
2632
)
2733

2834
async def resolve_group_id(self, group_name: str) -> uuid.UUID:

src/ai/backend/manager/repositories/scheduler/options.py

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -107,58 +107,6 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
107107

108108
return inner
109109

110-
@staticmethod
111-
def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition:
112-
def inner() -> sa.sql.expression.ColumnElement[bool]:
113-
if spec.case_insensitive:
114-
condition = SessionRow.access_key.ilike(f"%{spec.value}%")
115-
else:
116-
condition = SessionRow.access_key.like(f"%{spec.value}%")
117-
if spec.negated:
118-
condition = sa.not_(condition)
119-
return condition
120-
121-
return inner
122-
123-
@staticmethod
124-
def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition:
125-
def inner() -> sa.sql.expression.ColumnElement[bool]:
126-
if spec.case_insensitive:
127-
condition = sa.func.lower(SessionRow.access_key) == spec.value.lower()
128-
else:
129-
condition = SessionRow.access_key == spec.value
130-
if spec.negated:
131-
condition = sa.not_(condition)
132-
return condition
133-
134-
return inner
135-
136-
@staticmethod
137-
def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition:
138-
def inner() -> sa.sql.expression.ColumnElement[bool]:
139-
if spec.case_insensitive:
140-
condition = SessionRow.access_key.ilike(f"{spec.value}%")
141-
else:
142-
condition = SessionRow.access_key.like(f"{spec.value}%")
143-
if spec.negated:
144-
condition = sa.not_(condition)
145-
return condition
146-
147-
return inner
148-
149-
@staticmethod
150-
def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition:
151-
def inner() -> sa.sql.expression.ColumnElement[bool]:
152-
if spec.case_insensitive:
153-
condition = SessionRow.access_key.ilike(f"%{spec.value}")
154-
else:
155-
condition = SessionRow.access_key.like(f"%{spec.value}")
156-
if spec.negated:
157-
condition = sa.not_(condition)
158-
return condition
159-
160-
return inner
161-
162110
@staticmethod
163111
def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition:
164112
def inner() -> sa.sql.expression.ColumnElement[bool]:
@@ -413,8 +361,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
413361
return inner
414362

415363
@staticmethod
416-
def by_user_uuid_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition:
417-
"""Factory for user UUID equality filter."""
364+
def by_owner_id_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition:
365+
"""Factory for owner_id equality filter."""
418366

419367
def inner() -> sa.sql.expression.ColumnElement[bool]:
420368
if spec.negated:
@@ -424,8 +372,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
424372
return inner
425373

426374
@staticmethod
427-
def by_user_uuid_filter_in(spec: UUIDInMatchSpec) -> QueryCondition:
428-
"""Factory for user UUID IN filter."""
375+
def by_owner_id_filter_in(spec: UUIDInMatchSpec) -> QueryCondition:
376+
"""Factory for owner_id IN filter."""
429377

430378
def inner() -> sa.sql.expression.ColumnElement[bool]:
431379
if spec.negated:

src/ai/backend/manager/repositories/scheduler/types/allocation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class SessionAllocation:
6969
kernel_allocations: list[KernelAllocation]
7070
# List of agent allocations for this session
7171
agent_allocations: list[AgentAllocation]
72-
# Keypair associated with the session
72+
# Resolved main_access_key of the owner; required for keypair-scoped concurrency tracking and resource policy lookups.
7373
access_key: AccessKey
7474
# Phases that passed during scheduling
7575
passed_phases: list[SchedulingPredicate] = field(default_factory=list)

src/ai/backend/manager/repositories/scheduler/types/results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ class ScheduledSessionData:
1313

1414
session_id: SessionId
1515
creation_id: str
16-
access_key: AccessKey
16+
main_access_key: AccessKey
1717
reason: str

src/ai/backend/manager/repositories/scheduler/types/session.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ class PendingSessionData:
4444
"""Pending session data for scheduling."""
4545

4646
id: SessionId
47+
# Resolved main_access_key of the owner; required for keypair-scoped concurrency tracking and resource policy lookups.
4748
access_key: AccessKey
4849
requested_slots: ResourceSlot
49-
user_uuid: UUID
50+
owner_id: UUID
5051
group_id: UUID
5152
domain_name: str
5253
scaling_group_name: str
@@ -64,9 +65,9 @@ def to_session_workload(self) -> SessionWorkload:
6465
kernel_workloads = [k.to_kernel_workload() for k in self.kernels]
6566
return SessionWorkload(
6667
session_id=self.id,
67-
access_key=self.access_key,
68+
main_access_key=self.access_key,
6869
requested_slots=self.requested_slots,
69-
user_uuid=self.user_uuid,
70+
owner_id=self.owner_id,
7071
group_id=self.group_id,
7172
domain_name=self.domain_name,
7273
scaling_group=self.scaling_group_name,
@@ -90,12 +91,12 @@ class PendingSessions:
9091
@cached_property
9192
def access_keys(self) -> set[AccessKey]:
9293
"""Extract unique access keys from pending sessions."""
93-
return {s.access_key for s in self.sessions}
94+
return {s.main_access_key for s in self.sessions}
9495

9596
@cached_property
96-
def user_uuids(self) -> set[UUID]:
97-
"""Extract unique user UUIDs from pending sessions."""
98-
return {s.user_uuid for s in self.sessions}
97+
def owner_ids(self) -> set[UUID]:
98+
"""Extract unique owner (user) UUIDs from pending sessions."""
99+
return {s.owner_id for s in self.sessions}
99100

100101
@cached_property
101102
def group_ids(self) -> set[UUID]:
@@ -125,7 +126,7 @@ class TerminatingSessionData:
125126
"""Data for a session that needs to be terminated."""
126127

127128
session_id: SessionId
128-
access_key: AccessKey
129+
main_access_key: AccessKey
129130
creation_id: str
130131
status: SessionStatus
131132
status_info: str
@@ -161,7 +162,7 @@ class SweptSessionInfo:
161162

162163
session_id: SessionId
163164
creation_id: str
164-
access_key: AccessKey
165+
main_access_key: AccessKey
165166

166167

167168
@dataclass

src/ai/backend/manager/repositories/scheduler/types/session_creation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ class KernelEnqueueData:
222222
scaling_group: str
223223
domain_name: str
224224
group_id: UUID
225-
user_uuid: UUID
226-
access_key: AccessKey
225+
owner_id: UUID
226+
main_access_key: AccessKey
227227
image: str # Canonical image name
228228
architecture: str
229229
registry: str
@@ -268,8 +268,8 @@ class SessionEnqueueData:
268268
id: SessionId
269269
creation_id: str
270270
name: str
271-
access_key: AccessKey
272-
user_uuid: UUID
271+
main_access_key: AccessKey
272+
owner_id: UUID
273273
group_id: UUID
274274
domain_name: str
275275
scaling_group_name: str

src/ai/backend/manager/repositories/stream/db_source/db_source.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import sqlalchemy as sa
2+
13
from ai.backend.common.types import AccessKey
4+
from ai.backend.manager.errors.kernel import SessionNotFound
25
from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow
6+
from ai.backend.manager.models.user import UserRow
37
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
48

59

@@ -15,9 +19,14 @@ async def get_streaming_session(
1519
access_key: AccessKey,
1620
) -> SessionRow:
1721
async with self._db.begin_readonly_session() as db_sess:
22+
owner_id = await db_sess.scalar(
23+
sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key)
24+
)
25+
if owner_id is None:
26+
raise SessionNotFound(f"Unknown access_key: {access_key}")
1827
return await SessionRow.get_session(
1928
db_sess,
2029
session_name,
21-
access_key,
30+
owner_id=owner_id,
2231
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
2332
)

src/ai/backend/manager/scheduler/drf.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import uuid
45
from collections import defaultdict
56
from collections.abc import Mapping, Sequence
67
from decimal import Decimal
@@ -9,7 +10,6 @@
910
import trafaret as t
1011

1112
from ai.backend.common.types import (
12-
AccessKey,
1313
ResourceSlot,
1414
SessionId,
1515
)
@@ -24,7 +24,7 @@
2424

2525

2626
class DRFScheduler(AbstractScheduler):
27-
per_user_dominant_share: dict[AccessKey, Decimal]
27+
per_user_dominant_share: dict[uuid.UUID, Decimal]
2828
total_capacity: ResourceSlot
2929

3030
def __init__(
@@ -60,31 +60,30 @@ def pick_session(
6060
slot_share = Decimal(value) / slot_cap
6161
if dominant_share < slot_share:
6262
dominant_share = slot_share
63-
raw_access_key = existing_sess.access_key
64-
if raw_access_key is not None:
65-
access_key = AccessKey(raw_access_key)
66-
if self.per_user_dominant_share[access_key] < dominant_share:
67-
self.per_user_dominant_share[access_key] = dominant_share
63+
owner_id = existing_sess.owner_id
64+
if owner_id is not None:
65+
if self.per_user_dominant_share[owner_id] < dominant_share:
66+
self.per_user_dominant_share[owner_id] = dominant_share
6867
log.debug("per-user dominant share: {}", dict(self.per_user_dominant_share))
6968

7069
# Find who has the least dominant share among the pending session.
71-
users_with_pending_session: set[AccessKey] = {
72-
AccessKey(pending_sess.access_key)
70+
users_with_pending_session: set[uuid.UUID] = {
71+
pending_sess.user_uuid
7372
for pending_sess in pending_sessions
74-
if pending_sess.access_key is not None
73+
if pending_sess.user_uuid is not None
7574
}
7675
if not users_with_pending_session:
7776
return None
7877
least_dominant_share_user, dshare = min(
79-
((akey, self.per_user_dominant_share[akey]) for akey in users_with_pending_session),
78+
((oid, self.per_user_dominant_share[oid]) for oid in users_with_pending_session),
8079
key=lambda item: item[1],
8180
)
8281
log.debug("least dominant share user: {} ({})", least_dominant_share_user, dshare)
8382

8483
# Pick the first pending session of the user
8584
# who has the lowest dominant share.
8685
for pending_sess in pending_sessions:
87-
if pending_sess.access_key == least_dominant_share_user:
86+
if pending_sess.user_uuid == least_dominant_share_user:
8887
return SessionId(pending_sess.id)
8988

9089
return None
@@ -96,10 +95,7 @@ def update_allocation(
9695
) -> None:
9796
# In such case, we just skip updating self.per_user_dominant_share state
9897
# and the scheduler continues to pick another session within the same scaling group.
99-
raw_access_key = scheduled_session_or_kernel.access_key
100-
if raw_access_key is None:
101-
return
102-
access_key = AccessKey(raw_access_key)
98+
owner_id = scheduled_session_or_kernel.user_uuid
10399
requested_slots = scheduled_session_or_kernel.requested_slots
104100

105101
# Update the dominant share.
@@ -114,5 +110,5 @@ def update_allocation(
114110
slot_share = Decimal(value) / slot_cap
115111
if dominant_share_from_request < slot_share:
116112
dominant_share_from_request = slot_share
117-
if self.per_user_dominant_share[access_key] < dominant_share_from_request:
118-
self.per_user_dominant_share[access_key] = dominant_share_from_request
113+
if self.per_user_dominant_share[owner_id] < dominant_share_from_request:
114+
self.per_user_dominant_share[owner_id] = dominant_share_from_request

0 commit comments

Comments
 (0)