diff --git a/changes/11045.enhance.md b/changes/11045.enhance.md deleted file mode 100644 index 157161bf881..00000000000 --- a/changes/11045.enhance.md +++ /dev/null @@ -1 +0,0 @@ -Rename `SessionData.user_uuid` / `SessionMetadata.user_uuid` to `owner_id` and drop the redundant `access_key` snapshot fields from those data types. `ComputeSessionNode.access_key` is now sourced from the owner's `main_access_key`, kept in step with the underlying user record. diff --git a/changes/11048.enhance.md b/changes/11048.enhance.md new file mode 100644 index 00000000000..044a4354293 --- /dev/null +++ b/changes/11048.enhance.md @@ -0,0 +1 @@ +Collapse scheduler and sokovan signatures to use `owner_id` and `main_access_key`, propagate rename into sokovan handlers, coordinator, and sequencers. diff --git a/src/ai/backend/manager/repositories/events/db_source/db_source.py b/src/ai/backend/manager/repositories/events/db_source/db_source.py index c163a8ce0c7..ec66a8f0432 100644 --- a/src/ai/backend/manager/repositories/events/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/events/db_source/db_source.py @@ -6,6 +6,7 @@ from ai.backend.manager.errors.resource import ProjectNotFound from ai.backend.manager.models.group import groups from ai.backend.manager.models.session import SessionRow +from ai.backend.manager.models.user import UserRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine @@ -21,8 +22,13 @@ async def match_sessions_by_name( access_key: AccessKey, ) -> list[SessionRow]: async with self._db.begin_readonly_session(isolation_level="READ COMMITTED") as db_sess: + owner_id = await db_sess.scalar( + sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key) + ) + if owner_id is None: + return [] return await SessionRow.match_sessions( - db_sess, session_name, access_key, allow_prefix=False + db_sess, session_name, owner_id=owner_id, allow_prefix=False ) async def resolve_group_id(self, group_name: str) -> uuid.UUID: diff --git a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py index 0af02fa9c20..22eb9a7277d 100644 --- a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py @@ -351,9 +351,9 @@ async def _fetch_pending_sessions( if session_id not in sessions_map: sessions_map[session_id] = PendingSessionData( id=session_id, - access_key=row.access_key, + main_access_key=row.access_key, requested_slots=row.requested_slots, - user_uuid=row.user_uuid, + owner_id=row.user_uuid, group_id=row.group_id, domain_name=row.domain_name, scaling_group_name=row.scaling_group_name, @@ -700,7 +700,7 @@ async def _fetch_user_policies( """Fetch user resource policies for users in pending sessions.""" user_policies: dict[UUID, UserResourcePolicy] = {} - if not pending_sessions.user_uuids: + if not pending_sessions.owner_ids: return user_policies user_policy_result = await db_sess.execute( @@ -716,7 +716,7 @@ async def _fetch_user_policies( KeyPairResourcePolicyRow, KeyPairRow.resource_policy == KeyPairResourcePolicyRow.name, ) - .where(UserRow.uuid.in_(pending_sessions.user_uuids)) + .where(UserRow.uuid.in_(pending_sessions.owner_ids)) ) for row in user_policy_result: @@ -1149,7 +1149,7 @@ async def get_terminating_sessions_by_ids( terminating_sessions.append( TerminatingSessionData( session_id=session_row.id, - access_key=AccessKey(session_row.access_key) + main_access_key=AccessKey(session_row.access_key) if session_row.access_key else AccessKey(""), creation_id=session_row.creation_id or "", @@ -1213,7 +1213,7 @@ async def get_pending_timeout_sessions_by_ids( SweptSessionInfo( session_id=row.id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, ) ) @@ -1302,8 +1302,8 @@ async def enqueue_session( id=session_data.id, creation_id=session_data.creation_id, name=session_data.name, - access_key=session_data.access_key, - user_uuid=session_data.user_uuid, + access_key=session_data.main_access_key, + user_uuid=session_data.owner_id, group_id=session_data.group_id, domain_name=session_data.domain_name, scaling_group_name=session_data.scaling_group_name, @@ -1349,8 +1349,8 @@ async def enqueue_session( scaling_group=kernel.scaling_group, domain_name=kernel.domain_name, group_id=kernel.group_id, - user_uuid=kernel.user_uuid, - access_key=kernel.access_key, + user_uuid=kernel.owner_id, + access_key=kernel.main_access_key, image=kernel.image, architecture=kernel.architecture, registry=kernel.registry, @@ -1387,7 +1387,7 @@ async def enqueue_session( element_type=RBACElementType.SESSION, scope_ref=RBACElementRef( element_type=RBACElementType.USER, - element_id=str(session_data.user_uuid), + element_id=str(session_data.owner_id), ), additional_scope_refs=[ RBACElementRef( @@ -1857,7 +1857,7 @@ async def allocate_sessions( ScheduledSessionData( session_id=allocation.session_id, creation_id=creation_id, - access_key=access_key, + main_access_key=access_key, reason="triggered-by-scheduler", ) ) @@ -2917,7 +2917,9 @@ async def _get_sessions_by_statuses( scheduled_session = ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""), + main_access_key=AccessKey(session.access_key) + if session.access_key + else AccessKey(""), reason="triggered-by-scheduler", ) scheduled_sessions.append(scheduled_session) @@ -2962,7 +2964,7 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - access_key=AccessKey(session.access_key) + main_access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""), reason="triggered-by-scheduler", @@ -3102,7 +3104,7 @@ async def _get_sessions_for_pull( sessions_map[session_id] = SessionDataForPull( session_id=session_id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, kernels=[], ) @@ -3294,13 +3296,13 @@ async def _get_sessions_for_start( SessionDataForStart( session_id=session_info["id"], creation_id=session_info["creation_id"], - access_key=session_info["access_key"], + main_access_key=session_info["access_key"], session_type=session_info["session_type"], name=session_info["name"], cluster_mode=session_info["cluster_mode"], kernels=kernel_bindings, environ=session_info.get("environ", {}), - user_uuid=session_info["user_uuid"], + owner_id=session_info["user_uuid"], user_email=user_info.email, user_name=user_info.username, ) @@ -4074,7 +4076,7 @@ async def _fetch_sessions_for_pull_by_ids( sessions_map[session_id] = SessionDataForPull( session_id=session_id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, kernels=[], ) @@ -4293,13 +4295,13 @@ async def _fetch_sessions_for_start_by_ids( SessionDataForStart( session_id=session_info["id"], creation_id=session_info["creation_id"], - access_key=session_info["access_key"], + main_access_key=session_info["access_key"], session_type=session_info["session_type"], name=session_info["name"], cluster_mode=session_info["cluster_mode"], kernels=kernel_bindings, environ=session_info.get("environ", {}), - user_uuid=session_info["user_uuid"], + owner_id=session_info["user_uuid"], user_email=user_info.email, user_name=user_info.username, ) @@ -4369,7 +4371,7 @@ async def search_sessions_with_kernels( sessions_map[row.id] = SessionDataForPull( session_id=row.id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, kernels=[], ) @@ -4625,13 +4627,13 @@ async def search_sessions_with_kernels_and_user( SessionDataForStart( session_id=session_info["id"], creation_id=session_info["creation_id"], - access_key=session_info["access_key"], + main_access_key=session_info["access_key"], session_type=session_info["session_type"], name=session_info["name"], cluster_mode=session_info["cluster_mode"], kernels=session_info["kernels"], environ=session_info.get("environ") or {}, - user_uuid=session_info["user_uuid"], + owner_id=session_info["user_uuid"], user_email=user_info.email, user_name=user_info.username, ) @@ -4774,6 +4776,26 @@ async def get_db_now(self) -> datetime: result = await conn.execute(sa.select(sa.func.now())) return result.scalar_one() + async def resolve_main_access_keys( + self, session_ids: Sequence[SessionId] + ) -> dict[SessionId, AccessKey]: + """Resolve the main access key for each session's owner. + + Joins ``sessions`` → ``users`` to look up the owner's + ``main_access_key``. Sessions whose owner has no configured + main access key are omitted from the returned mapping. + """ + if not session_ids: + return {} + async with self._db.begin_readonly_session() as db_sess: + stmt = ( + sa.select(SessionRow.id, UserRow.main_access_key) + .join(UserRow, SessionRow.user_uuid == UserRow.uuid) + .where(SessionRow.id.in_([sid for sid in session_ids])) + ) + rows = (await db_sess.execute(stmt)).all() + return {SessionId(row[0]): AccessKey(row[1]) for row in rows if row[1] is not None} + async def _get_db_now_in_session(self, db_sess: SASession) -> datetime: """Get the current timestamp from the database within an existing session. diff --git a/src/ai/backend/manager/repositories/scheduler/options.py b/src/ai/backend/manager/repositories/scheduler/options.py index 7ceb9b68088..7a27cb148cf 100644 --- a/src/ai/backend/manager/repositories/scheduler/options.py +++ b/src/ai/backend/manager/repositories/scheduler/options.py @@ -107,58 +107,6 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner - @staticmethod - def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}%") - else: - condition = SessionRow.access_key.like(f"%{spec.value}%") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = sa.func.lower(SessionRow.access_key) == spec.value.lower() - else: - condition = SessionRow.access_key == spec.value - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"{spec.value}%") - else: - condition = SessionRow.access_key.like(f"{spec.value}%") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}") - else: - condition = SessionRow.access_key.like(f"%{spec.value}") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - @staticmethod def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: @@ -413,8 +361,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner @staticmethod - def by_user_uuid_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition: - """Factory for user UUID equality filter.""" + def by_owner_id_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition: + """Factory for owner_id equality filter.""" def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.negated: @@ -424,8 +372,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner @staticmethod - def by_user_uuid_filter_in(spec: UUIDInMatchSpec) -> QueryCondition: - """Factory for user UUID IN filter.""" + def by_owner_id_filter_in(spec: UUIDInMatchSpec) -> QueryCondition: + """Factory for owner_id IN filter.""" def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.negated: diff --git a/src/ai/backend/manager/repositories/scheduler/repository.py b/src/ai/backend/manager/repositories/scheduler/repository.py index 2921a662814..0e38d59b6a9 100644 --- a/src/ai/backend/manager/repositories/scheduler/repository.py +++ b/src/ai/backend/manager/repositories/scheduler/repository.py @@ -959,3 +959,9 @@ async def get_db_now(self) -> datetime: Current database timestamp with timezone """ return await self._db_source.get_db_now() + + async def resolve_main_access_keys( + self, session_ids: Sequence[SessionId] + ) -> dict[SessionId, AccessKey]: + """Resolve the main access key for each session's owner.""" + return await self._db_source.resolve_main_access_keys(session_ids) diff --git a/src/ai/backend/manager/repositories/scheduler/types/allocation.py b/src/ai/backend/manager/repositories/scheduler/types/allocation.py index d25c0440d52..2b9629ba7ce 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/allocation.py +++ b/src/ai/backend/manager/repositories/scheduler/types/allocation.py @@ -69,8 +69,7 @@ class SessionAllocation: kernel_allocations: list[KernelAllocation] # List of agent allocations for this session agent_allocations: list[AgentAllocation] - # Keypair associated with the session - access_key: AccessKey + main_access_key: AccessKey # Phases that passed during scheduling passed_phases: list[SchedulingPredicate] = field(default_factory=list) # Phases that failed during scheduling (normally empty for successful allocations) diff --git a/src/ai/backend/manager/repositories/scheduler/types/results.py b/src/ai/backend/manager/repositories/scheduler/types/results.py index 75be947cd08..40fff71ddf5 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/results.py +++ b/src/ai/backend/manager/repositories/scheduler/types/results.py @@ -13,5 +13,5 @@ class ScheduledSessionData: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey reason: str diff --git a/src/ai/backend/manager/repositories/scheduler/types/session.py b/src/ai/backend/manager/repositories/scheduler/types/session.py index fe975b3d4e8..8da9962fcac 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/session.py +++ b/src/ai/backend/manager/repositories/scheduler/types/session.py @@ -44,9 +44,9 @@ class PendingSessionData: """Pending session data for scheduling.""" id: SessionId - access_key: AccessKey + main_access_key: AccessKey requested_slots: ResourceSlot - user_uuid: UUID + owner_id: UUID group_id: UUID domain_name: str scaling_group_name: str @@ -64,9 +64,9 @@ def to_session_workload(self) -> SessionWorkload: kernel_workloads = [k.to_kernel_workload() for k in self.kernels] return SessionWorkload( session_id=self.id, - access_key=self.access_key, + main_access_key=self.main_access_key, requested_slots=self.requested_slots, - user_uuid=self.user_uuid, + owner_id=self.owner_id, group_id=self.group_id, domain_name=self.domain_name, scaling_group=self.scaling_group_name, @@ -90,12 +90,12 @@ class PendingSessions: @cached_property def access_keys(self) -> set[AccessKey]: """Extract unique access keys from pending sessions.""" - return {s.access_key for s in self.sessions} + return {s.main_access_key for s in self.sessions} @cached_property - def user_uuids(self) -> set[UUID]: - """Extract unique user UUIDs from pending sessions.""" - return {s.user_uuid for s in self.sessions} + def owner_ids(self) -> set[UUID]: + """Extract unique owner (user) UUIDs from pending sessions.""" + return {s.owner_id for s in self.sessions} @cached_property def group_ids(self) -> set[UUID]: @@ -125,7 +125,7 @@ class TerminatingSessionData: """Data for a session that needs to be terminated.""" session_id: SessionId - access_key: AccessKey + main_access_key: AccessKey creation_id: str status: SessionStatus status_info: str @@ -161,7 +161,7 @@ class SweptSessionInfo: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey @dataclass diff --git a/src/ai/backend/manager/repositories/scheduler/types/session_creation.py b/src/ai/backend/manager/repositories/scheduler/types/session_creation.py index 3dfcab2a6ae..a2187381587 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/session_creation.py +++ b/src/ai/backend/manager/repositories/scheduler/types/session_creation.py @@ -222,8 +222,8 @@ class KernelEnqueueData: scaling_group: str domain_name: str group_id: UUID - user_uuid: UUID - access_key: AccessKey + owner_id: UUID + main_access_key: AccessKey image: str # Canonical image name architecture: str registry: str @@ -268,8 +268,8 @@ class SessionEnqueueData: id: SessionId creation_id: str name: str - access_key: AccessKey - user_uuid: UUID + main_access_key: AccessKey + owner_id: UUID group_id: UUID domain_name: str scaling_group_name: str diff --git a/src/ai/backend/manager/repositories/stream/db_source/db_source.py b/src/ai/backend/manager/repositories/stream/db_source/db_source.py index f49c7f3bbcc..bc475de9bff 100644 --- a/src/ai/backend/manager/repositories/stream/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/stream/db_source/db_source.py @@ -1,7 +1,7 @@ import sqlalchemy as sa from ai.backend.common.types import AccessKey -from ai.backend.manager.errors.kernel import SessionNotFound +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow from ai.backend.manager.models.user import UserRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine @@ -23,7 +23,7 @@ async def get_streaming_session( sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key) ) if owner_id is None: - raise SessionNotFound(f"Unknown access_key: {access_key}") + raise UserNotFound(f"No user with main_access_key={access_key}") return await SessionRow.get_session( db_sess, session_name, diff --git a/src/ai/backend/manager/scheduler/drf.py b/src/ai/backend/manager/scheduler/drf.py index 6de246211f5..6b18a9b0645 100644 --- a/src/ai/backend/manager/scheduler/drf.py +++ b/src/ai/backend/manager/scheduler/drf.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import uuid from collections import defaultdict from collections.abc import Mapping, Sequence from decimal import Decimal @@ -9,7 +10,6 @@ import trafaret as t from ai.backend.common.types import ( - AccessKey, ResourceSlot, SessionId, ) @@ -24,7 +24,7 @@ class DRFScheduler(AbstractScheduler): - per_user_dominant_share: dict[AccessKey, Decimal] + per_user_dominant_share: dict[uuid.UUID, Decimal] total_capacity: ResourceSlot def __init__( @@ -60,23 +60,22 @@ def pick_session( slot_share = Decimal(value) / slot_cap if dominant_share < slot_share: dominant_share = slot_share - raw_access_key = existing_sess.access_key - if raw_access_key is not None: - access_key = AccessKey(raw_access_key) - if self.per_user_dominant_share[access_key] < dominant_share: - self.per_user_dominant_share[access_key] = dominant_share + owner_id = existing_sess.user_uuid + if owner_id is not None: + if self.per_user_dominant_share[owner_id] < dominant_share: + self.per_user_dominant_share[owner_id] = dominant_share log.debug("per-user dominant share: {}", dict(self.per_user_dominant_share)) # Find who has the least dominant share among the pending session. - users_with_pending_session: set[AccessKey] = { - AccessKey(pending_sess.access_key) + users_with_pending_session: set[uuid.UUID] = { + pending_sess.user_uuid for pending_sess in pending_sessions - if pending_sess.access_key is not None + if pending_sess.user_uuid is not None } if not users_with_pending_session: return None least_dominant_share_user, dshare = min( - ((akey, self.per_user_dominant_share[akey]) for akey in users_with_pending_session), + ((oid, self.per_user_dominant_share[oid]) for oid in users_with_pending_session), key=lambda item: item[1], ) log.debug("least dominant share user: {} ({})", least_dominant_share_user, dshare) @@ -84,7 +83,7 @@ def pick_session( # Pick the first pending session of the user # who has the lowest dominant share. for pending_sess in pending_sessions: - if pending_sess.access_key == least_dominant_share_user: + if pending_sess.user_uuid == least_dominant_share_user: return SessionId(pending_sess.id) return None @@ -96,10 +95,7 @@ def update_allocation( ) -> None: # In such case, we just skip updating self.per_user_dominant_share state # and the scheduler continues to pick another session within the same scaling group. - raw_access_key = scheduled_session_or_kernel.access_key - if raw_access_key is None: - return - access_key = AccessKey(raw_access_key) + owner_id = scheduled_session_or_kernel.user_uuid requested_slots = scheduled_session_or_kernel.requested_slots # Update the dominant share. @@ -114,5 +110,5 @@ def update_allocation( slot_share = Decimal(value) / slot_cap if dominant_share_from_request < slot_share: dominant_share_from_request = slot_share - if self.per_user_dominant_share[access_key] < dominant_share_from_request: - self.per_user_dominant_share[access_key] = dominant_share_from_request + if self.per_user_dominant_share[owner_id] < dominant_share_from_request: + self.per_user_dominant_share[owner_id] = dominant_share_from_request diff --git a/src/ai/backend/manager/scheduler/predicates.py b/src/ai/backend/manager/scheduler/predicates.py index f4cb7feffce..027a1faa1c0 100644 --- a/src/ai/backend/manager/scheduler/predicates.py +++ b/src/ai/backend/manager/scheduler/predicates.py @@ -29,6 +29,12 @@ log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.scheduler")) +async def _resolve_main_access_key(db_sess: SASession, sess_ctx: SessionRow) -> str | None: + """Resolve the owner's main access key via UserRow join.""" + stmt = sa.select(UserRow.main_access_key).where(UserRow.uuid == sess_ctx.user_uuid) + return await db_sess.scalar(stmt) + + async def check_reserved_batch_session( db_sess: SASession, _sched_ctx: SchedulingContext, @@ -53,9 +59,16 @@ async def check_concurrency( sched_ctx: SchedulingContext, sess_ctx: SessionRow, ) -> PredicateResult: + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult( + False, + "Session owner has no main_access_key; cannot evaluate concurrency policy", + ) + async def _get_max_concurrent_sessions() -> int: resouce_policy_q = sa.select(KeyPairRow.resource_policy).where( - KeyPairRow.access_key == sess_ctx.access_key + KeyPairRow.access_key == main_ak ) if sess_ctx.is_private: concurrent_session_column = KeyPairResourcePolicyRow.max_concurrent_sftp_sessions @@ -69,9 +82,9 @@ async def _get_max_concurrent_sessions() -> int: max_concurrent_sessions = await execute_with_retry(_get_max_concurrent_sessions) or 0 if sess_ctx.is_private: - redis_key = f"keypair.sftp_concurrency_used.{sess_ctx.access_key}" + redis_key = f"keypair.sftp_concurrency_used.{main_ak}" else: - redis_key = f"keypair.concurrency_used.{sess_ctx.access_key}" + redis_key = f"keypair.concurrency_used.{main_ak}" ok, concurrency_used = await sched_ctx.registry.valkey_stat.check_keypair_concurrency( redis_key, max_concurrent_sessions, @@ -83,7 +96,7 @@ async def _get_max_concurrent_sessions() -> int: ) log.debug( "number of concurrent sessions of ak:{0} = {1} / {2}", - sess_ctx.access_key, + main_ak, concurrency_used, max_concurrent_sessions, ) @@ -135,9 +148,10 @@ async def check_keypair_resource_limit( sched_ctx: SchedulingContext, sess_ctx: SessionRow, ) -> PredicateResult: - resouce_policy_q = sa.select(KeyPairRow.resource_policy).where( - KeyPairRow.access_key == sess_ctx.access_key - ) + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") + resouce_policy_q = sa.select(KeyPairRow.resource_policy).where(KeyPairRow.access_key == main_ak) select_query = sa.select(KeyPairResourcePolicyRow).where( KeyPairResourcePolicyRow.name == resouce_policy_q.scalar_subquery() ) @@ -146,7 +160,7 @@ async def check_keypair_resource_limit( if resource_policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) resource_policy_map = { "total_resource_slots": resource_policy.total_resource_slots, @@ -155,14 +169,11 @@ async def check_keypair_resource_limit( total_keypair_allowed = ResourceSlot.from_policy( resource_policy_map, cast(Mapping[str, Any], sched_ctx.known_slot_types) ) - - if sess_ctx.access_key is None: - return PredicateResult(False, "Session has no access key") key_occupied = await sched_ctx.registry.get_keypair_occupancy( - AccessKey(sess_ctx.access_key), db_sess=db_sess + AccessKey(main_ak), db_sess=db_sess ) - log.debug("keypair:{} current-occupancy: {}", sess_ctx.access_key, key_occupied) - log.debug("keypair:{} total-allowed: {}", sess_ctx.access_key, total_keypair_allowed) + log.debug("keypair:{} current-occupancy: {}", main_ak, key_occupied) + log.debug("keypair:{} total-allowed: {}", main_ak, total_keypair_allowed) if not (key_occupied + sess_ctx.requested_slots <= total_keypair_allowed): return PredicateResult( False, @@ -300,10 +311,13 @@ async def check_pending_session_count_limit( result = True failure_msgs = [] + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") query = ( sa.select(SessionRow) .where( - (SessionRow.access_key == sess_ctx.access_key) + (SessionRow.user_uuid == sess_ctx.user_uuid) & (SessionRow.status == SessionStatus.PENDING) ) .options(noload("*"), load_only(SessionRow.requested_slots)) @@ -319,7 +333,7 @@ async def check_pending_session_count_limit( policy_stmt = ( sa.select(KeyPairResourcePolicyRow) .select_from(j) - .where(KeyPairRow.access_key == sess_ctx.access_key) + .where(KeyPairRow.access_key == main_ak) .options( noload("*"), load_only( @@ -331,7 +345,7 @@ async def check_pending_session_count_limit( if policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) pending_count_limit: int | None = policy.max_pending_session_count @@ -344,7 +358,7 @@ async def check_pending_session_count_limit( log.debug( "access key:{} number of pending sessions: {} / {}", - sess_ctx.access_key, + main_ak, len(pending_sessions), pending_count_limit, ) @@ -361,10 +375,13 @@ async def check_pending_session_resource_limit( result = True failure_msgs = [] + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") query = ( sa.select(SessionRow) .where( - (SessionRow.access_key == sess_ctx.access_key) + (SessionRow.user_uuid == sess_ctx.user_uuid) & (SessionRow.status == SessionStatus.PENDING) ) .options(noload("*"), load_only(SessionRow.requested_slots)) @@ -380,7 +397,7 @@ async def check_pending_session_resource_limit( policy_stmt = ( sa.select(KeyPairResourcePolicyRow) .select_from(j) - .where(KeyPairRow.access_key == sess_ctx.access_key) + .where(KeyPairRow.access_key == main_ak) .options( noload("*"), load_only( @@ -392,7 +409,7 @@ async def check_pending_session_resource_limit( if policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) pending_resource_limit: ResourceSlot | None = policy.max_pending_session_resource_slots @@ -413,12 +430,12 @@ async def check_pending_session_resource_limit( failure_msgs.append(msg) log.debug( "access key:{} current-occupancy of pending sessions: {}", - sess_ctx.access_key, + main_ak, current_pending_session_slots, ) log.debug( "access key:{} total-allowed of pending sessions: {}", - sess_ctx.access_key, + main_ak, pending_resource_limit, ) if not result: diff --git a/src/ai/backend/manager/sokovan/data/allocation.py b/src/ai/backend/manager/sokovan/data/allocation.py index 72334089fe9..c3d3e3d477e 100644 --- a/src/ai/backend/manager/sokovan/data/allocation.py +++ b/src/ai/backend/manager/sokovan/data/allocation.py @@ -80,8 +80,8 @@ class SessionAllocation: kernel_allocations: list[KernelAllocation] # List of agent allocations for this session agent_allocations: list[AgentAllocation] - # Keypair associated with the session - access_key: AccessKey + # Owner's resolved main_access_key; required for keypair-scoped concurrency tracking and resource policy lookups. + main_access_key: AccessKey # Phases that passed during scheduling passed_phases: list[SchedulingPredicate] = field(default_factory=list) # Phases that failed during scheduling (normally empty for successful allocations) @@ -141,7 +141,7 @@ def from_agent_selections( scaling_group=scaling_group, kernel_allocations=kernel_allocations, agent_allocations=agent_allocations, - access_key=session_workload.access_key, + main_access_key=session_workload.main_access_key, ) def unique_agent_ids(self) -> list[AgentId]: diff --git a/src/ai/backend/manager/sokovan/data/lifecycle.py b/src/ai/backend/manager/sokovan/data/lifecycle.py index d332814968e..e4bf41b6a33 100644 --- a/src/ai/backend/manager/sokovan/data/lifecycle.py +++ b/src/ai/backend/manager/sokovan/data/lifecycle.py @@ -64,7 +64,7 @@ class SessionDataForPull: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey kernels: list[KernelBindingData] @@ -74,12 +74,12 @@ class SessionDataForStart: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey session_type: SessionTypes name: str cluster_mode: ClusterMode kernels: list[KernelBindingData] - user_uuid: UUID + owner_id: UUID user_email: str user_name: str environ: dict[str, str] @@ -93,13 +93,13 @@ class ScheduledSessionData: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey session_type: SessionTypes name: str kernels: list[KernelBindingData] # Additional fields for PREPARED sessions cluster_mode: ClusterMode | None = None - user_uuid: UUID | None = None + owner_id: UUID | None = None user_email: str | None = None user_name: str | None = None network_type: NetworkType | None = None @@ -157,12 +157,12 @@ class PreparedSessionData: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey session_type: SessionTypes name: str cluster_mode: ClusterMode kernels: list[KernelStartData] - user_uuid: UUID + owner_id: UUID user_email: str user_name: str network_type: NetworkType | None = None diff --git a/src/ai/backend/manager/sokovan/data/workload.py b/src/ai/backend/manager/sokovan/data/workload.py index c953818621e..ab5d5cafcec 100644 --- a/src/ai/backend/manager/sokovan/data/workload.py +++ b/src/ai/backend/manager/sokovan/data/workload.py @@ -77,12 +77,12 @@ class SessionWorkload: # Session identifier session_id: SessionId - # User identification for fairness calculation - access_key: AccessKey + # Owner's resolved main_access_key; required for keypair-scoped concurrency tracking and resource policy lookups. + main_access_key: AccessKey # Resource requirements requested_slots: ResourceSlot - # User UUID for user resource limit checks - user_uuid: UUID + # Owner (user) UUID for user resource limit checks + owner_id: UUID # Group ID for group resource limit checks group_id: UUID # Domain name for domain resource limit checks diff --git a/src/ai/backend/manager/sokovan/scheduler/coordinator.py b/src/ai/backend/manager/sokovan/scheduler/coordinator.py index abdf61d4ac0..a35a1da9eec 100644 --- a/src/ai/backend/manager/sokovan/scheduler/coordinator.py +++ b/src/ai/backend/manager/sokovan/scheduler/coordinator.py @@ -25,7 +25,7 @@ ) from ai.backend.common.events.types import AbstractBroadcastEvent from ai.backend.common.leader.tasks import EventTaskSpec -from ai.backend.common.types import AccessKey, AgentId, SessionId +from ai.backend.common.types import AgentId, SessionId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.config.provider import ManagerConfigProvider from ai.backend.manager.data.kernel.types import KernelStatus @@ -831,6 +831,8 @@ async def _process_promotion_scaling_group( "check_kernel_status", success_detail=f"All kernels ready for {spec.success_status.value}", ): + # BA-5609: resolve main_access_key for cache invalidation consumer. + access_key_by_id = await self._repository.resolve_main_access_keys(session_ids) result = SessionExecutionResult() for session_info in session_infos: result.successes.append( @@ -839,7 +841,7 @@ async def _process_promotion_scaling_group( from_status=session_info.lifecycle.status, reason=spec.reason, creation_id=session_info.identity.creation_id, - access_key=AccessKey(session_info.metadata.access_key), + access_key=access_key_by_id.get(session_info.identity.id), ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/check_precondition.py b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/check_precondition.py index 3f30adcb754..8db44812541 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/check_precondition.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/check_precondition.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from ai.backend.common.types import AccessKey from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import SessionStatus, StatusTransitions, TransitionStatus @@ -116,6 +115,11 @@ async def execute( sessions_for_pull_data.image_configs, ) + # BA-5609: source resolved main_access_key from SessionDataForPull. + access_key_by_id = { + s.session_id: s.main_access_key for s in sessions_for_pull_data.sessions + } + # Mark all sessions as success for status transition for session in sessions: session_info = session.session_info @@ -125,7 +129,7 @@ async def execute( from_status=session_info.lifecycle.status, reason="passed-preconditions", creation_id=session_info.identity.creation_id, - access_key=AccessKey(session_info.metadata.access_key), + access_key=access_key_by_id.get(session_info.identity.id), ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/deprioritize_sessions.py b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/deprioritize_sessions.py index 6169cda4d8c..55d68d15e97 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/deprioritize_sessions.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/deprioritize_sessions.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from ai.backend.common.defs.session import SESSION_PRIORITY_MIN -from ai.backend.common.types import AccessKey from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import SessionStatus, StatusTransitions, TransitionStatus @@ -112,6 +111,9 @@ async def execute( scaling_group, ) + # BA-5609: resolve main_access_key for cache invalidation consumer. + access_key_by_id = await self._repository.resolve_main_access_keys(session_ids) + # Mark all sessions as success for status transition to PENDING for session in sessions: session_info = session.session_info @@ -121,7 +123,7 @@ async def execute( from_status=session_info.lifecycle.status, reason="deprioritized-for-rescheduling", creation_id=session_info.identity.creation_id, - access_key=AccessKey(session_info.metadata.access_key), + access_key=access_key_by_id.get(session_info.identity.id), ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/schedule_sessions.py b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/schedule_sessions.py index 012fd7ff1a6..ab57109618a 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/schedule_sessions.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/schedule_sessions.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from ai.backend.common.types import AccessKey from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import SessionStatus, StatusTransitions, TransitionStatus @@ -128,7 +127,10 @@ async def execute( from_status=session.session_info.lifecycle.status, reason="no-scheduling-data", creation_id=session.session_info.identity.creation_id, - access_key=AccessKey(session.session_info.metadata.access_key), + # BA-5609: skipped sessions are only recorded to + # scheduling history; access_key is not used by that + # consumer, so leaving it None is safe. + access_key=None, ) ) return result @@ -157,7 +159,7 @@ async def execute( from_status=from_status, reason=event_data.reason, creation_id=event_data.creation_id, - access_key=event_data.access_key, + access_key=event_data.main_access_key, ) ) @@ -171,7 +173,10 @@ async def execute( from_status=session.session_info.lifecycle.status, reason="not-scheduled-this-cycle", creation_id=session.session_info.identity.creation_id, - access_key=AccessKey(session.session_info.metadata.access_key), + # BA-5609: skipped sessions are only recorded to + # scheduling history; access_key is not used by that + # consumer, so leaving it None is safe. + access_key=None, ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/start_sessions.py b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/start_sessions.py index d003f790d1f..f077fa6dba5 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/start_sessions.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/start_sessions.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from ai.backend.common.types import AccessKey from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import SessionStatus, StatusTransitions, TransitionStatus @@ -123,6 +122,9 @@ async def execute( sessions_data.image_configs, ) + # BA-5609: source resolved main_access_key from SessionDataForStart. + access_key_by_id = {s.session_id: s.main_access_key for s in sessions_data.sessions} + # Mark all sessions as success for status transition for session in sessions: session_info = session.session_info @@ -132,7 +134,7 @@ async def execute( from_status=session_info.lifecycle.status, reason="triggered-by-scheduler", creation_id=session_info.identity.creation_id, - access_key=AccessKey(session_info.metadata.access_key), + access_key=access_key_by_id.get(session_info.identity.id), ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/maintenance/sweep_sessions.py b/src/ai/backend/manager/sokovan/scheduler/handlers/maintenance/sweep_sessions.py index 3a31cf55ac8..39d1cd266ed 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/maintenance/sweep_sessions.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/maintenance/sweep_sessions.py @@ -118,7 +118,7 @@ async def execute( from_status=session_data.session_info.lifecycle.status, reason="PENDING_TIMEOUT_EXCEEDED", creation_id=timed_out.creation_id, - access_key=timed_out.access_key, + access_key=timed_out.main_access_key, ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/launcher/launcher.py b/src/ai/backend/manager/sokovan/scheduler/launcher/launcher.py index a7f88de58a7..4c7b8538f8b 100644 --- a/src/ai/backend/manager/sokovan/scheduler/launcher/launcher.py +++ b/src/ai/backend/manager/sokovan/scheduler/launcher/launcher.py @@ -229,7 +229,7 @@ async def _start_single_session( session.session_id, session.session_type, session.name, - session.access_key, + session.main_access_key, session.cluster_mode, ) log.debug(log_fmt + "try-starting", *log_args) @@ -262,7 +262,7 @@ async def _start_single_session( } environ: dict[str, str] = { **session.environ, - "BACKENDAI_USER_UUID": str(session.user_uuid), + "BACKENDAI_USER_UUID": str(session.owner_id), "BACKENDAI_USER_EMAIL": session.user_email, "BACKENDAI_USER_NAME": session.user_name, "BACKENDAI_SESSION_ID": str(session.session_id), @@ -273,7 +273,7 @@ async def _start_single_session( k.cluster_hostname or f"{k.cluster_role}{k.cluster_idx}" for k in session.kernels ), - "BACKENDAI_ACCESS_KEY": session.access_key, + "BACKENDAI_ACCESS_KEY": session.main_access_key, # BACKENDAI_SERVICE_PORTS are set as per-kernel env-vars. "BACKENDAI_PREOPEN_PORTS": ( ",".join(str(port) for port in session.kernels[0].preopen_ports) @@ -335,7 +335,7 @@ async def create_kernels_on_agent( "image": kernel_image_config, "kernel_id": kernel_id_str, "session_id": str(session.session_id), - "owner_user_id": str(session.user_uuid), + "owner_user_id": str(session.owner_id), "owner_project_id": None, # TODO: Implement project-owned sessions "network_id": str(session.session_id), "session_type": session.session_type, diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/provisioner.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/provisioner.py index f689f0d7cab..9a7d11fc32a 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/provisioner.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/provisioner.py @@ -374,7 +374,7 @@ def _update_system_snapshot( # 1. Update resource occupancy - add the session's allocated slots # Update keypair occupancy - current_keypair = snapshot.resource_occupancy.by_keypair.get(workload.access_key) + current_keypair = snapshot.resource_occupancy.by_keypair.get(workload.main_access_key) if current_keypair is None: current_keypair = KeypairOccupancy( occupied_slots=[], session_count=0, sftp_session_count=0 @@ -389,11 +389,11 @@ def _update_system_snapshot( else: current_keypair.session_count += 1 - snapshot.resource_occupancy.by_keypair[workload.access_key] = current_keypair + snapshot.resource_occupancy.by_keypair[workload.main_access_key] = current_keypair # Update user occupancy - current_user = snapshot.resource_occupancy.by_user.get(workload.user_uuid, []) - snapshot.resource_occupancy.by_user[workload.user_uuid] = add_quantities( + current_user = snapshot.resource_occupancy.by_user.get(workload.owner_id, []) + snapshot.resource_occupancy.by_user[workload.owner_id] = add_quantities( current_user, total_quantities ) @@ -412,12 +412,20 @@ def _update_system_snapshot( # 2. Update concurrency counts if workload.is_private: # Increment SFTP session count - current_sftp = snapshot.concurrency.sftp_sessions_by_keypair.get(workload.access_key, 0) - snapshot.concurrency.sftp_sessions_by_keypair[workload.access_key] = current_sftp + 1 + current_sftp = snapshot.concurrency.sftp_sessions_by_keypair.get( + workload.main_access_key, 0 + ) + snapshot.concurrency.sftp_sessions_by_keypair[workload.main_access_key] = ( + current_sftp + 1 + ) else: # Increment regular session count - current_sessions = snapshot.concurrency.sessions_by_keypair.get(workload.access_key, 0) - snapshot.concurrency.sessions_by_keypair[workload.access_key] = current_sessions + 1 + current_sessions = snapshot.concurrency.sessions_by_keypair.get( + workload.main_access_key, 0 + ) + snapshot.concurrency.sessions_by_keypair[workload.main_access_key] = ( + current_sessions + 1 + ) async def _allocate_workload( self, diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/drf.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/drf.py index 22243c2ad2a..b4ba256d656 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/drf.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/drf.py @@ -62,7 +62,7 @@ async def sequence( # Sort workloads by dominant share (ascending order - lower share gets higher priority) # For users with the same dominant share, maintain original order - return sorted(workloads, key=lambda w: user_dominant_shares[w.access_key]) + return sorted(workloads, key=lambda w: user_dominant_shares[w.main_access_key]) def _calculate_dominant_share( self, resource_slots: ResourceSlot, total_capacity: ResourceSlot diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/fair_share.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/fair_share.py index 1b9d462cf88..462deee1bd3 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/fair_share.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/fair_share.py @@ -80,7 +80,7 @@ async def sequence( # If a user doesn't have recorded factors, use default (lowest priority) return sorted( workloads, - key=lambda w: self._get_sort_key(w.user_uuid, user_factors), + key=lambda w: self._get_sort_key(w.owner_id, user_factors), ) async def _load_factors( @@ -92,7 +92,7 @@ async def _load_factors( # Group user_ids by project_id project_users: dict[UUID, set[UUID]] = defaultdict(set) for w in workloads: - project_users[w.group_id].add(w.user_uuid) + project_users[w.group_id].add(w.owner_id) # Build ProjectUserIds list project_user_ids = [ diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/concurrency.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/concurrency.py index 18f926f4579..fe702ddbbce 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/concurrency.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/concurrency.py @@ -22,15 +22,15 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the keypair's resource policy - policy = snapshot.resource_policy.keypair_policies.get(workload.access_key) + policy = snapshot.resource_policy.keypair_policies.get(workload.main_access_key) if not policy: # If no policy is defined, we can't validate - let it pass return # Get current session count - current_sessions = snapshot.concurrency.sessions_by_keypair.get(workload.access_key, 0) + current_sessions = snapshot.concurrency.sessions_by_keypair.get(workload.main_access_key, 0) current_sftp_sessions = snapshot.concurrency.sftp_sessions_by_keypair.get( - workload.access_key, 0 + workload.main_access_key, 0 ) # Check the appropriate limit based on session type diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/keypair_resource_limit.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/keypair_resource_limit.py index 88e8c4dc723..20174bf517b 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/keypair_resource_limit.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/keypair_resource_limit.py @@ -26,13 +26,13 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the keypair's resource policy - policy = snapshot.resource_policy.keypair_policies.get(workload.access_key) + policy = snapshot.resource_policy.keypair_policies.get(workload.main_access_key) if not policy: # If no policy is defined, we can't validate - let it pass return # Get current keypair occupancy (occupied_slots is list[SlotQuantity]) - key_occupancy = snapshot.resource_occupancy.by_keypair.get(workload.access_key) + key_occupancy = snapshot.resource_occupancy.by_keypair.get(workload.main_access_key) if key_occupancy: key_occupied = ResourceSlot({ sq.slot_name: sq.quantity for sq in key_occupancy.occupied_slots diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_count_limit.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_count_limit.py index a4ae0a3c51e..c14e64a861c 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_count_limit.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_count_limit.py @@ -22,7 +22,7 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the keypair's resource policy - policy = snapshot.resource_policy.keypair_policies.get(workload.access_key) + policy = snapshot.resource_policy.keypair_policies.get(workload.main_access_key) if not policy: # If no policy is defined, we can't validate - let it pass return @@ -34,7 +34,7 @@ def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: return # Get current pending sessions for this keypair - pending_sessions = snapshot.pending_sessions.by_keypair.get(workload.access_key, []) + pending_sessions = snapshot.pending_sessions.by_keypair.get(workload.main_access_key, []) current_pending_count = len(pending_sessions) # Check if creating this session would exceed the limit diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_resource_limit.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_resource_limit.py index 758b5c577fc..0ea514479b3 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_resource_limit.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_resource_limit.py @@ -23,7 +23,7 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the keypair's resource policy - policy = snapshot.resource_policy.keypair_policies.get(workload.access_key) + policy = snapshot.resource_policy.keypair_policies.get(workload.main_access_key) if not policy: # If no policy is defined, we can't validate - let it pass return @@ -35,7 +35,7 @@ def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: return # Calculate current pending session resource usage - pending_sessions = snapshot.pending_sessions.by_keypair.get(workload.access_key, []) + pending_sessions = snapshot.pending_sessions.by_keypair.get(workload.main_access_key, []) current_pending_slots = ResourceSlot() for session in pending_sessions: current_pending_slots += session.requested_slots diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/user_resource_limit.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/user_resource_limit.py index 69d8c6cb88f..fe366c99237 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/user_resource_limit.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/user_resource_limit.py @@ -23,13 +23,13 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the user's resource policy - policy = snapshot.resource_policy.user_policies.get(workload.user_uuid) + policy = snapshot.resource_policy.user_policies.get(workload.owner_id) if not policy: # If no user-specific policy, skip validation (no limits apply) return # Get current user occupancy (list[SlotQuantity]) and convert to ResourceSlot - user_occupied_quantities = snapshot.resource_occupancy.by_user.get(workload.user_uuid, []) + user_occupied_quantities = snapshot.resource_occupancy.by_user.get(workload.owner_id, []) user_occupied = ResourceSlot({sq.slot_name: sq.quantity for sq in user_occupied_quantities}) # Check if adding this workload would exceed the limit diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py b/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py index 61c815d3d9e..941382e08e5 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py @@ -125,8 +125,8 @@ async def prepare( id=session_id, creation_id=spec.session_creation_id, name=spec.session_name, - access_key=spec.access_key, - user_uuid=spec.user_scope.user_uuid, + main_access_key=spec.access_key, + owner_id=spec.user_scope.user_uuid, group_id=spec.user_scope.group_id, domain_name=spec.user_scope.domain_name, scaling_group_name=validated_scaling_group.name, @@ -254,8 +254,8 @@ async def _prepare_kernels( scaling_group=validated_scaling_group.name, domain_name=spec.user_scope.domain_name, group_id=spec.user_scope.group_id, - user_uuid=spec.user_scope.user_uuid, - access_key=spec.access_key, + owner_id=spec.user_scope.user_uuid, + main_access_key=spec.access_key, image=image_info.canonical if image_info else self.DEFAULT_IMAGE_NAME, architecture=image_info.architecture if image_info else self.DEFAULT_ARCHITECTURE, registry=image_info.registry if image_info else self.DEFAULT_REGISTRY, diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py b/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py index c8fd82d2d2d..927e92f599c 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py @@ -257,7 +257,7 @@ async def enqueue_session( hook_result = await self._hook_plugin_ctx.dispatch( "PRE_ENQUEUE_SESSION", - (session_data.id, session_data.name, session_data.access_key), + (session_data.id, session_data.name, session_data.main_access_key), return_when=ALL_COMPLETED, ) if hook_result.status != PASSED: @@ -295,7 +295,7 @@ async def enqueue_session( ) await self._hook_plugin_ctx.notify( "POST_ENQUEUE_SESSION", - (session_id, session_data.name, session_data.access_key), + (session_id, session_data.name, session_data.main_access_key), ) return session_id