diff --git a/changes/11046.enhance.md b/changes/11046.enhance.md new file mode 100644 index 00000000000..0fae52d9b76 --- /dev/null +++ b/changes/11046.enhance.md @@ -0,0 +1 @@ +Add `main_access_key` resolver helpers, thread through `UserPermission`, rename `SessionData.user_uuid` to `owner_id`, and collapse `SessionRepository` signatures to take `owner_id: UUID`. diff --git a/src/ai/backend/manager/api/adapters/session.py b/src/ai/backend/manager/api/adapters/session.py index 78ae9c9f807..e04aa90abb9 100644 --- a/src/ai/backend/manager/api/adapters/session.py +++ b/src/ai/backend/manager/api/adapters/session.py @@ -933,13 +933,13 @@ def _session_data_to_node(data: SessionData) -> SessionNode: return SessionNode( id=data.id, domain_name=data.domain_name, - user_id=data.user_uuid, + user_id=data.owner_id, project_id=data.group_id, metadata=SessionMetadataInfoGQLDTO( creation_id=data.creation_id or "", name=data.name or "", session_type=data.session_type.value, - access_key=str(data.access_key) if data.access_key else "", + access_key="", cluster_mode=data.cluster_mode.name, cluster_size=data.cluster_size, priority=data.priority, @@ -1011,8 +1011,8 @@ def _kernel_info_to_node(info: KernelInfo) -> KernelNode: session_type=info.session.session_type.value, ), user_info=KernelUserInfoGQLDTO( - user_id=info.user_permission.user_uuid, - access_key=info.user_permission.access_key, + user_id=info.user_permission.owner_id, + access_key=info.user_permission.main_access_key, domain_name=info.user_permission.domain_name, group_id=info.user_permission.group_id, ), diff --git a/src/ai/backend/manager/api/gql_legacy/session.py b/src/ai/backend/manager/api/gql_legacy/session.py index 6ab06ca2dfb..6db61a21c5c 100644 --- a/src/ai/backend/manager/api/gql_legacy/session.py +++ b/src/ai/backend/manager/api/gql_legacy/session.py @@ -23,7 +23,6 @@ from ai.backend.common import validators as tx from ai.backend.common.defs.session import SESSION_PRIORITY_MAX, SESSION_PRIORITY_MIN -from ai.backend.common.exception import SessionWithInvalidStateError from ai.backend.common.types import ( ClusterMode, KernelId, @@ -395,9 +394,18 @@ def from_dataclass( cls, ctx: GraphQueryContext, session_data: SessionData, + main_access_key: str | None, *, permissions: Iterable[ComputeSessionPermission] | None = None, ) -> Self: + """Build a ``ComputeSessionNode`` from session data. + + ``main_access_key`` must be pre-resolved by the caller (typically + via ``UserRepository.get_main_access_key_by_id(session_data.owner_id)`` + or by eagerly loading ``session_data.owner``). Keeping the helper + synchronous avoids a hidden per-session DB query and lets the + caller batch the lookup across nodes. + """ status_history = session_data.status_history or {} raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name) if not session_data.vfolder_mounts: @@ -405,9 +413,6 @@ def from_dataclass( else: vfolder_mounts = [vf.vfid.folder_id for vf in session_data.vfolder_mounts] - if session_data.owner is None: - raise SessionWithInvalidStateError() - result = cls( # identity id=session_data.id, # auto-converted to Relay global ID @@ -422,9 +427,9 @@ def from_dataclass( # ownership domain_name=session_data.domain_name, project_id=session_data.group_id, - user_id=session_data.user_uuid, - access_key=session_data.access_key, - owner=UserNode.from_dataclass(ctx, session_data.owner), + user_id=session_data.owner_id, + access_key=main_access_key, + owner=UserNode.from_dataclass(ctx, session_data.owner) if session_data.owner else None, # status status=session_data.status.name, # status_changed=row.status_changed, # FIXME: generated attribute @@ -918,8 +923,14 @@ async def mutate_and_get_payload( ) ) + session_data = result.session_data + main_access_key = ( + session_data.owner.main_access_key + if session_data.owner + else await graph_ctx.user_repository.get_main_access_key_by_id(session_data.owner_id) + ) return ModifyComputeSession( - ComputeSessionNode.from_dataclass(graph_ctx, result.session_data), + ComputeSessionNode.from_dataclass(graph_ctx, session_data, main_access_key), input.get("client_mutation_id"), ) @@ -969,8 +980,16 @@ async def mutate( ) ) ) + session_data = action_result.session_data + main_access_key = ( + session_data.owner.main_access_key + if session_data.owner + else await graph_ctx.user_repository.get_main_access_key_by_id( + session_data.owner_id + ) + ) session_nodes.append( - ComputeSessionNode.from_dataclass(graph_ctx, action_result.session_data) + ComputeSessionNode.from_dataclass(graph_ctx, session_data, main_access_key) ) return CheckAndTransitStatus(session_nodes, input.get("client_mutation_id")) diff --git a/src/ai/backend/manager/data/kernel/types.py b/src/ai/backend/manager/data/kernel/types.py index 075419d536a..ff495b2aae2 100644 --- a/src/ai/backend/manager/data/kernel/types.py +++ b/src/ai/backend/manager/data/kernel/types.py @@ -207,8 +207,8 @@ class ClusterConfig: @dataclass class UserPermission: - user_uuid: UUID - access_key: str + owner_id: UUID + main_access_key: str | None domain_name: str group_id: UUID uid: int | None diff --git a/src/ai/backend/manager/data/session/types.py b/src/ai/backend/manager/data/session/types.py index 5123d528497..9a6a116b487 100644 --- a/src/ai/backend/manager/data/session/types.py +++ b/src/ai/backend/manager/data/session/types.py @@ -12,7 +12,6 @@ from ai.backend.common.data.vfolder.types import VFolderMountData from ai.backend.common.types import ( - AccessKey, CIStrEnum, ClusterMode, ResourceSlot, @@ -155,7 +154,7 @@ class SessionData: cluster_size: int domain_name: str group_id: UUID - user_uuid: UUID + owner_id: UUID occupying_slots: Any # TODO: ResourceSlot? requested_slots: Any use_host_network: bool @@ -165,7 +164,6 @@ class SessionData: num_queries: int creation_id: str | None name: str | None - access_key: AccessKey | None agent_ids: list[str] | None images: list[str] | None tag: str | None @@ -206,8 +204,7 @@ class SessionMetadata: name: str domain_name: str group_id: UUID - user_uuid: UUID - access_key: str + owner_id: UUID session_type: SessionTypes priority: int created_at: datetime | None diff --git a/src/ai/backend/manager/models/endpoint/row.py b/src/ai/backend/manager/models/endpoint/row.py index d4a94404ba3..0c452d35fdd 100644 --- a/src/ai/backend/manager/models/endpoint/row.py +++ b/src/ai/backend/manager/models/endpoint/row.py @@ -169,15 +169,12 @@ class EndpointRow(Base): # type: ignore[misc] __table_args__ = ( sa.Index( - "ix_endpoints_unique_name_when_active", + "ix_endpoints_unique_name_when_not_destroyed", "name", "domain", "project", unique=True, - postgresql_where=sa.column("lifecycle_stage").notin_([ - EndpointLifecycle.DESTROYING.value, - EndpointLifecycle.DESTROYED.value, - ]), + postgresql_where=(sa.column("lifecycle_stage") != EndpointLifecycle.DESTROYED.value), ), sa.Index( "ix_endpoints_lifecycle_sub_step", @@ -530,7 +527,6 @@ async def delegate_endpoint_ownership( db_session: AsyncSession, owner_user_uuid: UUID, target_user_uuid: UUID, - target_access_key: AccessKey, ) -> None: from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow @@ -554,7 +550,7 @@ async def delegate_endpoint_ownership( db_session, session_ids, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS ) for session_row in session_rows: - session_row.delegate_ownership(target_user_uuid, target_access_key) + session_row.delegate_ownership(target_user_uuid) async def generate_route_info( self, db_sess: AsyncSession diff --git a/src/ai/backend/manager/models/kernel/row.py b/src/ai/backend/manager/models/kernel/row.py index c5a570039a2..7332ad3ba0b 100644 --- a/src/ai/backend/manager/models/kernel/row.py +++ b/src/ai/backend/manager/models/kernel/row.py @@ -820,9 +820,8 @@ def set_status( else: self.status_data = dict(status_data) - def delegate_ownership(self, user_uuid: uuid.UUID, access_key: AccessKey) -> None: - self.user_uuid = user_uuid - self.access_key = access_key + def delegate_ownership(self, owner_id: uuid.UUID) -> None: + self.user_uuid = owner_id @classmethod async def set_kernel_status( @@ -945,8 +944,7 @@ def from_kernel_info(cls, info: KernelInfo) -> Self: agent_addr=info.resource.agent_addr, domain_name=info.user_permission.domain_name, group_id=info.user_permission.group_id, - user_uuid=info.user_permission.user_uuid, - access_key=info.user_permission.access_key, + user_uuid=info.user_permission.owner_id, image=info.image.identifier.canonical if info.image.identifier else None, architecture=info.image.identifier.architecture if info.image.identifier else None, registry=info.image.registry, @@ -1002,8 +1000,8 @@ def to_kernel_info(self) -> KernelInfo: session_type=self.session_type, ), user_permission=UserPermission( - user_uuid=self.user_uuid, - access_key=self.access_key or "", + owner_id=self.user_uuid, + main_access_key=self.user_row.main_access_key if self.user_row else None, domain_name=self.domain_name, group_id=self.group_id, uid=self.uid, @@ -1113,12 +1111,19 @@ async def recalc_concurrency_used( ) -> None: from ai.backend.manager.models.session import PRIVATE_SESSION_TYPES + # TODO(BA-5609 phase D): kernels.access_key is removed. Resolve the + # owner_id for this access_key (via users.main_access_key) and filter by + # KernelRow.user_uuid instead. The join below is a temporary shim that + # selects kernels whose owning user has main_access_key == access_key. + owner_id_subq = ( + sa.select(users.c.uuid).where(users.c.main_access_key == access_key).scalar_subquery() + ) async with db_sess.begin_nested(): result = await db_sess.execute( sa.select(sa.func.count()) .select_from(KernelRow) .where( - (KernelRow.access_key == access_key) + (KernelRow.user_uuid == owner_id_subq) & (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES)) ), @@ -1128,7 +1133,7 @@ async def recalc_concurrency_used( sa.select(sa.func.count()) .select_from(KernelRow) .where( - (KernelRow.access_key == access_key) + (KernelRow.user_uuid == owner_id_subq) & (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & (KernelRow.session_type.in_(PRIVATE_SESSION_TYPES)) ), diff --git a/src/ai/backend/manager/models/keypair/row.py b/src/ai/backend/manager/models/keypair/row.py index 99837afe8d0..2dfac89bf0e 100644 --- a/src/ai/backend/manager/models/keypair/row.py +++ b/src/ai/backend/manager/models/keypair/row.py @@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from cryptography.hazmat.primitives.hashes import SHA256 from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection -from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.sql.expression import false from ai.backend.common import msgpack @@ -31,7 +31,6 @@ if TYPE_CHECKING: from ai.backend.manager.models.resource_policy import KeyPairResourcePolicyRow from ai.backend.manager.models.scaling_group import ScalingGroupForKeypairsRow - from ai.backend.manager.models.session import SessionRow from ai.backend.manager.models.user import UserRow __all__: Sequence[str] = ( @@ -48,13 +47,6 @@ MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB -# Defined for avoiding circular import -def _get_session_row_join_condition() -> sa.ColumnElement[bool]: - from ai.backend.manager.models.session import SessionRow - - return KeyPairRow.access_key == foreign(SessionRow.access_key) - - class KeyPairRow(Base): # type: ignore[misc] __tablename__ = "keypairs" @@ -100,12 +92,6 @@ class KeyPairRow(Base): # type: ignore[misc] ) # Relationships - sessions: Mapped[list[SessionRow]] = relationship( - "SessionRow", - primaryjoin=_get_session_row_join_condition, - foreign_keys="SessionRow.access_key", - back_populates="access_key_row", - ) resource_policy_row: Mapped[KeyPairResourcePolicyRow] = relationship( "KeyPairResourcePolicyRow", back_populates="keypairs" ) diff --git a/src/ai/backend/manager/models/resource_usage.py b/src/ai/backend/manager/models/resource_usage.py index a4454bf1ac0..5f523ab9c76 100644 --- a/src/ai/backend/manager/models/resource_usage.py +++ b/src/ai/backend/manager/models/resource_usage.py @@ -524,7 +524,12 @@ async def parse_resource_usage_groups( last_stat=stat_map.get(kern.id), user_id=kern.session.user_uuid, user_email=kern.session.user.email if kern.session.user is not None else None, - access_key=kern.session.access_key, + # The old ``SessionRow.access_key`` column is being dropped in a + # later slice; source the keypair access_key from the owner's + # ``main_access_key`` instead. + access_key=( + kern.session.user.main_access_key if kern.session.user is not None else None + ), project_id=kern.session.group.id if kern.session.group is not None else None, project_name=kern.session.group.name if kern.session.group is not None else None, kernel_id=kern.id, @@ -553,7 +558,9 @@ async def parse_resource_usage_groups( SessionRow.domain_name, SessionRow.id, SessionRow.group_id, - SessionRow.access_key, + # SessionRow.access_key is deprecated (removed in a later slice); callers + # that need the keypair access_key should join UserRow and read + # users.main_access_key instead. SessionRow.images, SessionRow.cluster_mode, SessionRow.status_history, @@ -606,7 +613,12 @@ def _parse_query( session_load.options( load_only(*SESSION_RESOURCE_SELECT_COLS), joinedload(SessionRow.user).options( - load_only(UserRow.email, UserRow.username, UserRow.full_name) + load_only( + UserRow.email, + UserRow.username, + UserRow.full_name, + UserRow.main_access_key, + ) ), project_load.options(load_only(*PROJECT_RESOURCE_SELECT_COLS)), ), diff --git a/src/ai/backend/manager/models/session/conditions.py b/src/ai/backend/manager/models/session/conditions.py index 0298889188f..913a818c10d 100644 --- a/src/ai/backend/manager/models/session/conditions.py +++ b/src/ai/backend/manager/models/session/conditions.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from ai.backend.common.data.filter_specs import ( + StringInMatchSpec, StringMatchSpec, UUIDEqualMatchSpec, UUIDInMatchSpec, @@ -20,6 +21,7 @@ from ai.backend.manager.data.session.types import KernelMatchType, SessionStatus from ai.backend.manager.models.condition_utils import make_string_in_factory from ai.backend.manager.models.kernel import KernelRow +from ai.backend.manager.models.user import UserRow from ai.backend.manager.repositories.base import QueryCondition from .row import SessionRow @@ -28,6 +30,19 @@ class SessionConditions: """Query conditions for sessions.""" + @staticmethod + def _owners_where_main_access_key( + condition: sa.sql.expression.ColumnElement[bool], + ) -> sa.sql.expression.ColumnElement[bool]: + """Return a predicate matching ``SessionRow.user_uuid`` against users whose ``main_access_key`` satisfies ``condition``. + + The subquery selects ``users.uuid`` (non-null PK) so ``NOT IN`` is + well-defined. NULL ``main_access_key`` fails ``condition`` (evaluates + to NULL, not TRUE), so such users are excluded from the subquery + without needing an explicit ``IS NOT NULL`` guard. + """ + return SessionRow.user_uuid.in_(sa.select(UserRow.uuid).where(condition)) + @staticmethod def by_ids(session_ids: Collection[SessionId]) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: @@ -107,9 +122,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}%") + match = UserRow.main_access_key.ilike(f"%{spec.value}%") else: - condition = SessionRow.access_key.like(f"%{spec.value}%") + match = UserRow.main_access_key.like(f"%{spec.value}%") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -120,9 +136,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = sa.func.lower(SessionRow.access_key) == spec.value.lower() + match = sa.func.lower(UserRow.main_access_key) == spec.value.lower() else: - condition = SessionRow.access_key == spec.value + match = UserRow.main_access_key == spec.value + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -133,9 +150,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"{spec.value}%") + match = UserRow.main_access_key.ilike(f"{spec.value}%") else: - condition = SessionRow.access_key.like(f"{spec.value}%") + match = UserRow.main_access_key.like(f"{spec.value}%") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -146,16 +164,29 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}") + match = UserRow.main_access_key.ilike(f"%{spec.value}") else: - condition = SessionRow.access_key.like(f"%{spec.value}") + match = UserRow.main_access_key.like(f"%{spec.value}") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition return inner - by_access_key_in = staticmethod(make_string_in_factory(SessionRow.access_key)) + @staticmethod + def by_access_key_in(spec: StringInMatchSpec) -> QueryCondition: + def inner() -> sa.sql.expression.ColumnElement[bool]: + if spec.case_insensitive: + match = sa.func.lower(UserRow.main_access_key).in_([v.lower() for v in spec.values]) + else: + match = UserRow.main_access_key.in_(spec.values) + condition = SessionConditions._owners_where_main_access_key(match) + if spec.negated: + condition = sa.not_(condition) + return condition + + return inner @staticmethod def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition: diff --git a/src/ai/backend/manager/models/session/row.py b/src/ai/backend/manager/models/session/row.py index 4f7beeaa5cd..4cdfda3d11e 100644 --- a/src/ai/backend/manager/models/session/row.py +++ b/src/ai/backend/manager/models/session/row.py @@ -123,7 +123,6 @@ if TYPE_CHECKING: from ai.backend.manager.models.domain import DomainRow - from ai.backend.manager.models.keypair import KeyPairRow from ai.backend.manager.models.scaling_group import ScalingGroupRow from ai.backend.manager.models.user import UserRow @@ -494,17 +493,26 @@ async def handle_session_exception( def _build_session_fetch_query( base_cond: Any, - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_stale: bool = True, for_update: bool = False, do_ordering: bool = False, max_matches: int | None = None, eager_loading_op: Sequence[_AbstractLoad] | None = None, ) -> sa.sql.Select[Any]: + from ai.backend.manager.models.user import UserRow as _UserRow + cond = base_cond - if access_key: - cond = cond & (SessionRow.access_key == access_key) + if owner_id is not None: + cond = cond & (SessionRow.user_uuid == owner_id) + if owner_access_key is not None: + # Resolve the access key to its user via the users table so sessions + # filtered by the caller's main_access_key continue to work while the + # DB-level ``sessions.access_key`` column is being phased out. + owner_subq = sa.select(_UserRow.uuid).where(_UserRow.main_access_key == owner_access_key) + cond = cond & SessionRow.user_uuid.in_(owner_subq) if not allow_stale: cond = cond & (~SessionRow.status.in_(DEAD_SESSION_STATUSES)) query = ( @@ -528,8 +536,9 @@ def _build_session_fetch_query( async def _match_sessions_by_id( db_session: SASession, session_id_or_list: SessionId | list[SessionId], - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_prefix: bool = False, allow_stale: bool = True, for_update: bool = False, @@ -546,7 +555,8 @@ async def _match_sessions_by_id( cond = SessionRow.id == session_id_or_list query = _build_session_fetch_query( cond, - access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, max_matches=max_matches, allow_stale=allow_stale, for_update=for_update, @@ -560,8 +570,9 @@ async def _match_sessions_by_id( async def _match_sessions_by_name( db_session: SASession, session_name: str, - access_key: AccessKey, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_prefix: bool = False, allow_stale: bool = True, for_update: bool = False, @@ -575,7 +586,8 @@ async def _match_sessions_by_name( cond = SessionRow.name == session_name query = _build_session_fetch_query( cond, - access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, max_matches=max_matches, allow_stale=allow_stale, for_update=for_update, @@ -595,20 +607,6 @@ class ConcurrencyUsed: compute_session_ids: set[SessionId] = field(default_factory=set) system_session_ids: set[SessionId] = field(default_factory=set) - @property - def compute_concurrency_used_key(self) -> str: - return f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" - - @property - def system_concurrency_used_key(self) -> str: - return f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" - - def to_cnt_map(self) -> Mapping[str, int]: - return { - self.compute_concurrency_used_key: len(self.compute_session_ids), - self.system_concurrency_used_key: len(self.system_session_ids), - } - class SessionOp(enum.StrEnum): CREATE = "create_session" @@ -637,13 +635,6 @@ class KernelLoadingStrategy(enum.StrEnum): } -# Defined for avoiding circular import -def _get_keypair_row_join_condition() -> sa.sql.elements.ColumnElement[Any]: - from ai.backend.manager.models.keypair import KeyPairRow - - return KeyPairRow.access_key == foreign(SessionRow.access_key) - - def _get_user_row_join_condition() -> sa.sql.elements.ColumnElement[Any]: from ai.backend.manager.models.user import UserRow @@ -731,14 +722,7 @@ class SessionRow(Base): # type: ignore[misc] back_populates="sessions", foreign_keys=[user_uuid], ) - access_key: Mapped[str | None] = mapped_column("access_key", sa.String(length=20)) - access_key_row: Mapped[KeyPairRow | None] = relationship( - "KeyPairRow", - primaryjoin=_get_keypair_row_join_condition, - back_populates="sessions", - foreign_keys=[access_key], - ) # `image` column is identical to kernels `image` column. images: Mapped[list[str] | None] = mapped_column("images", sa.ARRAY(sa.String), nullable=True) @@ -884,7 +868,7 @@ class SessionRow(Base): # type: ignore[misc] sa.Index("ix_session_status_with_priority", "status", "priority"), # Unique index for session names per user excluding terminal statuses sa.Index( - "ix_sessions_unique_name_per_user_nonterminal", + "ix_sessions_unique_name_per_owner_nonterminal", "name", "user_uuid", unique=True, @@ -923,8 +907,7 @@ def from_dataclass(cls, session_data: SessionData) -> SessionRow: target_sgroup_names=session_data.target_sgroup_names, domain_name=session_data.domain_name, group_id=session_data.group_id, - user_uuid=session_data.user_uuid, - access_key=session_data.access_key, + user_uuid=session_data.owner_id, images=session_data.images, tag=session_data.tag, occupying_slots=session_data.occupying_slots, @@ -968,8 +951,7 @@ def to_dataclass(self, owner: UserData | None = None) -> SessionData: target_sgroup_names=self.target_sgroup_names, domain_name=self.domain_name, group_id=self.group_id, - user_uuid=self.user_uuid, - access_key=AccessKey(self.access_key) if self.access_key else None, + owner_id=self.user_uuid, images=self.images, tag=self.tag, occupying_slots=self.occupying_slots, @@ -1017,8 +999,7 @@ def from_session_info(cls, info: SessionInfo) -> Self: target_sgroup_names=info.resource.target_sgroup_names, domain_name=info.metadata.domain_name, group_id=info.metadata.group_id, - user_uuid=info.metadata.user_uuid, - access_key=info.metadata.access_key, + user_uuid=info.metadata.owner_id, images=info.image.images, tag=info.image.tag or info.metadata.tag, occupying_slots=info.resource.occupying_slots, @@ -1059,8 +1040,7 @@ def to_session_info(self) -> SessionInfo: name=self.name or "", domain_name=self.domain_name, group_id=self.group_id, - user_uuid=self.user_uuid, - access_key=self.access_key or "", + owner_id=self.user_uuid, session_type=self.session_type, priority=self.priority, created_at=self.created_at, @@ -1284,11 +1264,10 @@ def set_status( if _status_info is not None: self.status_info = _status_info - def delegate_ownership(self, user_uuid: UUID, access_key: AccessKey) -> None: - self.user_uuid = user_uuid - self.access_key = access_key + def delegate_ownership(self, owner_id: UUID) -> None: + self.user_uuid = owner_id for kernel_row in self.kernels: - kernel_row.delegate_ownership(user_uuid, access_key) + kernel_row.delegate_ownership(owner_id) @staticmethod async def delete_by_user_id(user_uuid: UUID, *, db_session: SASession) -> None: @@ -1357,8 +1336,9 @@ async def match_sessions( cls, db_session: SASession, session_reference: str | UUID | list[UUID], - access_key: AccessKey | None, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_prefix: bool = False, allow_stale: bool = True, for_update: bool = False, @@ -1367,7 +1347,8 @@ async def match_sessions( ) -> list[SessionRow]: """ Match the prefix of session ID or session name among the sessions - that belongs to the given access key, and return the list of SessionRow. + that belong to the given owner (``owner_id``), and return the list + of ``SessionRow``. """ if isinstance(session_reference, list): @@ -1412,7 +1393,8 @@ async def match_sessions( for fetch_func in query_list: rows = await fetch_func( db_session, - access_key=access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, allow_stale=allow_stale, for_update=for_update, max_matches=max_matches, @@ -1428,8 +1410,9 @@ async def get_session( cls, db_session: SASession, session_name_or_id: str | UUID, - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_stale: bool = False, for_update: bool = False, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE, @@ -1437,12 +1420,12 @@ async def get_session( ) -> SessionRow: """ Retrieve the session information by session's UUID, - or session's name paired with access_key. + or session's name paired with ``owner_id``. This will return the information of the session and the sibling kernel(s). :param db_session: Database connection to use when fetching row. :param session_name_or_id: Name or ID (UUID) of session to look up. - :param access_key: Access key used to create session. + :param owner_id: UUID of the session owner; required when ``session_name_or_id`` is a name. :param allow_stale: If set to True, filter "inactive" sessions as well as "active" ones. Otherwise filter "active" sessions only. :param for_update: Apply for_update during executing select query. @@ -1474,7 +1457,8 @@ async def get_session( session_list = await cls.match_sessions( db_session, session_name_or_id, - access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, allow_stale=allow_stale, for_update=for_update, eager_loading_op=_eager_loading_op, @@ -1499,8 +1483,8 @@ async def list_sessions( cls, db_session: SASession, session_ids: list[UUID], - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, allow_stale: bool = False, for_update: bool = False, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE, @@ -1531,7 +1515,7 @@ async def list_sessions( session_list = await cls.match_sessions( db_session, session_ids, - access_key, + owner_id=owner_id, allow_stale=allow_stale, for_update=for_update, eager_loading_op=_eager_loading_op, @@ -1547,8 +1531,8 @@ async def get_session_by_id( cls, db_session: SASession, session_id: SessionId, - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, max_matches: int | None = None, allow_stale: bool = True, for_update: bool = False, @@ -1557,7 +1541,7 @@ async def get_session_by_id( sessions = await _match_sessions_by_id( db_session, session_id, - access_key, + owner_id=owner_id, max_matches=max_matches, allow_stale=allow_stale, for_update=for_update, @@ -1586,7 +1570,6 @@ async def get_sgroup_managed_sessions( noload("*"), selectinload(SessionRow.group).options(noload("*")), selectinload(SessionRow.domain).options(noload("*")), - selectinload(SessionRow.access_key_row).options(noload("*")), selectinload(SessionRow.kernels).options(noload("*")), ) ) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 9fa12a3e836..d2840577bad 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -116,6 +116,7 @@ from ai.backend.manager.plugin.network import NetworkPluginContext from ai.backend.manager.repositories.resource_slot import ResourceSlotRepository from ai.backend.manager.repositories.scheduler.types.session_creation import SessionCreationSpec +from ai.backend.manager.repositories.user.repository import UserRepository from ai.backend.manager.sokovan.scheduling_controller import SchedulingController from .agent_cache import AgentRPCCache @@ -221,6 +222,7 @@ def __init__( hook_plugin_ctx: HookPluginContext, network_plugin_ctx: NetworkPluginContext, scheduling_controller: SchedulingController, + user_repository: UserRepository, *, debug: bool = False, manager_public_key: PublicKey, @@ -460,7 +462,7 @@ async def create_session( sess = await SessionRow.get_session( db_session, session_name, - owner_access_key, + owner_id=user_scope.user_uuid, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) if sess.main_kernel.image is None: @@ -687,7 +689,7 @@ async def create_cluster( await SessionRow.get_session( db_sess, session_name, - owner_access_key, + owner_id=user_scope.user_uuid, ) except SessionNotFound: pass diff --git a/src/ai/backend/manager/repositories/model_serving/repository.py b/src/ai/backend/manager/repositories/model_serving/repository.py index 20c7eca71ce..001a70d1d18 100644 --- a/src/ai/backend/manager/repositories/model_serving/repository.py +++ b/src/ai/backend/manager/repositories/model_serving/repository.py @@ -738,7 +738,7 @@ async def get_session_by_id( async with self._db.begin_readonly_session_read_committed() as session: try: return await SessionRow.get_session( - session, session_id, None, kernel_loading_strategy=kernel_loading_strategy + session, session_id, kernel_loading_strategy=kernel_loading_strategy ) except NoResultFound: return None diff --git a/src/ai/backend/manager/repositories/session/creators.py b/src/ai/backend/manager/repositories/session/creators.py index bb1605aefcc..be89fe4c3fd 100644 --- a/src/ai/backend/manager/repositories/session/creators.py +++ b/src/ai/backend/manager/repositories/session/creators.py @@ -18,7 +18,7 @@ class SessionRowCreatorSpec(CreatorSpec[SessionRow]): SessionRow instances. It simply returns the provided row in build_row(). For scope information needed by RBACEntityCreator, use the row's user_uuid - field as the scope_id with ScopeType.USER. + field (the owner's UUID) as the scope_id with ScopeType.USER. """ row: SessionRow diff --git a/src/ai/backend/manager/repositories/session/db_source/db_source.py b/src/ai/backend/manager/repositories/session/db_source/db_source.py index b544b94dd89..7eed18c3170 100644 --- a/src/ai/backend/manager/repositories/session/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/session/db_source/db_source.py @@ -63,7 +63,7 @@ async def get_session_owner(self, session_id: str | SessionId) -> UserData: async def get_session_validated( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY, allow_stale: bool = False, eager_loading_op: Sequence[_AbstractLoad] | None = None, @@ -73,7 +73,7 @@ async def get_session_validated( return await SessionRow.get_session( db_sess, session_name_or_id, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=kernel_loading_strategy, allow_stale=allow_stale, eager_loading_op=list(eager_loading_op) if eager_loading_op else None, @@ -82,13 +82,13 @@ async def get_session_validated( async def match_sessions( self, id_or_name_prefix: str, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> list[SessionRow]: async with self._db.begin_readonly_session_read_committed() as db_sess: return await SessionRow.match_sessions( db_sess, id_or_name_prefix, - owner_access_key, + owner_id=owner_id, ) async def get_session_to_determine_status( @@ -132,7 +132,7 @@ async def update_session_name( self, session_name_or_id: str | SessionId, new_name: str, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> SessionRow: async def _update(db_session: AsyncSession) -> SessionRow: # Check if new name already exists for this owner @@ -140,7 +140,7 @@ async def _update(db_session: AsyncSession) -> SessionRow: await SessionRow.get_session( db_session, new_name, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.NONE, ) raise SessionAlreadyExists(f"Session with name '{new_name}' already exists") @@ -151,7 +151,7 @@ async def _update(db_session: AsyncSession) -> SessionRow: session_row = await SessionRow.get_session( db_session, session_name_or_id, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS, ) @@ -305,13 +305,12 @@ async def modify_session( if session_row is None: raise SessionNotFound(f"Session not found (id:{session_id})") - if session_name and session_row.access_key is not None: - # Check the owner of the target session has any session with the same name + if session_name: try: sess = await SessionRow.get_session( db_session, session_name, - AccessKey(session_row.access_key), + owner_id=session_row.user_uuid, ) except SessionNotFound: pass @@ -371,7 +370,7 @@ async def _find_dependent_sessions( self, db_sess: AsyncSession, root_session_name_or_id: str | uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, allow_stale: bool = False, ) -> tuple[uuid.UUID, set[uuid.UUID]]: """ @@ -379,7 +378,7 @@ async def _find_dependent_sessions( :param db_sess: Database session :param root_session_name_or_id: Root session name or ID - :param access_key: Access key of the session owner + :param owner_id: UUID of the session owner :param allow_stale: Whether to allow stale sessions :return: Tuple of (root_session_id, set of dependent session IDs) """ @@ -401,7 +400,7 @@ async def _find_recursive_dependencies(session_id: uuid.UUID) -> set[uuid.UUID]: root_session = await SessionRow.get_session( db_sess, root_session_name_or_id, - access_key=access_key, + owner_id=owner_id, allow_stale=allow_stale, ) root_session_id = cast(uuid.UUID, root_session.id) @@ -412,14 +411,14 @@ async def _find_recursive_dependencies(session_id: uuid.UUID) -> set[uuid.UUID]: async def get_target_session_ids( self, session_name_or_id: str | uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, recursive: bool = False, ) -> list[SessionId]: """ Get list of session IDs including dependent sessions if recursive. :param session_name_or_id: Name or ID of the primary session - :param access_key: Access key of the session owner + :param owner_id: User UUID of the session owner :param recursive: If True, include dependent sessions :return: List of session IDs """ @@ -430,7 +429,7 @@ async def get_target_session_ids( root_id, dependent_ids = await self._find_dependent_sessions( db_sess, session_name_or_id, - access_key, + owner_id, allow_stale=True, ) # Return dependent sessions first, then root session @@ -441,7 +440,7 @@ async def get_target_session_ids( session = await SessionRow.get_session( db_sess, session_name_or_id, - access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.NONE, allow_stale=True, ) @@ -454,19 +453,19 @@ async def get_target_session_ids( async def find_dependency_sessions( self, session_name_or_id: uuid.UUID | str, - access_key: AccessKey, + owner_id: uuid.UUID, ) -> dict[str, list[Any] | str]: async with self._db.begin_readonly_session_read_committed() as db_sess: return await find_dependency_sessions( session_name_or_id, db_sess, - access_key, + owner_id, ) async def get_session_with_group( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY, allow_stale: bool = False, ) -> SessionRow: @@ -475,7 +474,7 @@ async def get_session_with_group( return await SessionRow.get_session( db_sess, session_name_or_id, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=kernel_loading_strategy, allow_stale=allow_stale, eager_loading_op=[selectinload(SessionRow.group)], @@ -484,14 +483,14 @@ async def get_session_with_group( async def get_session_with_routing_minimal( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> SessionRow: """Get session with minimal routing information""" async with self._db.begin_readonly_session_read_committed() as db_sess: return await SessionRow.get_session( db_sess, session_name_or_id, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, eager_loading_op=[ selectinload(SessionRow.routing).options(noload("*")), @@ -594,7 +593,7 @@ async def search_kernels( KernelListResult with items, total count, and pagination info """ async with self._db.begin_readonly_session() as db_sess: - query = sa.select(KernelRow) + query = sa.select(KernelRow).options(selectinload(KernelRow.user_row)) result = await execute_batch_querier( db_sess, diff --git a/src/ai/backend/manager/repositories/session/dependency_graph.py b/src/ai/backend/manager/repositories/session/dependency_graph.py index d9234416bb2..3ebdded333f 100644 --- a/src/ai/backend/manager/repositories/session/dependency_graph.py +++ b/src/ai/backend/manager/repositories/session/dependency_graph.py @@ -23,12 +23,15 @@ async def _find_dependency_sessions( session_name_or_id: UUID | str, db_session: SASession, - access_key: AccessKey, + owner: UUID | AccessKey, ) -> dict[str, list[Any] | str]: + owner_id: UUID | None = owner if isinstance(owner, UUID) else None + owner_access_key = owner if not isinstance(owner, UUID) else None sessions = await SessionRow.match_sessions( db_session, session_name_or_id, - access_key=access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, ) if len(sessions) < 1: @@ -66,7 +69,7 @@ async def _find_dependency_sessions( "status": str(kernel_query_result[0]), "status_changed": str(kernel_query_result[1]), "depends_on": [ - await _find_dependency_sessions(dependency_session_id, db_session, access_key) + await _find_dependency_sessions(dependency_session_id, db_session, owner) for dependency_session_id in dependency_session_ids ], } @@ -77,15 +80,15 @@ async def _find_dependency_sessions( async def find_dependency_sessions( session_name_or_id: UUID | str, db_session: SASession, - access_key: AccessKey, + owner: UUID | AccessKey, ) -> dict[str, list[Any] | str]: - return await _find_dependency_sessions(session_name_or_id, db_session, access_key) + return await _find_dependency_sessions(session_name_or_id, db_session, owner) async def find_dependent_sessions( root_session_name_or_id: str | UUID, db_session: SASession, - access_key: AccessKey, + owner_id: UUID, *, allow_stale: bool = False, ) -> set[UUID]: @@ -108,7 +111,7 @@ async def _find_dependent_sessions(session_id: UUID) -> set[UUID]: root_session = await SessionRow.get_session( db_session, root_session_name_or_id, - access_key=access_key, + owner_id=owner_id, allow_stale=allow_stale, ) return await _find_dependent_sessions(cast(UUID, root_session.id)) diff --git a/src/ai/backend/manager/repositories/session/repository.py b/src/ai/backend/manager/repositories/session/repository.py index 805604170e8..2ac82f731a8 100644 --- a/src/ai/backend/manager/repositories/session/repository.py +++ b/src/ai/backend/manager/repositories/session/repository.py @@ -57,14 +57,14 @@ async def get_session_owner(self, session_id: str | SessionId) -> UserData: async def get_session_validated( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY, allow_stale: bool = False, eager_loading_op: Sequence[_AbstractLoad] | None = None, ) -> SessionRow: return await self._db_source.get_session_validated( session_name_or_id, - owner_access_key, + owner_id, kernel_loading_strategy, allow_stale, eager_loading_op, @@ -74,9 +74,9 @@ async def get_session_validated( async def match_sessions( self, id_or_name_prefix: str, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> list[SessionRow]: - return await self._db_source.match_sessions(id_or_name_prefix, owner_access_key) + return await self._db_source.match_sessions(id_or_name_prefix, owner_id) @session_repository_resilience.apply() async def get_session_to_determine_status( @@ -104,11 +104,9 @@ async def update_session_name( self, session_name_or_id: str | SessionId, new_name: str, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> SessionRow: - return await self._db_source.update_session_name( - session_name_or_id, new_name, owner_access_key - ) + return await self._db_source.update_session_name(session_name_or_id, new_name, owner_id) @session_repository_resilience.apply() async def get_container_registry( @@ -210,52 +208,48 @@ async def query_userinfo( async def get_target_session_ids( self, session_name_or_id: str | uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, recursive: bool = False, ) -> list[SessionId]: """ Get list of session IDs including dependent sessions if recursive. :param session_name_or_id: Name or ID of the primary session - :param access_key: Access key of the session owner + :param owner_id: User UUID of the session owner :param recursive: If True, include dependent sessions :return: List of session IDs """ - return await self._db_source.get_target_session_ids( - session_name_or_id, access_key, recursive - ) + return await self._db_source.get_target_session_ids(session_name_or_id, owner_id, recursive) @session_repository_resilience.apply() async def find_dependency_sessions( self, session_name_or_id: uuid.UUID | str, - access_key: AccessKey, + owner_id: uuid.UUID, ) -> dict[str, list[Any] | str]: - return await self._db_source.find_dependency_sessions(session_name_or_id, access_key) + return await self._db_source.find_dependency_sessions(session_name_or_id, owner_id) @session_repository_resilience.apply() async def get_session_with_group( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY, allow_stale: bool = False, ) -> SessionRow: """Get session with group information eagerly loaded""" return await self._db_source.get_session_with_group( - session_name_or_id, owner_access_key, kernel_loading_strategy, allow_stale + session_name_or_id, owner_id, kernel_loading_strategy, allow_stale ) @session_repository_resilience.apply() async def get_session_with_routing_minimal( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> SessionRow: """Get session with minimal routing information""" - return await self._db_source.get_session_with_routing_minimal( - session_name_or_id, owner_access_key - ) + return await self._db_source.get_session_with_routing_minimal(session_name_or_id, owner_id) @session_repository_resilience.apply() async def search( diff --git a/src/ai/backend/manager/repositories/stream/db_source/db_source.py b/src/ai/backend/manager/repositories/stream/db_source/db_source.py index 8bfee9d611e..f49c7f3bbcc 100644 --- a/src/ai/backend/manager/repositories/stream/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/stream/db_source/db_source.py @@ -1,5 +1,9 @@ +import sqlalchemy as sa + from ai.backend.common.types import AccessKey +from ai.backend.manager.errors.kernel import SessionNotFound from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow +from ai.backend.manager.models.user import UserRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine @@ -15,9 +19,14 @@ async def get_streaming_session( access_key: AccessKey, ) -> SessionRow: async with self._db.begin_readonly_session() as db_sess: + owner_id = await db_sess.scalar( + sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key) + ) + if owner_id is None: + raise SessionNotFound(f"Unknown access_key: {access_key}") return await SessionRow.get_session( db_sess, session_name, - access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) diff --git a/src/ai/backend/manager/repositories/user/db_source/db_source.py b/src/ai/backend/manager/repositories/user/db_source/db_source.py index f3ff8eef41b..2b0d3fd70d8 100644 --- a/src/ai/backend/manager/repositories/user/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/user/db_source/db_source.py @@ -137,6 +137,13 @@ async def get_user_by_uuid(self, user_uuid: UUID) -> UserData: user_row = await self._get_user_by_uuid(db_session, user_uuid) return user_row.to_data() + async def get_main_access_key_by_id(self, user_uuid: UUID) -> str | None: + """Return the user's ``main_access_key`` or ``None`` if unset/missing.""" + async with self._db.begin_readonly_session() as db_session: + return await db_session.scalar( + sa.select(UserRow.main_access_key).where(UserRow.uuid == user_uuid) + ) + async def get_by_email_validated( self, email: str, @@ -667,11 +674,14 @@ async def delegate_endpoint_ownership( target_user_uuid: UUID, target_main_access_key: AccessKey, ) -> None: - """Delegate endpoint ownership to another user.""" + """Delegate endpoint ownership to another user. + + ``target_main_access_key`` is kept on the facade for caller compatibility + but is no longer required by ``EndpointRow.delegate_endpoint_ownership``. + """ + del target_main_access_key # unused async with self._db.begin_session() as session: - await EndpointRow.delegate_endpoint_ownership( - session, user_uuid, target_user_uuid, target_main_access_key - ) + await EndpointRow.delegate_endpoint_ownership(session, user_uuid, target_user_uuid) async def delete_endpoints( self, diff --git a/src/ai/backend/manager/repositories/user/repository.py b/src/ai/backend/manager/repositories/user/repository.py index 9df169da1f9..9907b828595 100644 --- a/src/ai/backend/manager/repositories/user/repository.py +++ b/src/ai/backend/manager/repositories/user/repository.py @@ -77,6 +77,11 @@ async def get_user_by_uuid(self, user_uuid: UUID) -> UserData: """ return await self._db_source.get_user_by_uuid(user_uuid) + @user_repository_resilience.apply() + async def get_main_access_key_by_id(self, user_uuid: UUID) -> str | None: + """Return the user's ``main_access_key`` or ``None`` if unset/missing.""" + return await self._db_source.get_main_access_key_by_id(user_uuid) + @user_repository_resilience.apply() async def get_by_email_validated( self, diff --git a/src/ai/backend/manager/sokovan/scheduler/fair_share/aggregator.py b/src/ai/backend/manager/sokovan/scheduler/fair_share/aggregator.py index 746ec15b527..67e17f5b158 100644 --- a/src/ai/backend/manager/sokovan/scheduler/fair_share/aggregator.py +++ b/src/ai/backend/manager/sokovan/scheduler/fair_share/aggregator.py @@ -484,7 +484,7 @@ def _generate_slice_specs( spec = KernelUsageRecordCreatorSpec( kernel_id=UUID(str(kernel.id)), session_id=UUID(kernel.session.session_id), - user_uuid=kernel.user_permission.user_uuid, + user_uuid=kernel.user_permission.owner_id, project_id=kernel.user_permission.group_id, domain_name=kernel.user_permission.domain_name, resource_group=scaling_group, diff --git a/tests/unit/manager/api/compute_sessions/test_handler.py b/tests/unit/manager/api/compute_sessions/test_handler.py index c66a3db2431..30a8c14a176 100644 --- a/tests/unit/manager/api/compute_sessions/test_handler.py +++ b/tests/unit/manager/api/compute_sessions/test_handler.py @@ -70,7 +70,7 @@ def create_session_data( cluster_size=1, domain_name="default", group_id=uuid4(), - user_uuid=uuid4(), + owner_id=uuid4(), occupying_slots=ResourceSlot({"cpu": Decimal("2.0"), "mem": Decimal("4294967296")}), requested_slots=ResourceSlot({"cpu": Decimal("4.0"), "mem": Decimal("8589934592")}), use_host_network=False, @@ -80,7 +80,6 @@ def create_session_data( num_queries=0, creation_id="test-creation-id", name=name, - access_key=None, agent_ids=["agent-001"], images=images or ["cr.backend.ai/stable/python:3.11"], tag=None, @@ -123,8 +122,8 @@ def create_kernel_info( session_type=SessionTypes.INTERACTIVE, ), user_permission=UserPermission( - user_uuid=uuid4(), - access_key="TESTKEY", + owner_id=uuid4(), + main_access_key="TESTKEY", domain_name="default", group_id=uuid4(), uid=None, diff --git a/tests/unit/manager/services/session/test_session_service.py b/tests/unit/manager/services/session/test_session_service.py index 20999c0e168..c6a869c49df 100644 --- a/tests/unit/manager/services/session/test_session_service.py +++ b/tests/unit/manager/services/session/test_session_service.py @@ -245,8 +245,7 @@ def sample_session_data( agent_ids=["i-ubuntu"], domain_name="default", group_id=sample_group_id, - user_uuid=sample_user_id, - access_key=sample_access_key, + owner_id=sample_user_id, images=["cr.backend.ai/stable/python:latest"], tag=None, occupying_slots=ResourceSlot({"cpu": 1, "mem": 1024}), @@ -349,8 +348,7 @@ async def test_multiple_matches( agent_ids=[], domain_name="default", group_id=sample_group_id, - user_uuid=sample_user_id, - access_key=sample_access_key, + owner_id=sample_user_id, images=["python:latest"], tag=None, occupying_slots=ResourceSlot({}), @@ -1712,8 +1710,8 @@ def sample_kernel_info(self) -> KernelInfo: session_type=SessionTypes.INTERACTIVE, ), user_permission=UserPermission( - user_uuid=user_id, - access_key="TESTKEY", + owner_id=user_id, + main_access_key="TESTKEY", domain_name="default", group_id=group_id, uid=1000, diff --git a/tests/unit/manager/sokovan/scheduler/handlers/conftest.py b/tests/unit/manager/sokovan/scheduler/handlers/conftest.py index 97cc68b2458..e1c9848b12c 100644 --- a/tests/unit/manager/sokovan/scheduler/handlers/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/handlers/conftest.py @@ -106,8 +106,7 @@ def _create_session( name=f"session-{sid}", domain_name="default", group_id=group_id, - user_uuid=user_uuid, - access_key=access_key, + owner_id=user_uuid, session_type=session_type, priority=0, created_at=now, @@ -160,8 +159,8 @@ def _create_session( session_type=session_type, ), user_permission=UserPermission( - user_uuid=user_uuid, - access_key=access_key, + owner_id=user_uuid, + main_access_key=None, domain_name="default", group_id=group_id, uid=None, @@ -261,8 +260,8 @@ def _create_kernel( session_type=SessionTypes.INTERACTIVE, ), user_permission=UserPermission( - user_uuid=user_uuid, - access_key="test-access-key", + owner_id=user_uuid, + main_access_key=None, domain_name="default", group_id=group_id, uid=None, @@ -545,7 +544,7 @@ def _create(sessions: list[SessionWithKernels]) -> ScheduleResult: ScheduledSessionData( session_id=s.session_info.identity.id, creation_id=s.session_info.identity.creation_id, - access_key=AccessKey(s.session_info.metadata.access_key), + access_key=AccessKey("test-access-key"), reason="scheduled-successfully", ) for s in sessions @@ -564,7 +563,7 @@ def _create(sessions: list[SessionWithKernels]) -> SessionsForPullWithImages: SessionDataForPull( session_id=s.session_info.identity.id, creation_id=s.session_info.identity.creation_id, - access_key=AccessKey(s.session_info.metadata.access_key), + access_key=AccessKey("test-access-key"), kernels=[ KernelBindingData( kernel_id=KernelId(k.id), @@ -609,7 +608,7 @@ def _create(sessions: list[SessionWithKernels]) -> SessionsForStartWithImages: SessionDataForStart( session_id=s.session_info.identity.id, creation_id=s.session_info.identity.creation_id, - access_key=AccessKey(s.session_info.metadata.access_key), + access_key=AccessKey("test-access-key"), session_type=s.session_info.identity.session_type, name=s.session_info.identity.name, cluster_mode=ClusterMode(s.session_info.resource.cluster_mode), @@ -624,7 +623,7 @@ def _create(sessions: list[SessionWithKernels]) -> SessionsForStartWithImages: ) for k in s.kernel_infos ], - user_uuid=s.session_info.metadata.user_uuid, + user_uuid=s.session_info.metadata.owner_id, user_email="test@example.com", user_name="test-user", environ={}, @@ -660,7 +659,7 @@ def _create(sessions: list[SessionWithKernels]) -> list[TerminatingSessionData]: return [ TerminatingSessionData( session_id=s.session_info.identity.id, - access_key=AccessKey(s.session_info.metadata.access_key), + access_key=AccessKey("test-access-key"), creation_id=s.session_info.identity.creation_id, status=s.session_info.lifecycle.status, status_info="user-requested", diff --git a/tests/unit/manager/sokovan/scheduler/terminator/conftest.py b/tests/unit/manager/sokovan/scheduler/terminator/conftest.py index f959b16b790..e33c6b66917 100644 --- a/tests/unit/manager/sokovan/scheduler/terminator/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/terminator/conftest.py @@ -213,8 +213,8 @@ def _create_kernel_info( session_type=SessionTypes.INTERACTIVE, ), user_permission=UserPermission( - user_uuid=uuid4(), - access_key="test-access-key", + owner_id=uuid4(), + main_access_key="test-access-key", domain_name="default", group_id=uuid4(), uid=None,