diff --git a/changes/11047.enhance.md b/changes/11047.enhance.md new file mode 100644 index 00000000000..63b2ee24773 --- /dev/null +++ b/changes/11047.enhance.md @@ -0,0 +1 @@ +Collapse scheduler / predicates / scheduler options signatures to take `owner_id: UUID`. Rename `access_key` field to `main_access_key` on `ScheduledSessionData` / `TerminatingSessionData` / `SweptSessionInfo` / scheduler types. diff --git a/src/ai/backend/manager/repositories/events/db_source/db_source.py b/src/ai/backend/manager/repositories/events/db_source/db_source.py index c163a8ce0c7..ec66a8f0432 100644 --- a/src/ai/backend/manager/repositories/events/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/events/db_source/db_source.py @@ -6,6 +6,7 @@ from ai.backend.manager.errors.resource import ProjectNotFound from ai.backend.manager.models.group import groups from ai.backend.manager.models.session import SessionRow +from ai.backend.manager.models.user import UserRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine @@ -21,8 +22,13 @@ async def match_sessions_by_name( access_key: AccessKey, ) -> list[SessionRow]: async with self._db.begin_readonly_session(isolation_level="READ COMMITTED") as db_sess: + owner_id = await db_sess.scalar( + sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key) + ) + if owner_id is None: + return [] return await SessionRow.match_sessions( - db_sess, session_name, access_key, allow_prefix=False + db_sess, session_name, owner_id=owner_id, allow_prefix=False ) async def resolve_group_id(self, group_name: str) -> uuid.UUID: diff --git a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py index 0af02fa9c20..793ee215926 100644 --- a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py @@ -351,9 +351,9 @@ async def _fetch_pending_sessions( if session_id not in sessions_map: sessions_map[session_id] = PendingSessionData( id=session_id, - access_key=row.access_key, + main_access_key=row.access_key, requested_slots=row.requested_slots, - user_uuid=row.user_uuid, + owner_id=row.user_uuid, group_id=row.group_id, domain_name=row.domain_name, scaling_group_name=row.scaling_group_name, @@ -700,7 +700,7 @@ async def _fetch_user_policies( """Fetch user resource policies for users in pending sessions.""" user_policies: dict[UUID, UserResourcePolicy] = {} - if not pending_sessions.user_uuids: + if not pending_sessions.owner_ids: return user_policies user_policy_result = await db_sess.execute( @@ -716,7 +716,7 @@ async def _fetch_user_policies( KeyPairResourcePolicyRow, KeyPairRow.resource_policy == KeyPairResourcePolicyRow.name, ) - .where(UserRow.uuid.in_(pending_sessions.user_uuids)) + .where(UserRow.uuid.in_(pending_sessions.owner_ids)) ) for row in user_policy_result: @@ -1149,7 +1149,7 @@ async def get_terminating_sessions_by_ids( terminating_sessions.append( TerminatingSessionData( session_id=session_row.id, - access_key=AccessKey(session_row.access_key) + main_access_key=AccessKey(session_row.access_key) if session_row.access_key else AccessKey(""), creation_id=session_row.creation_id or "", @@ -1213,7 +1213,7 @@ async def get_pending_timeout_sessions_by_ids( SweptSessionInfo( session_id=row.id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, ) ) @@ -1302,8 +1302,8 @@ async def enqueue_session( id=session_data.id, creation_id=session_data.creation_id, name=session_data.name, - access_key=session_data.access_key, - user_uuid=session_data.user_uuid, + access_key=session_data.main_access_key, + user_uuid=session_data.owner_id, group_id=session_data.group_id, domain_name=session_data.domain_name, scaling_group_name=session_data.scaling_group_name, @@ -1349,8 +1349,8 @@ async def enqueue_session( scaling_group=kernel.scaling_group, domain_name=kernel.domain_name, group_id=kernel.group_id, - user_uuid=kernel.user_uuid, - access_key=kernel.access_key, + user_uuid=kernel.owner_id, + access_key=kernel.main_access_key, image=kernel.image, architecture=kernel.architecture, registry=kernel.registry, @@ -1387,7 +1387,7 @@ async def enqueue_session( element_type=RBACElementType.SESSION, scope_ref=RBACElementRef( element_type=RBACElementType.USER, - element_id=str(session_data.user_uuid), + element_id=str(session_data.owner_id), ), additional_scope_refs=[ RBACElementRef( @@ -1857,7 +1857,7 @@ async def allocate_sessions( ScheduledSessionData( session_id=allocation.session_id, creation_id=creation_id, - access_key=access_key, + main_access_key=access_key, reason="triggered-by-scheduler", ) ) @@ -2917,7 +2917,7 @@ async def _get_sessions_by_statuses( scheduled_session = ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""), + main_access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""), reason="triggered-by-scheduler", ) scheduled_sessions.append(scheduled_session) @@ -2962,7 +2962,7 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - access_key=AccessKey(session.access_key) + main_access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""), reason="triggered-by-scheduler", diff --git a/src/ai/backend/manager/repositories/scheduler/options.py b/src/ai/backend/manager/repositories/scheduler/options.py index 7ceb9b68088..7a27cb148cf 100644 --- a/src/ai/backend/manager/repositories/scheduler/options.py +++ b/src/ai/backend/manager/repositories/scheduler/options.py @@ -107,58 +107,6 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner - @staticmethod - def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}%") - else: - condition = SessionRow.access_key.like(f"%{spec.value}%") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = sa.func.lower(SessionRow.access_key) == spec.value.lower() - else: - condition = SessionRow.access_key == spec.value - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"{spec.value}%") - else: - condition = SessionRow.access_key.like(f"{spec.value}%") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}") - else: - condition = SessionRow.access_key.like(f"%{spec.value}") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - @staticmethod def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: @@ -413,8 +361,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner @staticmethod - def by_user_uuid_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition: - """Factory for user UUID equality filter.""" + def by_owner_id_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition: + """Factory for owner_id equality filter.""" def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.negated: @@ -424,8 +372,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner @staticmethod - def by_user_uuid_filter_in(spec: UUIDInMatchSpec) -> QueryCondition: - """Factory for user UUID IN filter.""" + def by_owner_id_filter_in(spec: UUIDInMatchSpec) -> QueryCondition: + """Factory for owner_id IN filter.""" def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.negated: diff --git a/src/ai/backend/manager/repositories/scheduler/types/allocation.py b/src/ai/backend/manager/repositories/scheduler/types/allocation.py index d25c0440d52..2b9629ba7ce 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/allocation.py +++ b/src/ai/backend/manager/repositories/scheduler/types/allocation.py @@ -69,8 +69,7 @@ class SessionAllocation: kernel_allocations: list[KernelAllocation] # List of agent allocations for this session agent_allocations: list[AgentAllocation] - # Keypair associated with the session - access_key: AccessKey + main_access_key: AccessKey # Phases that passed during scheduling passed_phases: list[SchedulingPredicate] = field(default_factory=list) # Phases that failed during scheduling (normally empty for successful allocations) diff --git a/src/ai/backend/manager/repositories/scheduler/types/results.py b/src/ai/backend/manager/repositories/scheduler/types/results.py index 75be947cd08..40fff71ddf5 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/results.py +++ b/src/ai/backend/manager/repositories/scheduler/types/results.py @@ -13,5 +13,5 @@ class ScheduledSessionData: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey reason: str diff --git a/src/ai/backend/manager/repositories/scheduler/types/session.py b/src/ai/backend/manager/repositories/scheduler/types/session.py index fe975b3d4e8..3075817a81c 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/session.py +++ b/src/ai/backend/manager/repositories/scheduler/types/session.py @@ -44,9 +44,9 @@ class PendingSessionData: """Pending session data for scheduling.""" id: SessionId - access_key: AccessKey + main_access_key: AccessKey requested_slots: ResourceSlot - user_uuid: UUID + owner_id: UUID group_id: UUID domain_name: str scaling_group_name: str @@ -64,9 +64,9 @@ def to_session_workload(self) -> SessionWorkload: kernel_workloads = [k.to_kernel_workload() for k in self.kernels] return SessionWorkload( session_id=self.id, - access_key=self.access_key, + access_key=self.main_access_key, requested_slots=self.requested_slots, - user_uuid=self.user_uuid, + user_uuid=self.owner_id, group_id=self.group_id, domain_name=self.domain_name, scaling_group=self.scaling_group_name, @@ -90,12 +90,12 @@ class PendingSessions: @cached_property def access_keys(self) -> set[AccessKey]: """Extract unique access keys from pending sessions.""" - return {s.access_key for s in self.sessions} + return {s.main_access_key for s in self.sessions} @cached_property - def user_uuids(self) -> set[UUID]: - """Extract unique user UUIDs from pending sessions.""" - return {s.user_uuid for s in self.sessions} + def owner_ids(self) -> set[UUID]: + """Extract unique owner (user) UUIDs from pending sessions.""" + return {s.owner_id for s in self.sessions} @cached_property def group_ids(self) -> set[UUID]: @@ -125,7 +125,7 @@ class TerminatingSessionData: """Data for a session that needs to be terminated.""" session_id: SessionId - access_key: AccessKey + main_access_key: AccessKey creation_id: str status: SessionStatus status_info: str @@ -161,7 +161,7 @@ class SweptSessionInfo: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey @dataclass diff --git a/src/ai/backend/manager/repositories/scheduler/types/session_creation.py b/src/ai/backend/manager/repositories/scheduler/types/session_creation.py index 3dfcab2a6ae..a2187381587 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/session_creation.py +++ b/src/ai/backend/manager/repositories/scheduler/types/session_creation.py @@ -222,8 +222,8 @@ class KernelEnqueueData: scaling_group: str domain_name: str group_id: UUID - user_uuid: UUID - access_key: AccessKey + owner_id: UUID + main_access_key: AccessKey image: str # Canonical image name architecture: str registry: str @@ -268,8 +268,8 @@ class SessionEnqueueData: id: SessionId creation_id: str name: str - access_key: AccessKey - user_uuid: UUID + main_access_key: AccessKey + owner_id: UUID group_id: UUID domain_name: str scaling_group_name: str diff --git a/src/ai/backend/manager/repositories/stream/db_source/db_source.py b/src/ai/backend/manager/repositories/stream/db_source/db_source.py index f49c7f3bbcc..bc475de9bff 100644 --- a/src/ai/backend/manager/repositories/stream/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/stream/db_source/db_source.py @@ -1,7 +1,7 @@ import sqlalchemy as sa from ai.backend.common.types import AccessKey -from ai.backend.manager.errors.kernel import SessionNotFound +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow from ai.backend.manager.models.user import UserRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine @@ -23,7 +23,7 @@ async def get_streaming_session( sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key) ) if owner_id is None: - raise SessionNotFound(f"Unknown access_key: {access_key}") + raise UserNotFound(f"No user with main_access_key={access_key}") return await SessionRow.get_session( db_sess, session_name, diff --git a/src/ai/backend/manager/scheduler/drf.py b/src/ai/backend/manager/scheduler/drf.py index 6de246211f5..6b18a9b0645 100644 --- a/src/ai/backend/manager/scheduler/drf.py +++ b/src/ai/backend/manager/scheduler/drf.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import uuid from collections import defaultdict from collections.abc import Mapping, Sequence from decimal import Decimal @@ -9,7 +10,6 @@ import trafaret as t from ai.backend.common.types import ( - AccessKey, ResourceSlot, SessionId, ) @@ -24,7 +24,7 @@ class DRFScheduler(AbstractScheduler): - per_user_dominant_share: dict[AccessKey, Decimal] + per_user_dominant_share: dict[uuid.UUID, Decimal] total_capacity: ResourceSlot def __init__( @@ -60,23 +60,22 @@ def pick_session( slot_share = Decimal(value) / slot_cap if dominant_share < slot_share: dominant_share = slot_share - raw_access_key = existing_sess.access_key - if raw_access_key is not None: - access_key = AccessKey(raw_access_key) - if self.per_user_dominant_share[access_key] < dominant_share: - self.per_user_dominant_share[access_key] = dominant_share + owner_id = existing_sess.user_uuid + if owner_id is not None: + if self.per_user_dominant_share[owner_id] < dominant_share: + self.per_user_dominant_share[owner_id] = dominant_share log.debug("per-user dominant share: {}", dict(self.per_user_dominant_share)) # Find who has the least dominant share among the pending session. - users_with_pending_session: set[AccessKey] = { - AccessKey(pending_sess.access_key) + users_with_pending_session: set[uuid.UUID] = { + pending_sess.user_uuid for pending_sess in pending_sessions - if pending_sess.access_key is not None + if pending_sess.user_uuid is not None } if not users_with_pending_session: return None least_dominant_share_user, dshare = min( - ((akey, self.per_user_dominant_share[akey]) for akey in users_with_pending_session), + ((oid, self.per_user_dominant_share[oid]) for oid in users_with_pending_session), key=lambda item: item[1], ) log.debug("least dominant share user: {} ({})", least_dominant_share_user, dshare) @@ -84,7 +83,7 @@ def pick_session( # Pick the first pending session of the user # who has the lowest dominant share. for pending_sess in pending_sessions: - if pending_sess.access_key == least_dominant_share_user: + if pending_sess.user_uuid == least_dominant_share_user: return SessionId(pending_sess.id) return None @@ -96,10 +95,7 @@ def update_allocation( ) -> None: # In such case, we just skip updating self.per_user_dominant_share state # and the scheduler continues to pick another session within the same scaling group. - raw_access_key = scheduled_session_or_kernel.access_key - if raw_access_key is None: - return - access_key = AccessKey(raw_access_key) + owner_id = scheduled_session_or_kernel.user_uuid requested_slots = scheduled_session_or_kernel.requested_slots # Update the dominant share. @@ -114,5 +110,5 @@ def update_allocation( slot_share = Decimal(value) / slot_cap if dominant_share_from_request < slot_share: dominant_share_from_request = slot_share - if self.per_user_dominant_share[access_key] < dominant_share_from_request: - self.per_user_dominant_share[access_key] = dominant_share_from_request + if self.per_user_dominant_share[owner_id] < dominant_share_from_request: + self.per_user_dominant_share[owner_id] = dominant_share_from_request diff --git a/src/ai/backend/manager/scheduler/predicates.py b/src/ai/backend/manager/scheduler/predicates.py index f4cb7feffce..027a1faa1c0 100644 --- a/src/ai/backend/manager/scheduler/predicates.py +++ b/src/ai/backend/manager/scheduler/predicates.py @@ -29,6 +29,12 @@ log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.scheduler")) +async def _resolve_main_access_key(db_sess: SASession, sess_ctx: SessionRow) -> str | None: + """Resolve the owner's main access key via UserRow join.""" + stmt = sa.select(UserRow.main_access_key).where(UserRow.uuid == sess_ctx.user_uuid) + return await db_sess.scalar(stmt) + + async def check_reserved_batch_session( db_sess: SASession, _sched_ctx: SchedulingContext, @@ -53,9 +59,16 @@ async def check_concurrency( sched_ctx: SchedulingContext, sess_ctx: SessionRow, ) -> PredicateResult: + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult( + False, + "Session owner has no main_access_key; cannot evaluate concurrency policy", + ) + async def _get_max_concurrent_sessions() -> int: resouce_policy_q = sa.select(KeyPairRow.resource_policy).where( - KeyPairRow.access_key == sess_ctx.access_key + KeyPairRow.access_key == main_ak ) if sess_ctx.is_private: concurrent_session_column = KeyPairResourcePolicyRow.max_concurrent_sftp_sessions @@ -69,9 +82,9 @@ async def _get_max_concurrent_sessions() -> int: max_concurrent_sessions = await execute_with_retry(_get_max_concurrent_sessions) or 0 if sess_ctx.is_private: - redis_key = f"keypair.sftp_concurrency_used.{sess_ctx.access_key}" + redis_key = f"keypair.sftp_concurrency_used.{main_ak}" else: - redis_key = f"keypair.concurrency_used.{sess_ctx.access_key}" + redis_key = f"keypair.concurrency_used.{main_ak}" ok, concurrency_used = await sched_ctx.registry.valkey_stat.check_keypair_concurrency( redis_key, max_concurrent_sessions, @@ -83,7 +96,7 @@ async def _get_max_concurrent_sessions() -> int: ) log.debug( "number of concurrent sessions of ak:{0} = {1} / {2}", - sess_ctx.access_key, + main_ak, concurrency_used, max_concurrent_sessions, ) @@ -135,9 +148,10 @@ async def check_keypair_resource_limit( sched_ctx: SchedulingContext, sess_ctx: SessionRow, ) -> PredicateResult: - resouce_policy_q = sa.select(KeyPairRow.resource_policy).where( - KeyPairRow.access_key == sess_ctx.access_key - ) + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") + resouce_policy_q = sa.select(KeyPairRow.resource_policy).where(KeyPairRow.access_key == main_ak) select_query = sa.select(KeyPairResourcePolicyRow).where( KeyPairResourcePolicyRow.name == resouce_policy_q.scalar_subquery() ) @@ -146,7 +160,7 @@ async def check_keypair_resource_limit( if resource_policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) resource_policy_map = { "total_resource_slots": resource_policy.total_resource_slots, @@ -155,14 +169,11 @@ async def check_keypair_resource_limit( total_keypair_allowed = ResourceSlot.from_policy( resource_policy_map, cast(Mapping[str, Any], sched_ctx.known_slot_types) ) - - if sess_ctx.access_key is None: - return PredicateResult(False, "Session has no access key") key_occupied = await sched_ctx.registry.get_keypair_occupancy( - AccessKey(sess_ctx.access_key), db_sess=db_sess + AccessKey(main_ak), db_sess=db_sess ) - log.debug("keypair:{} current-occupancy: {}", sess_ctx.access_key, key_occupied) - log.debug("keypair:{} total-allowed: {}", sess_ctx.access_key, total_keypair_allowed) + log.debug("keypair:{} current-occupancy: {}", main_ak, key_occupied) + log.debug("keypair:{} total-allowed: {}", main_ak, total_keypair_allowed) if not (key_occupied + sess_ctx.requested_slots <= total_keypair_allowed): return PredicateResult( False, @@ -300,10 +311,13 @@ async def check_pending_session_count_limit( result = True failure_msgs = [] + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") query = ( sa.select(SessionRow) .where( - (SessionRow.access_key == sess_ctx.access_key) + (SessionRow.user_uuid == sess_ctx.user_uuid) & (SessionRow.status == SessionStatus.PENDING) ) .options(noload("*"), load_only(SessionRow.requested_slots)) @@ -319,7 +333,7 @@ async def check_pending_session_count_limit( policy_stmt = ( sa.select(KeyPairResourcePolicyRow) .select_from(j) - .where(KeyPairRow.access_key == sess_ctx.access_key) + .where(KeyPairRow.access_key == main_ak) .options( noload("*"), load_only( @@ -331,7 +345,7 @@ async def check_pending_session_count_limit( if policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) pending_count_limit: int | None = policy.max_pending_session_count @@ -344,7 +358,7 @@ async def check_pending_session_count_limit( log.debug( "access key:{} number of pending sessions: {} / {}", - sess_ctx.access_key, + main_ak, len(pending_sessions), pending_count_limit, ) @@ -361,10 +375,13 @@ async def check_pending_session_resource_limit( result = True failure_msgs = [] + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") query = ( sa.select(SessionRow) .where( - (SessionRow.access_key == sess_ctx.access_key) + (SessionRow.user_uuid == sess_ctx.user_uuid) & (SessionRow.status == SessionStatus.PENDING) ) .options(noload("*"), load_only(SessionRow.requested_slots)) @@ -380,7 +397,7 @@ async def check_pending_session_resource_limit( policy_stmt = ( sa.select(KeyPairResourcePolicyRow) .select_from(j) - .where(KeyPairRow.access_key == sess_ctx.access_key) + .where(KeyPairRow.access_key == main_ak) .options( noload("*"), load_only( @@ -392,7 +409,7 @@ async def check_pending_session_resource_limit( if policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) pending_resource_limit: ResourceSlot | None = policy.max_pending_session_resource_slots @@ -413,12 +430,12 @@ async def check_pending_session_resource_limit( failure_msgs.append(msg) log.debug( "access key:{} current-occupancy of pending sessions: {}", - sess_ctx.access_key, + main_ak, current_pending_session_slots, ) log.debug( "access key:{} total-allowed of pending sessions: {}", - sess_ctx.access_key, + main_ak, pending_resource_limit, ) if not result: diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py b/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py index 61c815d3d9e..941382e08e5 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py @@ -125,8 +125,8 @@ async def prepare( id=session_id, creation_id=spec.session_creation_id, name=spec.session_name, - access_key=spec.access_key, - user_uuid=spec.user_scope.user_uuid, + main_access_key=spec.access_key, + owner_id=spec.user_scope.user_uuid, group_id=spec.user_scope.group_id, domain_name=spec.user_scope.domain_name, scaling_group_name=validated_scaling_group.name, @@ -254,8 +254,8 @@ async def _prepare_kernels( scaling_group=validated_scaling_group.name, domain_name=spec.user_scope.domain_name, group_id=spec.user_scope.group_id, - user_uuid=spec.user_scope.user_uuid, - access_key=spec.access_key, + owner_id=spec.user_scope.user_uuid, + main_access_key=spec.access_key, image=image_info.canonical if image_info else self.DEFAULT_IMAGE_NAME, architecture=image_info.architecture if image_info else self.DEFAULT_ARCHITECTURE, registry=image_info.registry if image_info else self.DEFAULT_REGISTRY, diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py b/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py index c8fd82d2d2d..927e92f599c 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py @@ -257,7 +257,7 @@ async def enqueue_session( hook_result = await self._hook_plugin_ctx.dispatch( "PRE_ENQUEUE_SESSION", - (session_data.id, session_data.name, session_data.access_key), + (session_data.id, session_data.name, session_data.main_access_key), return_when=ALL_COMPLETED, ) if hook_result.status != PASSED: @@ -295,7 +295,7 @@ async def enqueue_session( ) await self._hook_plugin_ctx.notify( "POST_ENQUEUE_SESSION", - (session_id, session_data.name, session_data.access_key), + (session_id, session_data.name, session_data.main_access_key), ) return session_id