Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
82080dd
refactor(BA-5650): add main_access_key resolver helpers
jopemachine Apr 14, 2026
e1547c1
docs: rename news fragment to assigned PR number 11041
jopemachine Apr 14, 2026
e66536b
fix(BA-5709): keep delegate_endpoint_ownership signature in slice A
jopemachine Apr 14, 2026
1be446d
refactor(BA-5709): simplify main_access_key filter helper
jopemachine Apr 14, 2026
a866afb
refactor(BA-5650): thread main_access_key through UserPermission
jopemachine Apr 14, 2026
711f136
docs: rename news fragment to assigned PR number 11043
jopemachine Apr 14, 2026
cfdf874
fix(BA-5710): restrict slice B to UserPermission-only test changes
jopemachine Apr 14, 2026
f597c42
chore(BA-5710): drop stale misc news fragment; slice is skip:changelog
jopemachine Apr 14, 2026
3e2a72f
refactor(BA-5650-C): rename SessionData user_uuid to owner_id
jopemachine Apr 14, 2026
d151254
docs: rename news fragment to 11045
jopemachine Apr 14, 2026
09cbd11
refactor(BA-5711): address slice C review feedback
jopemachine Apr 14, 2026
12ec797
docs(BA-5711): add enhance news fragment for slice C
jopemachine Apr 14, 2026
5d3575a
fix(BA-5711): update cascaded call sites in slice C
jopemachine Apr 14, 2026
830edea
fix(BA-5650-C): make slice C typecheck independently
jopemachine Apr 14, 2026
deab30e
refactor(BA-5650-C): rename SessionData user_uuid to owner_id
jopemachine Apr 14, 2026
c79c20d
refactor(BA-5650-D): switch session repository to owner_id
jopemachine Apr 14, 2026
e0069d0
docs: rename news fragment to 11046
jopemachine Apr 14, 2026
3a01145
docs(BA-5650): use enhance news fragment type for slice
jopemachine Apr 14, 2026
d832999
refactor(BA-5712): address slice D review feedback
jopemachine Apr 14, 2026
7e6e70e
refactor(BA-5712): clean up stale slice C fragment and comment
jopemachine Apr 14, 2026
82ea2d0
fix(BA-5650-D): align remaining slice D call sites with owner_id rename
jopemachine Apr 14, 2026
f92b152
chore: consolidate news fragments for PR #11046
jopemachine Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11046.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `main_access_key` resolver helpers, thread through `UserPermission`, rename `SessionData.user_uuid` to `owner_id`, and collapse `SessionRepository` signatures to take `owner_id: UUID`.
8 changes: 4 additions & 4 deletions src/ai/backend/manager/api/adapters/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,13 +933,13 @@ def _session_data_to_node(data: SessionData) -> SessionNode:
return SessionNode(
id=data.id,
domain_name=data.domain_name,
user_id=data.user_uuid,
user_id=data.owner_id,
project_id=data.group_id,
metadata=SessionMetadataInfoGQLDTO(
creation_id=data.creation_id or "",
name=data.name or "",
session_type=data.session_type.value,
access_key=str(data.access_key) if data.access_key else "",
access_key="",
cluster_mode=data.cluster_mode.name,
cluster_size=data.cluster_size,
priority=data.priority,
Expand Down Expand Up @@ -1011,8 +1011,8 @@ def _kernel_info_to_node(info: KernelInfo) -> KernelNode:
session_type=info.session.session_type.value,
),
user_info=KernelUserInfoGQLDTO(
user_id=info.user_permission.user_uuid,
access_key=info.user_permission.access_key,
user_id=info.user_permission.owner_id,
access_key=info.user_permission.main_access_key,
domain_name=info.user_permission.domain_name,
group_id=info.user_permission.group_id,
),
Expand Down
37 changes: 28 additions & 9 deletions src/ai/backend/manager/api/gql_legacy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from ai.backend.common import validators as tx
from ai.backend.common.defs.session import SESSION_PRIORITY_MAX, SESSION_PRIORITY_MIN
from ai.backend.common.exception import SessionWithInvalidStateError
from ai.backend.common.types import (
ClusterMode,
KernelId,
Expand Down Expand Up @@ -395,19 +394,25 @@ def from_dataclass(
cls,
ctx: GraphQueryContext,
session_data: SessionData,
main_access_key: str | None,
*,
permissions: Iterable[ComputeSessionPermission] | None = None,
) -> Self:
"""Build a ``ComputeSessionNode`` from session data.

``main_access_key`` must be pre-resolved by the caller (typically
via ``UserRepository.get_main_access_key_by_id(session_data.owner_id)``
or by eagerly loading ``session_data.owner``). Keeping the helper
synchronous avoids a hidden per-session DB query and lets the
caller batch the lookup across nodes.
"""
status_history = session_data.status_history or {}
raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name)
if not session_data.vfolder_mounts:
vfolder_mounts = []
else:
vfolder_mounts = [vf.vfid.folder_id for vf in session_data.vfolder_mounts]

if session_data.owner is None:
raise SessionWithInvalidStateError()

result = cls(
# identity
id=session_data.id, # auto-converted to Relay global ID
Expand All @@ -422,9 +427,9 @@ def from_dataclass(
# ownership
domain_name=session_data.domain_name,
project_id=session_data.group_id,
user_id=session_data.user_uuid,
access_key=session_data.access_key,
owner=UserNode.from_dataclass(ctx, session_data.owner),
user_id=session_data.owner_id,
access_key=main_access_key,
owner=UserNode.from_dataclass(ctx, session_data.owner) if session_data.owner else None,
# status
status=session_data.status.name,
# status_changed=row.status_changed, # FIXME: generated attribute
Expand Down Expand Up @@ -918,8 +923,14 @@ async def mutate_and_get_payload(
)
)

session_data = result.session_data
main_access_key = (
session_data.owner.main_access_key
if session_data.owner
else await graph_ctx.user_repository.get_main_access_key_by_id(session_data.owner_id)
)
return ModifyComputeSession(
ComputeSessionNode.from_dataclass(graph_ctx, result.session_data),
ComputeSessionNode.from_dataclass(graph_ctx, session_data, main_access_key),
input.get("client_mutation_id"),
)

Expand Down Expand Up @@ -969,8 +980,16 @@ async def mutate(
)
)
)
session_data = action_result.session_data
main_access_key = (
session_data.owner.main_access_key
if session_data.owner
else await graph_ctx.user_repository.get_main_access_key_by_id(
session_data.owner_id
)
)
session_nodes.append(
ComputeSessionNode.from_dataclass(graph_ctx, action_result.session_data)
ComputeSessionNode.from_dataclass(graph_ctx, session_data, main_access_key)
)

