Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
7c209f1
refactor(BA-5650-D): switch session repository to owner_id
jopemachine Apr 14, 2026
adec0d8
refactor(BA-5650-E): collapse scheduler signatures to owner_id
jopemachine Apr 14, 2026
f44bad9
docs: rename news fragment to 11047
jopemachine Apr 14, 2026
df67b48
docs(BA-5650): use enhance news fragment type for slice
jopemachine Apr 14, 2026
ca71661
refactor(BA-5713): address slice E review feedback
jopemachine Apr 14, 2026
603c3d0
refactor(BA-5713): address new slice E review feedback
jopemachine Apr 14, 2026
b3067da
fix(BA-5650-E): align remaining slice E call sites with renamed fields
jopemachine Apr 14, 2026
f23ed6c
refactor(BA-5650-D): switch session repository to owner_id
jopemachine Apr 14, 2026
3ad2159
refactor(BA-5650-E): collapse scheduler signatures to owner_id
jopemachine Apr 14, 2026
4df2e52
refactor(BA-5650-F): propagate owner_id rename into sokovan
jopemachine Apr 14, 2026
65a8938
docs: rename news fragment to 11048
jopemachine Apr 14, 2026
7adff9a
docs(BA-5650): use enhance news fragment type for slice
jopemachine Apr 14, 2026
a072335
refactor(BA-5714): drop stray deployment changes and restore dropped …
jopemachine Apr 14, 2026
707fd7b
refactor(BA-5714): resolve rebase conflicts and add resolve_main_acce…
jopemachine Apr 14, 2026
efa2a91
fix(BA-5650-F): align slice F call sites with renamed sokovan fields
jopemachine Apr 14, 2026
6a8f6fa
chore: consolidate news fragments for PR #11048
jopemachine Apr 15, 2026
05c7571
fix: apply ruff format
jopemachine Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion changes/11045.enhance.md

This file was deleted.

1 change: 1 addition & 0 deletions changes/11048.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Collapse scheduler and sokovan signatures to use `owner_id` and `main_access_key`, propagate rename into sokovan handlers, coordinator, and sequencers.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ai.backend.manager.errors.resource import ProjectNotFound
from ai.backend.manager.models.group import groups
from ai.backend.manager.models.session import SessionRow
from ai.backend.manager.models.user import UserRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine


Expand All @@ -21,8 +22,13 @@ async def match_sessions_by_name(
access_key: AccessKey,
) -> list[SessionRow]:
async with self._db.begin_readonly_session(isolation_level="READ COMMITTED") as db_sess:
owner_id = await db_sess.scalar(
sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key)
)
if owner_id is None:
return []
return await SessionRow.match_sessions(
db_sess, session_name, access_key, allow_prefix=False
db_sess, session_name, owner_id=owner_id, allow_prefix=False
)

