Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11047.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Collapse scheduler / predicates / scheduler options signatures to take `owner_id: UUID`. Rename `access_key` field to `main_access_key` on `ScheduledSessionData` / `TerminatingSessionData` / `SweptSessionInfo` / scheduler types.
Comment thread
jopemachine marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ai.backend.manager.errors.resource import ProjectNotFound
from ai.backend.manager.models.group import groups
from ai.backend.manager.models.session import SessionRow
from ai.backend.manager.models.user import UserRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine


Expand All @@ -21,8 +22,13 @@ async def match_sessions_by_name(
access_key: AccessKey,
) -> list[SessionRow]:
async with self._db.begin_readonly_session(isolation_level="READ COMMITTED") as db_sess:
owner_id = await db_sess.scalar(
sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key)
)
if owner_id is None:
return []
return await SessionRow.match_sessions(
db_sess, session_name, access_key, allow_prefix=False
db_sess, session_name, owner_id=owner_id, allow_prefix=False
)

async def resolve_group_id(self, group_name: str) -> uuid.UUID:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ async def _fetch_pending_sessions(
if session_id not in sessions_map:
sessions_map[session_id] = PendingSessionData(
id=session_id,
access_key=row.access_key,
main_access_key=row.access_key,
requested_slots=row.requested_slots,
user_uuid=row.user_uuid,
owner_id=row.user_uuid,
group_id=row.group_id,
domain_name=row.domain_name,
scaling_group_name=row.scaling_group_name,
Expand Down Expand Up @@ -700,7 +700,7 @@ async def _fetch_user_policies(
"""Fetch user resource policies for users in pending sessions."""
user_policies: dict[UUID, UserResourcePolicy] = {}

if not pending_sessions.user_uuids:
if not pending_sessions.owner_ids:
return user_policies

user_policy_result = await db_sess.execute(
Expand All @@ -716,7 +716,7 @@ async def _fetch_user_policies(
KeyPairResourcePolicyRow,
KeyPairRow.resource_policy == KeyPairResourcePolicyRow.name,
)
.where(UserRow.uuid.in_(pending_sessions.user_uuids))
.where(UserRow.uuid.in_(pending_sessions.owner_ids))
)

for row in user_policy_result:
Expand Down Expand Up @@ -1149,7 +1149,7 @@ async def get_terminating_sessions_by_ids(
terminating_sessions.append(
TerminatingSessionData(
session_id=session_row.id,
access_key=AccessKey(session_row.access_key)
main_access_key=AccessKey(session_row.access_key)
if session_row.access_key
else AccessKey(""),
creation_id=session_row.creation_id or "",
Expand Down Expand Up @@ -1213,7 +1213,7 @@ async def get_pending_timeout_sessions_by_ids(
SweptSessionInfo(
session_id=row.id,
creation_id=row.creation_id,
access_key=row.access_key,
main_access_key=row.access_key,
)
)

Expand Down Expand Up @@ -1302,8 +1302,8 @@ async def enqueue_session(
id=session_data.id,
creation_id=session_data.creation_id,
name=session_data.name,
access_key=session_data.access_key,
user_uuid=session_data.user_uuid,
access_key=session_data.main_access_key,
user_uuid=session_data.owner_id,
group_id=session_data.group_id,
domain_name=session_data.domain_name,
scaling_group_name=session_data.scaling_group_name,
Expand Down Expand Up @@ -1349,8 +1349,8 @@ async def enqueue_session(
scaling_group=kernel.scaling_group,
domain_name=kernel.domain_name,
group_id=kernel.group_id,
user_uuid=kernel.user_uuid,
access_key=kernel.access_key,
user_uuid=kernel.owner_id,
access_key=kernel.main_access_key,
image=kernel.image,
architecture=kernel.architecture,
registry=kernel.registry,
Expand Down Expand Up @@ -1387,7 +1387,7 @@ async def enqueue_session(
element_type=RBACElementType.SESSION,
scope_ref=RBACElementRef(
element_type=RBACElementType.USER,
element_id=str(session_data.user_uuid),
element_id=str(session_data.owner_id),
),
additional_scope_refs=[
RBACElementRef(
Expand Down Expand Up @@ -1857,7 +1857,7 @@ async def allocate_sessions(
ScheduledSessionData(
session_id=allocation.session_id,
creation_id=creation_id,
access_key=access_key,
main_access_key=access_key,
reason="triggered-by-scheduler",
)
)
Expand Down Expand Up @@ -2917,7 +2917,7 @@ async def _get_sessions_by_statuses(
scheduled_session = ScheduledSessionData(
session_id=session.id,
creation_id=session.creation_id or "",
access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""),
main_access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""),
reason="triggered-by-scheduler",
)
scheduled_sessions.append(scheduled_session)
Expand Down Expand Up @@ -2962,7 +2962,7 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes
ScheduledSessionData(
session_id=session.id,
creation_id=session.creation_id or "",
access_key=AccessKey(session.access_key)
main_access_key=AccessKey(session.access_key)
if session.access_key
else AccessKey(""),
reason="triggered-by-scheduler",
Expand Down
60 changes: 4 additions & 56 deletions src/ai/backend/manager/repositories/scheduler/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,58 +107,6 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:

return inner

@staticmethod
def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"%{spec.value}%")
else:
condition = SessionRow.access_key.like(f"%{spec.value}%")
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = sa.func.lower(SessionRow.access_key) == spec.value.lower()
else:
condition = SessionRow.access_key == spec.value
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"{spec.value}%")
else:
condition = SessionRow.access_key.like(f"{spec.value}%")
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"%{spec.value}")
else:
condition = SessionRow.access_key.like(f"%{spec.value}")
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
Expand Down Expand Up @@ -413,8 +361,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
return inner

@staticmethod
def by_user_uuid_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition:
"""Factory for user UUID equality filter."""
def by_owner_id_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition:
"""Factory for owner_id equality filter."""

def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.negated:
Expand All @@ -424,8 +372,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
return inner

@staticmethod
def by_user_uuid_filter_in(spec: UUIDInMatchSpec) -> QueryCondition:
"""Factory for user UUID IN filter."""
def by_owner_id_filter_in(spec: UUIDInMatchSpec) -> QueryCondition:
"""Factory for owner_id IN filter."""

def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.negated:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ class SessionAllocation:
kernel_allocations: list[KernelAllocation]
# List of agent allocations for this session
agent_allocations: list[AgentAllocation]
# Keypair associated with the session
access_key: AccessKey
main_access_key: AccessKey
# Phases that passed during scheduling
passed_phases: list[SchedulingPredicate] = field(default_factory=list)
# Phases that failed during scheduling (normally empty for successful allocations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ class ScheduledSessionData:

session_id: SessionId
creation_id: str
access_key: AccessKey
main_access_key: AccessKey
reason: str
Comment thread
jopemachine marked this conversation as resolved.
20 changes: 10 additions & 10 deletions src/ai/backend/manager/repositories/scheduler/types/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class PendingSessionData:
"""Pending session data for scheduling."""

id: SessionId
access_key: AccessKey
main_access_key: AccessKey
requested_slots: ResourceSlot
user_uuid: UUID
owner_id: UUID
group_id: UUID
domain_name: str
scaling_group_name: str
Expand All @@ -64,9 +64,9 @@ def to_session_workload(self) -> SessionWorkload:
kernel_workloads = [k.to_kernel_workload() for k in self.kernels]
return SessionWorkload(
session_id=self.id,
access_key=self.access_key,
access_key=self.main_access_key,
requested_slots=self.requested_slots,
user_uuid=self.user_uuid,
user_uuid=self.owner_id,
group_id=self.group_id,
domain_name=self.domain_name,
scaling_group=self.scaling_group_name,
Expand All @@ -90,12 +90,12 @@ class PendingSessions:
@cached_property
def access_keys(self) -> set[AccessKey]:
"""Extract unique access keys from pending sessions."""
return {s.access_key for s in self.sessions}
return {s.main_access_key for s in self.sessions}
Comment thread
jopemachine marked this conversation as resolved.

@cached_property
def user_uuids(self) -> set[UUID]:
"""Extract unique user UUIDs from pending sessions."""
return {s.user_uuid for s in self.sessions}
def owner_ids(self) -> set[UUID]:
"""Extract unique owner (user) UUIDs from pending sessions."""
return {s.owner_id for s in self.sessions}

@cached_property
def group_ids(self) -> set[UUID]:
Expand Down Expand Up @@ -125,7 +125,7 @@ class TerminatingSessionData:
"""Data for a session that needs to be terminated."""

session_id: SessionId
access_key: AccessKey
main_access_key: AccessKey
creation_id: str
status: SessionStatus
status_info: str
Comment thread
jopemachine marked this conversation as resolved.
Expand Down Expand Up @@ -161,7 +161,7 @@ class SweptSessionInfo:

session_id: SessionId
creation_id: str
access_key: AccessKey
main_access_key: AccessKey

Comment thread
jopemachine marked this conversation as resolved.

@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ class KernelEnqueueData:
scaling_group: str
domain_name: str
group_id: UUID
user_uuid: UUID
access_key: AccessKey
owner_id: UUID
main_access_key: AccessKey
image: str # Canonical image name
Comment thread
jopemachine marked this conversation as resolved.
architecture: str
registry: str
Expand Down Expand Up @@ -268,8 +268,8 @@ class SessionEnqueueData:
id: SessionId
creation_id: str
name: str
access_key: AccessKey
user_uuid: UUID
main_access_key: AccessKey
owner_id: UUID
group_id: UUID
domain_name: str
scaling_group_name: str
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sqlalchemy as sa

from ai.backend.common.types import AccessKey
from ai.backend.manager.errors.kernel import SessionNotFound
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow
from ai.backend.manager.models.user import UserRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
Expand All @@ -23,7 +23,7 @@ async def get_streaming_session(
sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key)
)
if owner_id is None:
raise SessionNotFound(f"Unknown access_key: {access_key}")
raise UserNotFound(f"No user with main_access_key={access_key}")
return await SessionRow.get_session(
db_sess,
session_name,
Expand Down
Loading
Loading