return CheckAndTransitStatus(session_nodes, input.get("client_mutation_id"))
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/manager/data/kernel/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ class ClusterConfig:

@dataclass
class UserPermission:
user_uuid: UUID
access_key: str
owner_id: UUID
main_access_key: str | None
domain_name: str
group_id: UUID
uid: int | None
Expand Down
7 changes: 2 additions & 5 deletions src/ai/backend/manager/data/session/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from ai.backend.common.data.vfolder.types import VFolderMountData
from ai.backend.common.types import (
AccessKey,
CIStrEnum,
ClusterMode,
ResourceSlot,
Expand Down Expand Up @@ -155,7 +154,7 @@ class SessionData:
cluster_size: int
domain_name: str
group_id: UUID
user_uuid: UUID
owner_id: UUID
occupying_slots: Any # TODO: ResourceSlot?
requested_slots: Any
use_host_network: bool
Expand All @@ -165,7 +164,6 @@ class SessionData:
num_queries: int
creation_id: str | None
name: str | None
access_key: AccessKey | None
agent_ids: list[str] | None
images: list[str] | None
tag: str | None
Expand Down Expand Up @@ -206,8 +204,7 @@ class SessionMetadata:
name: str
domain_name: str
group_id: UUID
user_uuid: UUID
access_key: str
owner_id: UUID
session_type: SessionTypes
priority: int
created_at: datetime | None
Expand Down
10 changes: 3 additions & 7 deletions src/ai/backend/manager/models/endpoint/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,12 @@ class EndpointRow(Base): # type: ignore[misc]

__table_args__ = (
sa.Index(
"ix_endpoints_unique_name_when_active",
"ix_endpoints_unique_name_when_not_destroyed",
"name",
"domain",
"project",
unique=True,
postgresql_where=sa.column("lifecycle_stage").notin_([
EndpointLifecycle.DESTROYING.value,
EndpointLifecycle.DESTROYED.value,
]),
postgresql_where=(sa.column("lifecycle_stage") != EndpointLifecycle.DESTROYED.value),
),
sa.Index(
"ix_endpoints_lifecycle_sub_step",
Expand Down Expand Up @@ -530,7 +527,6 @@ async def delegate_endpoint_ownership(
db_session: AsyncSession,
owner_user_uuid: UUID,
target_user_uuid: UUID,
target_access_key: AccessKey,
) -> None:
from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow

Expand All @@ -554,7 +550,7 @@ async def delegate_endpoint_ownership(
db_session, session_ids, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS
)
for session_row in session_rows:
session_row.delegate_ownership(target_user_uuid, target_access_key)
session_row.delegate_ownership(target_user_uuid)

async def generate_route_info(
self, db_sess: AsyncSession
Expand Down
23 changes: 14 additions & 9 deletions src/ai/backend/manager/models/kernel/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,8 @@ def set_status(
else:
self.status_data = dict(status_data)

def delegate_ownership(self, user_uuid: uuid.UUID, access_key: AccessKey) -> None:
self.user_uuid = user_uuid
self.access_key = access_key
def delegate_ownership(self, owner_id: uuid.UUID) -> None:
self.user_uuid = owner_id

@classmethod
async def set_kernel_status(
Expand Down Expand Up @@ -945,8 +944,7 @@ def from_kernel_info(cls, info: KernelInfo) -> Self:
agent_addr=info.resource.agent_addr,
domain_name=info.user_permission.domain_name,
group_id=info.user_permission.group_id,
user_uuid=info.user_permission.user_uuid,
access_key=info.user_permission.access_key,
user_uuid=info.user_permission.owner_id,
image=info.image.identifier.canonical if info.image.identifier else None,
architecture=info.image.identifier.architecture if info.image.identifier else None,
registry=info.image.registry,
Expand Down Expand Up @@ -1002,8 +1000,8 @@ def to_kernel_info(self) -> KernelInfo:
session_type=self.session_type,
),
user_permission=UserPermission(
user_uuid=self.user_uuid,
access_key=self.access_key or "",
owner_id=self.user_uuid,
main_access_key=self.user_row.main_access_key if self.user_row else None,
domain_name=self.domain_name,
group_id=self.group_id,
uid=self.uid,
Expand Down Expand Up @@ -1113,12 +1111,19 @@ async def recalc_concurrency_used(
) -> None:
from ai.backend.manager.models.session import PRIVATE_SESSION_TYPES

# TODO(BA-5609 phase D): kernels.access_key is removed. Resolve the
# owner_id for this access_key (via users.main_access_key) and filter by
# KernelRow.user_uuid instead. The join below is a temporary shim that
# selects kernels whose owning user has main_access_key == access_key.
owner_id_subq = (
sa.select(users.c.uuid).where(users.c.main_access_key == access_key).scalar_subquery()
)
async with db_sess.begin_nested():
result = await db_sess.execute(
sa.select(sa.func.count())
.select_from(KernelRow)
.where(
(KernelRow.access_key == access_key)
(KernelRow.user_uuid == owner_id_subq)
& (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES))
& (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES))
),
Expand All @@ -1128,7 +1133,7 @@ async def recalc_concurrency_used(
sa.select(sa.func.count())
.select_from(KernelRow)
.where(
(KernelRow.access_key == access_key)
(KernelRow.user_uuid == owner_id_subq)
& (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES))
& (KernelRow.session_type.in_(PRIVATE_SESSION_TYPES))
),
Expand Down
16 changes: 1 addition & 15 deletions src/ai/backend/manager/models/keypair/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.hashes import SHA256
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql.expression import false

from ai.backend.common import msgpack
Expand All @@ -31,7 +31,6 @@
if TYPE_CHECKING:
from ai.backend.manager.models.resource_policy import KeyPairResourcePolicyRow
from ai.backend.manager.models.scaling_group import ScalingGroupForKeypairsRow
from ai.backend.manager.models.session import SessionRow
from ai.backend.manager.models.user import UserRow

__all__: Sequence[str] = (
Expand All @@ -48,13 +47,6 @@
MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB


# Defined for avoiding circular import
def _get_session_row_join_condition() -> sa.ColumnElement[bool]:
from ai.backend.manager.models.session import SessionRow

return KeyPairRow.access_key == foreign(SessionRow.access_key)


class KeyPairRow(Base): # type: ignore[misc]
__tablename__ = "keypairs"

Expand Down Expand Up @@ -100,12 +92,6 @@ class KeyPairRow(Base): # type: ignore[misc]
)

# Relationships
sessions: Mapped[list[SessionRow]] = relationship(
"SessionRow",
primaryjoin=_get_session_row_join_condition,
foreign_keys="SessionRow.access_key",
back_populates="access_key_row",
)
resource_policy_row: Mapped[KeyPairResourcePolicyRow] = relationship(
"KeyPairResourcePolicyRow", back_populates="keypairs"
)
Expand Down
18 changes: 15 additions & 3 deletions src/ai/backend/manager/models/resource_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,12 @@ async def parse_resource_usage_groups(
last_stat=stat_map.get(kern.id),
user_id=kern.session.user_uuid,
user_email=kern.session.user.email if kern.session.user is not None else None,
access_key=kern.session.access_key,
# The old ``SessionRow.access_key`` column is being dropped in a
# later slice; source the keypair access_key from the owner's
# ``main_access_key`` instead.
access_key=(
kern.session.user.main_access_key if kern.session.user is not None else None
),
project_id=kern.session.group.id if kern.session.group is not None else None,
project_name=kern.session.group.name if kern.session.group is not None else None,
kernel_id=kern.id,
Expand Down Expand Up @@ -553,7 +558,9 @@ async def parse_resource_usage_groups(
SessionRow.domain_name,
SessionRow.id,
SessionRow.group_id,
SessionRow.access_key,
# SessionRow.access_key is deprecated (removed in a later slice); callers
# that need the keypair access_key should join UserRow and read
# users.main_access_key instead.
SessionRow.images,
SessionRow.cluster_mode,
SessionRow.status_history,
Expand Down Expand Up @@ -606,7 +613,12 @@ def _parse_query(
session_load.options(
load_only(*SESSION_RESOURCE_SELECT_COLS),
joinedload(SessionRow.user).options(
load_only(UserRow.email, UserRow.username, UserRow.full_name)
load_only(
UserRow.email,
UserRow.username,
UserRow.full_name,
UserRow.main_access_key,
)
),
project_load.options(load_only(*PROJECT_RESOURCE_SELECT_COLS)),
),
Expand Down
Loading
Loading