async def resolve_group_id(self, group_name: str) -> uuid.UUID:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ async def _fetch_pending_sessions(
if session_id not in sessions_map:
sessions_map[session_id] = PendingSessionData(
id=session_id,
access_key=row.access_key,
main_access_key=row.access_key,
requested_slots=row.requested_slots,
user_uuid=row.user_uuid,
owner_id=row.user_uuid,
group_id=row.group_id,
domain_name=row.domain_name,
scaling_group_name=row.scaling_group_name,
Expand Down Expand Up @@ -700,7 +700,7 @@ async def _fetch_user_policies(
"""Fetch user resource policies for users in pending sessions."""
user_policies: dict[UUID, UserResourcePolicy] = {}

if not pending_sessions.user_uuids:
if not pending_sessions.owner_ids:
return user_policies

user_policy_result = await db_sess.execute(
Expand All @@ -716,7 +716,7 @@ async def _fetch_user_policies(
KeyPairResourcePolicyRow,
KeyPairRow.resource_policy == KeyPairResourcePolicyRow.name,
)
.where(UserRow.uuid.in_(pending_sessions.user_uuids))
.where(UserRow.uuid.in_(pending_sessions.owner_ids))
)

for row in user_policy_result:
Expand Down Expand Up @@ -1149,7 +1149,7 @@ async def get_terminating_sessions_by_ids(
terminating_sessions.append(
TerminatingSessionData(
session_id=session_row.id,
access_key=AccessKey(session_row.access_key)
main_access_key=AccessKey(session_row.access_key)
if session_row.access_key
else AccessKey(""),
creation_id=session_row.creation_id or "",
Expand Down Expand Up @@ -1213,7 +1213,7 @@ async def get_pending_timeout_sessions_by_ids(
SweptSessionInfo(
session_id=row.id,
creation_id=row.creation_id,
access_key=row.access_key,
main_access_key=row.access_key,
)
)

Expand Down Expand Up @@ -1302,8 +1302,8 @@ async def enqueue_session(
id=session_data.id,
creation_id=session_data.creation_id,
name=session_data.name,
access_key=session_data.access_key,
user_uuid=session_data.user_uuid,
access_key=session_data.main_access_key,
user_uuid=session_data.owner_id,
group_id=session_data.group_id,
domain_name=session_data.domain_name,
scaling_group_name=session_data.scaling_group_name,
Expand Down Expand Up @@ -1349,8 +1349,8 @@ async def enqueue_session(
scaling_group=kernel.scaling_group,
domain_name=kernel.domain_name,
group_id=kernel.group_id,
user_uuid=kernel.user_uuid,
access_key=kernel.access_key,
user_uuid=kernel.owner_id,
access_key=kernel.main_access_key,
image=kernel.image,
architecture=kernel.architecture,
registry=kernel.registry,
Expand Down Expand Up @@ -1387,7 +1387,7 @@ async def enqueue_session(
element_type=RBACElementType.SESSION,
scope_ref=RBACElementRef(
element_type=RBACElementType.USER,
element_id=str(session_data.user_uuid),
element_id=str(session_data.owner_id),
),
additional_scope_refs=[
RBACElementRef(
Expand Down Expand Up @@ -1857,7 +1857,7 @@ async def allocate_sessions(
ScheduledSessionData(
session_id=allocation.session_id,
creation_id=creation_id,
access_key=access_key,
main_access_key=access_key,
reason="triggered-by-scheduler",
)
)
Expand Down Expand Up @@ -2917,7 +2917,9 @@ async def _get_sessions_by_statuses(
scheduled_session = ScheduledSessionData(
session_id=session.id,
creation_id=session.creation_id or "",
access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""),
main_access_key=AccessKey(session.access_key)
if session.access_key
else AccessKey(""),
reason="triggered-by-scheduler",
)
scheduled_sessions.append(scheduled_session)
Expand Down Expand Up @@ -2962,7 +2964,7 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes
ScheduledSessionData(
session_id=session.id,
creation_id=session.creation_id or "",
access_key=AccessKey(session.access_key)
main_access_key=AccessKey(session.access_key)
if session.access_key
else AccessKey(""),
reason="triggered-by-scheduler",
Expand Down Expand Up @@ -3102,7 +3104,7 @@ async def _get_sessions_for_pull(
sessions_map[session_id] = SessionDataForPull(
session_id=session_id,
creation_id=row.creation_id,
access_key=row.access_key,
main_access_key=row.access_key,
kernels=[],
)

Expand Down Expand Up @@ -3294,13 +3296,13 @@ async def _get_sessions_for_start(
SessionDataForStart(
session_id=session_info["id"],
creation_id=session_info["creation_id"],
access_key=session_info["access_key"],
main_access_key=session_info["access_key"],
session_type=session_info["session_type"],
name=session_info["name"],
cluster_mode=session_info["cluster_mode"],
kernels=kernel_bindings,
environ=session_info.get("environ", {}),
user_uuid=session_info["user_uuid"],
owner_id=session_info["user_uuid"],
user_email=user_info.email,
user_name=user_info.username,
)
Expand Down Expand Up @@ -4074,7 +4076,7 @@ async def _fetch_sessions_for_pull_by_ids(
sessions_map[session_id] = SessionDataForPull(
session_id=session_id,
creation_id=row.creation_id,
access_key=row.access_key,
main_access_key=row.access_key,
kernels=[],
)

Expand Down Expand Up @@ -4293,13 +4295,13 @@ async def _fetch_sessions_for_start_by_ids(
SessionDataForStart(
session_id=session_info["id"],
creation_id=session_info["creation_id"],
access_key=session_info["access_key"],
main_access_key=session_info["access_key"],
session_type=session_info["session_type"],
name=session_info["name"],
cluster_mode=session_info["cluster_mode"],
kernels=kernel_bindings,
environ=session_info.get("environ", {}),
user_uuid=session_info["user_uuid"],
owner_id=session_info["user_uuid"],
user_email=user_info.email,
user_name=user_info.username,
)
Expand Down Expand Up @@ -4369,7 +4371,7 @@ async def search_sessions_with_kernels(
sessions_map[row.id] = SessionDataForPull(
session_id=row.id,
creation_id=row.creation_id,
access_key=row.access_key,
main_access_key=row.access_key,
kernels=[],
)

Expand Down Expand Up @@ -4625,13 +4627,13 @@ async def search_sessions_with_kernels_and_user(
SessionDataForStart(
session_id=session_info["id"],
creation_id=session_info["creation_id"],
access_key=session_info["access_key"],
main_access_key=session_info["access_key"],
session_type=session_info["session_type"],
name=session_info["name"],
cluster_mode=session_info["cluster_mode"],
kernels=session_info["kernels"],
environ=session_info.get("environ") or {},
user_uuid=session_info["user_uuid"],
owner_id=session_info["user_uuid"],
user_email=user_info.email,
user_name=user_info.username,
)
Expand Down Expand Up @@ -4774,6 +4776,26 @@ async def get_db_now(self) -> datetime:
result = await conn.execute(sa.select(sa.func.now()))
return result.scalar_one()

async def resolve_main_access_keys(
self, session_ids: Sequence[SessionId]
) -> dict[SessionId, AccessKey]:
"""Resolve the main access key for each session's owner.

Joins ``sessions`` → ``users`` to look up the owner's
``main_access_key``. Sessions whose owner has no configured
main access key are omitted from the returned mapping.
"""
if not session_ids:
return {}
async with self._db.begin_readonly_session() as db_sess:
stmt = (
sa.select(SessionRow.id, UserRow.main_access_key)
.join(UserRow, SessionRow.user_uuid == UserRow.uuid)
.where(SessionRow.id.in_([sid for sid in session_ids]))
)
rows = (await db_sess.execute(stmt)).all()
return {SessionId(row[0]): AccessKey(row[1]) for row in rows if row[1] is not None}

async def _get_db_now_in_session(self, db_sess: SASession) -> datetime:
"""Get the current timestamp from the database within an existing session.

Expand Down
60 changes: 4 additions & 56 deletions src/ai/backend/manager/repositories/scheduler/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,58 +107,6 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:

return inner

@staticmethod
def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"%{spec.value}%")
else:
condition = SessionRow.access_key.like(f"%{spec.value}%")
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = sa.func.lower(SessionRow.access_key) == spec.value.lower()
else:
condition = SessionRow.access_key == spec.value
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"{spec.value}%")
else:
condition = SessionRow.access_key.like(f"{spec.value}%")
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"%{spec.value}")
else:
condition = SessionRow.access_key.like(f"%{spec.value}")
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
Expand Down Expand Up @@ -413,8 +361,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
return inner

@staticmethod
def by_user_uuid_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition:
"""Factory for user UUID equality filter."""
def by_owner_id_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition:
"""Factory for owner_id equality filter."""

def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.negated:
Expand All @@ -424,8 +372,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
return inner

@staticmethod
def by_user_uuid_filter_in(spec: UUIDInMatchSpec) -> QueryCondition:
"""Factory for user UUID IN filter."""
def by_owner_id_filter_in(spec: UUIDInMatchSpec) -> QueryCondition:
"""Factory for owner_id IN filter."""

def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.negated:
Expand Down
6 changes: 6 additions & 0 deletions src/ai/backend/manager/repositories/scheduler/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,9 @@ async def get_db_now(self) -> datetime:
Current database timestamp with timezone
"""
return await self._db_source.get_db_now()

async def resolve_main_access_keys(
self, session_ids: Sequence[SessionId]
) -> dict[SessionId, AccessKey]:
"""Resolve the main access key for each session's owner."""
return await self._db_source.resolve_main_access_keys(session_ids)
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ class SessionAllocation:
kernel_allocations: list[KernelAllocation]
# List of agent allocations for this session
agent_allocations: list[AgentAllocation]
# Keypair associated with the session
access_key: AccessKey
main_access_key: AccessKey
# Phases that passed during scheduling
passed_phases: list[SchedulingPredicate] = field(default_factory=list)
# Phases that failed during scheduling (normally empty for successful allocations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ class ScheduledSessionData:

session_id: SessionId
creation_id: str
access_key: AccessKey
main_access_key: AccessKey
reason: str
Loading
Loading