diff --git a/changes/11046.enhance.md b/changes/11046.enhance.md new file mode 100644 index 00000000000..87451acbf19 --- /dev/null +++ b/changes/11046.enhance.md @@ -0,0 +1 @@ +Collapse `SessionRepository` / `SessionDBSource` signatures to take `owner_id: UUID` instead of `owner_access_key: AccessKey`. No external behavior change. diff --git a/changes/11048.enhance.md b/changes/11048.enhance.md new file mode 100644 index 00000000000..8ec0b8ce190 --- /dev/null +++ b/changes/11048.enhance.md @@ -0,0 +1 @@ +Propagate the `owner_id` / `main_access_key` signature rename into the sokovan data classes, scheduler handlers, provisioner validators, launcher, scheduling controller, and sequencers. diff --git a/changes/11050.breaking.md b/changes/11050.breaking.md new file mode 100644 index 00000000000..089e0382bff --- /dev/null +++ b/changes/11050.breaking.md @@ -0,0 +1 @@ +Remove `owner_access_key` from REST v1 session API; resolve `owner_id` via `current_user()` in the service layer. Clients must migrate to `owner_id` (user UUID) for delegation. diff --git a/src/ai/backend/common/dto/manager/session/request.py b/src/ai/backend/common/dto/manager/session/request.py index 0c783cd505e..8b3d1412476 100644 --- a/src/ai/backend/common/dto/manager/session/request.py +++ b/src/ai/backend/common/dto/manager/session/request.py @@ -135,7 +135,7 @@ class CreateFromTemplateRequest(BaseRequestModel): default=None, validation_alias=AliasChoices("callback_url", "callbackUrl", "callbackURL"), ) - owner_access_key: str | None = None + owner_id: UUID | None = None class CreateFromParamsRequest(BaseRequestModel): @@ -214,7 +214,7 @@ class CreateFromParamsRequest(BaseRequestModel): default=None, validation_alias=AliasChoices("callback_url", "callbackUrl", "callbackURL"), ) - owner_access_key: str | None = None + owner_id: UUID | None = None class CreateClusterRequest(BaseRequestModel): @@ -252,7 +252,7 @@ class CreateClusterRequest(BaseRequestModel): ge=0, validation_alias=AliasChoices("max_wait_seconds", "maxWaitSeconds"), ) - owner_access_key: str | None = None + owner_id: UUID | None = None # --------------------------------------------------------------------------- @@ -352,14 +352,11 @@ class DestroySessionRequest(BaseRequestModel): forced: bool = False recursive: bool = False - owner_access_key: str | None = None class RestartSessionRequest(BaseRequestModel): """PATCH ``/{session_name}``""" - owner_access_key: str | None = None - class MatchSessionsRequest(BaseRequestModel): """GET ``/_/match``""" @@ -419,10 +416,6 @@ class ListFilesRequest(BaseRequestModel): class GetContainerLogsRequest(BaseRequestModel): """GET ``/{session_name}/logs``""" - owner_access_key: str | None = Field( - default=None, - validation_alias=AliasChoices("owner_access_key", "ownerAccessKey"), - ) kernel_id: UUID | None = Field( default=None, validation_alias=AliasChoices("kernel_id", "kernelId"), @@ -441,5 +434,3 @@ class GetTaskLogsRequest(BaseRequestModel): class GetStatusHistoryRequest(BaseRequestModel): """GET ``/{session_name}/status-history``""" - - owner_access_key: str | None = None diff --git a/src/ai/backend/manager/api/adapters/session.py b/src/ai/backend/manager/api/adapters/session.py index 78ae9c9f807..bf07f8a3126 100644 --- a/src/ai/backend/manager/api/adapters/session.py +++ b/src/ai/backend/manager/api/adapters/session.py @@ -223,8 +223,8 @@ async def enqueue( When ``input.owner_id`` is set, the session is created on behalf of the target user: their main access key, role, and domain are used in place - of the caller's. Resolution and authorization of the delegated user - are handled by the downstream session service, not by this adapter. + of the caller's. The target user must be loadable via the user + processor (RBAC enforced). """ batch_spec: SessionBatchSpec | None = None if input.batch is not None: @@ -849,12 +849,10 @@ async def shutdown_service( self, session_id: UUID, input: ShutdownSessionServiceInput, - access_key: str, ) -> None: """Shut down a service in a session.""" action = ShutdownServiceAction( session_name=str(session_id), - owner_access_key=AccessKey(access_key), service_name=input.service, ) await self._processors.session.shutdown_service.wait_for_complete(action) @@ -866,13 +864,11 @@ async def shutdown_service( async def get_logs( self, session_id: UUID, - access_key: str, kernel_id: UUID | None = None, ) -> SessionLogsPayload: """Get container logs for a session.""" action = GetContainerLogsAction( session_name=str(session_id), - owner_access_key=AccessKey(access_key), kernel_id=KernelId(kernel_id) if kernel_id else None, ) result = await self._processors.session.get_container_logs.wait_for_complete(action) @@ -887,14 +883,12 @@ async def update( self, session_id: UUID, input: UpdateSessionInput, - access_key: str, ) -> UpdateSessionPayload: """Update session fields (currently supports rename only).""" if input.name is not None: action = RenameSessionAction( session_name=str(session_id), new_name=input.name, - owner_access_key=AccessKey(access_key), ) result = await self._processors.session.rename_session.wait_for_complete(action) return UpdateSessionPayload(session=self._session_data_to_node(result.session_data)) @@ -933,13 +927,13 @@ def _session_data_to_node(data: SessionData) -> SessionNode: return SessionNode( id=data.id, domain_name=data.domain_name, - user_id=data.user_uuid, + user_id=data.owner_id, project_id=data.group_id, metadata=SessionMetadataInfoGQLDTO( creation_id=data.creation_id or "", name=data.name or "", session_type=data.session_type.value, - access_key=str(data.access_key) if data.access_key else "", + access_key="", cluster_mode=data.cluster_mode.name, cluster_size=data.cluster_size, priority=data.priority, @@ -1011,8 +1005,8 @@ def _kernel_info_to_node(info: KernelInfo) -> KernelNode: session_type=info.session.session_type.value, ), user_info=KernelUserInfoGQLDTO( - user_id=info.user_permission.user_uuid, - access_key=info.user_permission.access_key, + user_id=info.user_permission.owner_id, + access_key=info.user_permission.main_access_key, domain_name=info.user_permission.domain_name, group_id=info.user_permission.group_id, ), diff --git a/src/ai/backend/manager/api/gql_legacy/session.py b/src/ai/backend/manager/api/gql_legacy/session.py index 6ab06ca2dfb..6db61a21c5c 100644 --- a/src/ai/backend/manager/api/gql_legacy/session.py +++ b/src/ai/backend/manager/api/gql_legacy/session.py @@ -23,7 +23,6 @@ from ai.backend.common import validators as tx from ai.backend.common.defs.session import SESSION_PRIORITY_MAX, SESSION_PRIORITY_MIN -from ai.backend.common.exception import SessionWithInvalidStateError from ai.backend.common.types import ( ClusterMode, KernelId, @@ -395,9 +394,18 @@ def from_dataclass( cls, ctx: GraphQueryContext, session_data: SessionData, + main_access_key: str | None, *, permissions: Iterable[ComputeSessionPermission] | None = None, ) -> Self: + """Build a ``ComputeSessionNode`` from session data. + + ``main_access_key`` must be pre-resolved by the caller (typically + via ``UserRepository.get_main_access_key_by_id(session_data.owner_id)`` + or by eagerly loading ``session_data.owner``). Keeping the helper + synchronous avoids a hidden per-session DB query and lets the + caller batch the lookup across nodes. + """ status_history = session_data.status_history or {} raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name) if not session_data.vfolder_mounts: @@ -405,9 +413,6 @@ def from_dataclass( else: vfolder_mounts = [vf.vfid.folder_id for vf in session_data.vfolder_mounts] - if session_data.owner is None: - raise SessionWithInvalidStateError() - result = cls( # identity id=session_data.id, # auto-converted to Relay global ID @@ -422,9 +427,9 @@ def from_dataclass( # ownership domain_name=session_data.domain_name, project_id=session_data.group_id, - user_id=session_data.user_uuid, - access_key=session_data.access_key, - owner=UserNode.from_dataclass(ctx, session_data.owner), + user_id=session_data.owner_id, + access_key=main_access_key, + owner=UserNode.from_dataclass(ctx, session_data.owner) if session_data.owner else None, # status status=session_data.status.name, # status_changed=row.status_changed, # FIXME: generated attribute @@ -918,8 +923,14 @@ async def mutate_and_get_payload( ) ) + session_data = result.session_data + main_access_key = ( + session_data.owner.main_access_key + if session_data.owner + else await graph_ctx.user_repository.get_main_access_key_by_id(session_data.owner_id) + ) return ModifyComputeSession( - ComputeSessionNode.from_dataclass(graph_ctx, result.session_data), + ComputeSessionNode.from_dataclass(graph_ctx, session_data, main_access_key), input.get("client_mutation_id"), ) @@ -969,8 +980,16 @@ async def mutate( ) ) ) + session_data = action_result.session_data + main_access_key = ( + session_data.owner.main_access_key + if session_data.owner + else await graph_ctx.user_repository.get_main_access_key_by_id( + session_data.owner_id + ) + ) session_nodes.append( - ComputeSessionNode.from_dataclass(graph_ctx, action_result.session_data) + ComputeSessionNode.from_dataclass(graph_ctx, session_data, main_access_key) ) return CheckAndTransitStatus(session_nodes, input.get("client_mutation_id")) diff --git a/src/ai/backend/manager/api/rest/session/handler.py b/src/ai/backend/manager/api/rest/session/handler.py index 407a59fcc83..0a6ac7dab65 100644 --- a/src/ai/backend/manager/api/rest/session/handler.py +++ b/src/ai/backend/manager/api/rest/session/handler.py @@ -97,9 +97,6 @@ from ai.backend.manager.services.agent.actions.sync_agent_registry import ( SyncAgentRegistryAction, ) -from ai.backend.manager.services.auth.actions.resolve_access_key_scope import ( - ResolveAccessKeyScopeAction, -) from ai.backend.manager.services.session.actions.check_and_transit_status import ( CheckAndTransitStatusAction, ) @@ -182,7 +179,6 @@ if TYPE_CHECKING: from ai.backend.manager.config.provider import ManagerConfigProvider from ai.backend.manager.services.agent.processors import AgentProcessors - from ai.backend.manager.services.auth.processors import AuthProcessors from ai.backend.manager.services.session.processors import SessionProcessors from ai.backend.manager.services.vfolder.processors.vfolder import VFolderProcessors @@ -243,13 +239,11 @@ class SessionHandler: def __init__( self, *, - auth: AuthProcessors, session: SessionProcessors, agent: AgentProcessors, vfolder: VFolderProcessors, config_provider: ManagerConfigProvider, ) -> None: - self._auth = auth self._session = session self._agent = agent self._vfolder = vfolder @@ -277,20 +271,13 @@ async def create_from_template( template=True, ) - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=params.owner_access_key, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) + owner_id = params.owner_id if params.owner_id is not None else request["user"]["uuid"] log.info( "GET_OR_CREATE (ak:{0}/{1}, img:{2}, s:{3})", requester_access_key, - owner_access_key if owner_access_key != requester_access_key else "*", + owner_id if owner_id != request["user"]["uuid"] else "*", params.image, params.session_name, ) @@ -348,7 +335,7 @@ async def create_from_template( batch_timeout=( timedelta(seconds=params.batch_timeout) if params.batch_timeout else None ), - owner_access_key=owner_access_key, + owner_id=owner_id, ), user_id=request["user"]["uuid"], user_role=request["user"]["role"], @@ -390,19 +377,12 @@ async def create_from_params( ) domain_name = params.domain or request["user"]["domain_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=params.owner_access_key, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) + owner_id = params.owner_id if params.owner_id is not None else request["user"]["uuid"] log.info( "GET_OR_CREATE (ak:{0}/{1}, img:{2}, s:{3})", requester_access_key, - owner_access_key if owner_access_key != requester_access_key else "*", + owner_id if owner_id != request["user"]["uuid"] else "*", params.image, params.session_name, ) @@ -436,7 +416,7 @@ async def create_from_params( batch_timeout=( timedelta(seconds=params.batch_timeout) if params.batch_timeout else None ), - owner_access_key=owner_access_key, + owner_id=owner_id, ), user_id=request["user"]["uuid"], user_role=request["user"]["role"], @@ -460,19 +440,12 @@ async def create_cluster( params = body.parsed domain_name = params.domain or request["user"]["domain_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=params.owner_access_key, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) + owner_id = params.owner_id if params.owner_id is not None else request["user"]["uuid"] log.info( "CREAT_CLUSTER (ak:{0}/{1}, s:{2})", requester_access_key, - owner_access_key if owner_access_key != requester_access_key else "*", + owner_id if owner_id != request["user"]["uuid"] else "*", params.session_name, ) @@ -484,7 +457,7 @@ async def create_cluster( domain_name=domain_name, group_name=params.group, requester_access_key=requester_access_key, - owner_access_key=owner_access_key, + owner_id=owner_id, scaling_group_name=params.scaling_group or "", tag=params.tag or "", session_type=params.session_type, @@ -508,28 +481,18 @@ async def match_sessions( ) -> APIResponse: request = ctx.request params = query.parsed - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) user = current_user() if user is None: raise UserNotFound("User not found in context") log.info( - "MATCH_SESSIONS(ak:{0}/{1}, prefix:{2})", + "MATCH_SESSIONS(ak:{0}, prefix:{1})", requester_access_key, - owner_access_key, params.id, ) result = await self._session.match_sessions.wait_for_complete( MatchSessionsAction( id_or_name_prefix=params.id, - owner_access_key=owner_access_key, user_id=user.user_id, ) ) @@ -546,20 +509,11 @@ async def sync_agent_registry( ) -> APIResponse: request = ctx.request params = body.parsed - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) agent_id = AgentId(params.agent) log.info( - "SYNC_AGENT_REGISTRY (ak:{}/{}, a:{})", + "SYNC_AGENT_REGISTRY (ak:{}, a:{})", requester_access_key, - owner_access_key, agent_id, ) await self._agent.sync_agent_registry.wait_for_complete( @@ -581,19 +535,10 @@ async def check_and_transit_status( session_ids = [SessionId(id_) for id_ in params.ids] user_role = cast(UserRole, request["user"]["role"]) user_id = cast(UUID, request["user"]["uuid"]) - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "TRANSIT_STATUS (ak:{}/{}, s:{})", + "TRANSIT_STATUS (ak:{}, s:{})", requester_access_key, - owner_access_key, session_ids, ) @@ -619,26 +564,16 @@ async def check_and_transit_status( async def get_info(self, ctx: RequestCtx) -> APIResponse: request = ctx.request session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "GET_INFO (ak:{0}/{1}, s:{2})", + "GET_INFO (ak:{0}, s:{1})", requester_access_key, - owner_access_key, session_name, ) try: result = await self._session.get_session_info.wait_for_complete( GetSessionInfoAction( session_name=session_name, - owner_access_key=owner_access_key, ) ) except BackendAIError: @@ -660,28 +595,18 @@ async def restart( ctx: RequestCtx, ) -> web.Response: request = ctx.request - params = query.parsed + _ = query.parsed session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=params.owner_access_key, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "RESTART (ak:{0}/{1}, s:{2})", + "RESTART (ak:{0}, s:{1})", requester_access_key, - owner_access_key, session_name, ) try: await self._session.restart_session.wait_for_complete( RestartSessionAction( session_name=session_name, - owner_access_key=owner_access_key, ) ) except BackendAIError: @@ -702,25 +627,11 @@ async def destroy( params = query.parsed session_name = request.match_info["session_name"] user_role = cast(UserRole, request["user"]["role"]) - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=params.owner_access_key, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key - if requester_access_key != owner_access_key and user_role not in ( - UserRole.ADMIN, - UserRole.SUPERADMIN, - ): - raise InsufficientPrivilege("You are not allowed to force-terminate others's sessions") + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "DESTROY (ak:{0}/{1}, s:{2}, forced:{3}, recursive: {4})", + "DESTROY (ak:{0}, s:{1}, forced:{2}, recursive: {3})", requester_access_key, - owner_access_key, session_name, params.forced, params.recursive, @@ -729,7 +640,6 @@ async def destroy( result = await self._session.destroy_session.wait_for_complete( DestroySessionAction( session_name=session_name, - owner_access_key=owner_access_key, user_role=user_role, forced=params.forced, recursive=params.recursive, @@ -749,26 +659,16 @@ async def execute( request = ctx.request params = body.parsed session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "EXECUTE(ak:{0}/{1}, s:{2})", + "EXECUTE(ak:{0}, s:{1})", requester_access_key, - owner_access_key, session_name, ) result = await self._session.execute_session.wait_for_complete( ExecuteSessionAction( session_name=session_name, - owner_access_key=owner_access_key, api_version=request["api_version"], params=ExecuteSessionActionParams( mode=params.mode, @@ -787,26 +687,16 @@ async def execute( async def interrupt(self, ctx: RequestCtx) -> web.Response: request = ctx.request session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "INTERRUPT(ak:{0}/{1}, s:{2})", + "INTERRUPT(ak:{0}, s:{1})", requester_access_key, - owner_access_key, session_name, ) try: await self._session.interrupt.wait_for_complete( InterruptSessionAction( session_name=session_name, - owner_access_key=owner_access_key, ) ) except BackendAIError: @@ -826,26 +716,16 @@ async def complete( request = ctx.request params = body.parsed session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "COMPLETE(ak:{0}/{1}, s:{2})", + "COMPLETE(ak:{0}, s:{1})", requester_access_key, - owner_access_key, session_name, ) action_result = await self._session.complete.wait_for_complete( CompleteAction( session_name=session_name, - owner_access_key=owner_access_key, code=params.code or "", options=params.options or {}, ) @@ -899,26 +779,16 @@ async def shutdown_service( request = ctx.request params = body.parsed session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "SHUTDOWN_SERVICE (ak:{0}/{1}, s:{2})", + "SHUTDOWN_SERVICE (ak:{0}, s:{1})", requester_access_key, - owner_access_key, session_name, ) try: await self._session.shutdown_service.wait_for_complete( ShutdownServiceAction( session_name=session_name, - owner_access_key=owner_access_key, service_name=params.service_name, ) ) @@ -935,26 +805,16 @@ async def upload_files(self, ctx: RequestCtx) -> web.Response: request = ctx.request reader = await request.multipart() session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "UPLOAD_FILE (ak:{0}/{1}, s:{2})", + "UPLOAD_FILE (ak:{0}, s:{1})", requester_access_key, - owner_access_key, session_name, ) try: await self._session.upload_files.wait_for_complete( UploadFilesAction( session_name=session_name, - owner_access_key=owner_access_key, reader=reader, ) ) @@ -975,26 +835,16 @@ async def download_files( request = ctx.request params = body.parsed session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "DOWNLOAD_FILE (ak:{0}/{1}, s:{2}, path:{3!r})", + "DOWNLOAD_FILE (ak:{0}, s:{1}, path:{2!r})", requester_access_key, - owner_access_key, session_name, params.files[0], ) result = await self._session.download_files.wait_for_complete( DownloadFilesAction( user_id=request["user"]["uuid"], - owner_access_key=owner_access_key, session_name=session_name, files=params.files, ) @@ -1013,19 +863,10 @@ async def download_single( request = ctx.request params = query.parsed session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "DOWNLOAD_SINGLE (ak:{0}/{1}, s:{2}, path:{3!r})", + "DOWNLOAD_SINGLE (ak:{0}, s:{1}, path:{2!r})", requester_access_key, - owner_access_key, session_name, params.file, ) @@ -1033,7 +874,6 @@ async def download_single( DownloadFileAction( user_id=request["user"]["uuid"], session_name=session_name, - owner_access_key=owner_access_key, file=params.file, ) ) @@ -1051,19 +891,10 @@ async def list_files( request = ctx.request params = query.parsed session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "LIST_FILES (ak:{0}/{1}, s:{2}, path:{3})", + "LIST_FILES (ak:{0}, s:{1}, path:{2})", requester_access_key, - owner_access_key, session_name, params.path, ) @@ -1072,7 +903,6 @@ async def list_files( user_id=request["user"]["uuid"], path=params.path, session_name=session_name, - owner_access_key=owner_access_key, ) ) return APIResponse.build(HTTPStatus.OK, ListFilesResponse(dict(result.result))) @@ -1090,19 +920,10 @@ async def rename_session( params = body.parsed session_name = request.match_info["session_name"] new_name = params.session_name - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "RENAME_SESSION (ak:{0}/{1}, s:{2}, newname:{3})", + "RENAME_SESSION (ak:{0}, s:{1}, newname:{2})", requester_access_key, - owner_access_key, session_name, new_name, ) @@ -1110,7 +931,6 @@ async def rename_session( RenameSessionAction( session_name=session_name, new_name=new_name, - owner_access_key=owner_access_key, ) ) return web.Response(status=HTTPStatus.NO_CONTENT) @@ -1127,25 +947,15 @@ async def commit_session( request = ctx.request params = query.parsed session_name: str = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "COMMIT_SESSION (ak:{}/{}, s:{})", + "COMMIT_SESSION (ak:{}, s:{})", requester_access_key, - owner_access_key, session_name, ) action_result = await self._session.commit_session.wait_for_complete( CommitSessionAction( session_name=session_name, - owner_access_key=owner_access_key, filename=params.filename, ) ) @@ -1166,25 +976,15 @@ async def convert_session_to_image( request = ctx.request params = body.parsed session_name: str = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "CONVERT_SESSION_TO_IMAGE (ak:{}/{}, s:{})", + "CONVERT_SESSION_TO_IMAGE (ak:{}, s:{})", requester_access_key, - owner_access_key, session_name, ) result = await self._session.convert_session_to_image.wait_for_complete( ConvertSessionToImageAction( session_name=session_name, - owner_access_key=owner_access_key, image_name=params.image_name, image_visibility=params.image_visibility, image_owner_id=request["user"]["uuid"], @@ -1210,28 +1010,18 @@ async def get_commit_status( ) -> APIResponse: request = ctx.request session_name: str = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) myself = asyncio.current_task() if myself is None: raise NoCurrentTaskContext("No current task context") log.info( - "GET_COMMIT_STATUS (ak:{}/{}, s:{})", + "GET_COMMIT_STATUS (ak:{}, s:{})", requester_access_key, - owner_access_key, session_name, ) result = await self._session.get_commit_status.wait_for_complete( GetCommitStatusAction( session_name=session_name, - owner_access_key=owner_access_key, ) ) return APIResponse.build( @@ -1250,25 +1040,15 @@ async def get_abusing_report( ) -> APIResponse: request = ctx.request session_name: str = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "GET_ABUSING_REPORT (ak:{}/{}, s:{})", + "GET_ABUSING_REPORT (ak:{}, s:{})", requester_access_key, - owner_access_key, session_name, ) result = await self._session.get_abusing_report.wait_for_complete( GetAbusingReportAction( session_name=session_name, - owner_access_key=owner_access_key, ) ) return APIResponse.build( @@ -1288,27 +1068,17 @@ async def get_status_history( ctx: RequestCtx, ) -> APIResponse: request = ctx.request - params = query.parsed + _ = query.parsed session_name: str = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=params.owner_access_key, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "GET_STATUS_HISTORY (ak:{}/{}, s:{})", + "GET_STATUS_HISTORY (ak:{}, s:{})", requester_access_key, - owner_access_key, session_name, ) result = await self._session.get_status_history.wait_for_complete( GetStatusHistoryAction( session_name=session_name, - owner_access_key=request["keypair"]["access_key"], ) ) return APIResponse.build( @@ -1323,19 +1093,9 @@ async def get_status_history( async def get_direct_access_info(self, ctx: RequestCtx) -> APIResponse: request = ctx.request session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - owner_access_key = scope.owner_access_key result = await self._session.get_direct_access_info.wait_for_complete( GetDirectAccessInfoAction( session_name=session_name, - owner_access_key=owner_access_key, ) ) return APIResponse.build( @@ -1355,20 +1115,11 @@ async def get_container_logs( request = ctx.request params = query.parsed session_name: str = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=params.owner_access_key, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) kernel_id = KernelId(params.kernel_id) if params.kernel_id is not None else None log.info( - "GET_CONTAINER_LOG (ak:{}/{}, s:{}, k:{})", + "GET_CONTAINER_LOG (ak:{}, s:{}, k:{})", requester_access_key, - owner_access_key, session_name, kernel_id, ) @@ -1376,15 +1127,13 @@ async def get_container_logs( result = await self._session.get_container_logs.wait_for_complete( GetContainerLogsAction( session_name=session_name, - owner_access_key=owner_access_key, kernel_id=kernel_id, ) ) except BackendAIError: log.exception( - "GET_CONTAINER_LOG(ak:{}/{}, kernel_id: {}, s:{}): unexpected error", + "GET_CONTAINER_LOG(ak:{}, kernel_id: {}, s:{}): unexpected error", requester_access_key, - owner_access_key, kernel_id, session_name, ) @@ -1433,25 +1182,15 @@ async def get_task_logs( async def get_dependency_graph(self, ctx: RequestCtx) -> APIResponse: request = ctx.request root_session_name = request.match_info["session_name"] - scope = await self._auth.resolve_access_key_scope.wait_for_complete( - ResolveAccessKeyScopeAction( - requester_access_key=request["keypair"]["access_key"], - requester_role=request["user"]["role"], - requester_domain=request["user"]["domain_name"], - owner_access_key=None, - ) - ) - requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + requester_access_key = AccessKey(request["keypair"]["access_key"]) log.info( - "GET_DEPENDENCY_GRAPH (ak:{0}/{1}, s:{2})", + "GET_DEPENDENCY_GRAPH (ak:{0}, s:{1})", requester_access_key, - owner_access_key, root_session_name, ) result = await self._session.get_dependency_graph.wait_for_complete( GetDependencyGraphAction( root_session_name=root_session_name, - owner_access_key=owner_access_key, ) ) return APIResponse.build( diff --git a/src/ai/backend/manager/api/rest/tree.py b/src/ai/backend/manager/api/rest/tree.py index 532e35a9299..b7ef8abf61c 100644 --- a/src/ai/backend/manager/api/rest/tree.py +++ b/src/ai/backend/manager/api/rest/tree.py @@ -206,7 +206,6 @@ def build_api_routes( model_serving_auto_scaling=processors.model_serving_auto_scaling, ) session_handler = SessionHandler( - auth=processors.auth, session=processors.session, agent=processors.agent, vfolder=processors.vfolder, diff --git a/src/ai/backend/manager/api/rest/v2/session/handler.py b/src/ai/backend/manager/api/rest/v2/session/handler.py index 0eca25bd8c8..359278d0b22 100644 --- a/src/ai/backend/manager/api/rest/v2/session/handler.py +++ b/src/ai/backend/manager/api/rest/v2/session/handler.py @@ -153,38 +153,30 @@ async def start_service( async def shutdown_service( self, - user_ctx: UserContext, path: PathParam[SessionIdPathParamDTO], body: BodyParam[ShutdownSessionServiceInput], ) -> APIResponse: """Shut down a service in a session.""" - await self._adapter.shutdown_service( - path.parsed.session_id, body.parsed, access_key=user_ctx.access_key - ) + await self._adapter.shutdown_service(path.parsed.session_id, body.parsed) return APIResponse.no_content(status_code=HTTPStatus.NO_CONTENT) async def get_logs( self, - user_ctx: UserContext, path: PathParam[SessionIdPathParamDTO], query: QueryParam[GetSessionLogsQuery], ) -> APIResponse: """Get container logs for a session.""" result = await self._adapter.get_logs( path.parsed.session_id, - access_key=user_ctx.access_key, kernel_id=query.parsed.kernel_id, ) return APIResponse.build(status_code=HTTPStatus.OK, response_model=result) async def update( self, - user_ctx: UserContext, path: PathParam[SessionIdPathParamDTO], body: BodyParam[UpdateSessionInput], ) -> APIResponse: """Update a session.""" - result = await self._adapter.update( - path.parsed.session_id, body.parsed, access_key=user_ctx.access_key - ) + result = await self._adapter.update(path.parsed.session_id, body.parsed) return APIResponse.build(status_code=HTTPStatus.OK, response_model=result) diff --git a/src/ai/backend/manager/data/kernel/types.py b/src/ai/backend/manager/data/kernel/types.py index 075419d536a..ff495b2aae2 100644 --- a/src/ai/backend/manager/data/kernel/types.py +++ b/src/ai/backend/manager/data/kernel/types.py @@ -207,8 +207,8 @@ class ClusterConfig: @dataclass class UserPermission: - user_uuid: UUID - access_key: str + owner_id: UUID + main_access_key: str | None domain_name: str group_id: UUID uid: int | None diff --git a/src/ai/backend/manager/data/session/types.py b/src/ai/backend/manager/data/session/types.py index 5123d528497..9a6a116b487 100644 --- a/src/ai/backend/manager/data/session/types.py +++ b/src/ai/backend/manager/data/session/types.py @@ -12,7 +12,6 @@ from ai.backend.common.data.vfolder.types import VFolderMountData from ai.backend.common.types import ( - AccessKey, CIStrEnum, ClusterMode, ResourceSlot, @@ -155,7 +154,7 @@ class SessionData: cluster_size: int domain_name: str group_id: UUID - user_uuid: UUID + owner_id: UUID occupying_slots: Any # TODO: ResourceSlot? requested_slots: Any use_host_network: bool @@ -165,7 +164,6 @@ class SessionData: num_queries: int creation_id: str | None name: str | None - access_key: AccessKey | None agent_ids: list[str] | None images: list[str] | None tag: str | None @@ -206,8 +204,7 @@ class SessionMetadata: name: str domain_name: str group_id: UUID - user_uuid: UUID - access_key: str + owner_id: UUID session_type: SessionTypes priority: int created_at: datetime | None diff --git a/src/ai/backend/manager/dependencies/agents/composer.py b/src/ai/backend/manager/dependencies/agents/composer.py index 5228e1c837b..ff88bd0afd7 100644 --- a/src/ai/backend/manager/dependencies/agents/composer.py +++ b/src/ai/backend/manager/dependencies/agents/composer.py @@ -20,6 +20,7 @@ from ai.backend.manager.registry import AgentRegistry from ai.backend.manager.repositories.deployment.repository import DeploymentRepository from ai.backend.manager.repositories.scheduler.repository import SchedulerRepository +from ai.backend.manager.repositories.user.repository import UserRepository from ai.backend.manager.sokovan.deployment.deployment_controller import DeploymentController from ai.backend.manager.sokovan.deployment.revision_generator.registry import ( RevisionGeneratorRegistry, @@ -188,6 +189,7 @@ async def compose( hook_plugin_ctx=setup_input.hook_plugin_ctx, network_plugin_ctx=setup_input.network_plugin_ctx, scheduling_controller=scheduling_controller, + user_repository=UserRepository(setup_input.db), debug=setup_input.config_provider.config.debug.enabled, manager_public_key=setup_input.agent_cache.manager_public_key, manager_secret_key=setup_input.agent_cache.manager_secret_key, diff --git a/src/ai/backend/manager/dependencies/agents/registry.py b/src/ai/backend/manager/dependencies/agents/registry.py index 57ce148dfdf..ecfd9066717 100644 --- a/src/ai/backend/manager/dependencies/agents/registry.py +++ b/src/ai/backend/manager/dependencies/agents/registry.py @@ -19,6 +19,7 @@ from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.plugin.network import NetworkPluginContext from ai.backend.manager.registry import AgentRegistry +from ai.backend.manager.repositories.user.repository import UserRepository from ai.backend.manager.sokovan.scheduling_controller.scheduling_controller import ( SchedulingController, ) @@ -41,6 +42,7 @@ class AgentRegistryInput: hook_plugin_ctx: HookPluginContext network_plugin_ctx: NetworkPluginContext scheduling_controller: SchedulingController + user_repository: UserRepository debug: bool manager_public_key: PublicKey manager_secret_key: SecretKey @@ -82,6 +84,7 @@ async def provide(self, setup_input: AgentRegistryInput) -> AsyncIterator[AgentR setup_input.hook_plugin_ctx, setup_input.network_plugin_ctx, setup_input.scheduling_controller, + setup_input.user_repository, debug=setup_input.debug, manager_public_key=setup_input.manager_public_key, manager_secret_key=setup_input.manager_secret_key, diff --git a/src/ai/backend/manager/models/endpoint/row.py b/src/ai/backend/manager/models/endpoint/row.py index d4a94404ba3..0c452d35fdd 100644 --- a/src/ai/backend/manager/models/endpoint/row.py +++ b/src/ai/backend/manager/models/endpoint/row.py @@ -169,15 +169,12 @@ class EndpointRow(Base): # type: ignore[misc] __table_args__ = ( sa.Index( - "ix_endpoints_unique_name_when_active", + "ix_endpoints_unique_name_when_not_destroyed", "name", "domain", "project", unique=True, - postgresql_where=sa.column("lifecycle_stage").notin_([ - EndpointLifecycle.DESTROYING.value, - EndpointLifecycle.DESTROYED.value, - ]), + postgresql_where=(sa.column("lifecycle_stage") != EndpointLifecycle.DESTROYED.value), ), sa.Index( "ix_endpoints_lifecycle_sub_step", @@ -530,7 +527,6 @@ async def delegate_endpoint_ownership( db_session: AsyncSession, owner_user_uuid: UUID, target_user_uuid: UUID, - target_access_key: AccessKey, ) -> None: from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow @@ -554,7 +550,7 @@ async def delegate_endpoint_ownership( db_session, session_ids, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS ) for session_row in session_rows: - session_row.delegate_ownership(target_user_uuid, target_access_key) + session_row.delegate_ownership(target_user_uuid) async def generate_route_info( self, db_sess: AsyncSession diff --git a/src/ai/backend/manager/models/kernel/row.py b/src/ai/backend/manager/models/kernel/row.py index c5a570039a2..7332ad3ba0b 100644 --- a/src/ai/backend/manager/models/kernel/row.py +++ b/src/ai/backend/manager/models/kernel/row.py @@ -820,9 +820,8 @@ def set_status( else: self.status_data = dict(status_data) - def delegate_ownership(self, user_uuid: uuid.UUID, access_key: AccessKey) -> None: - self.user_uuid = user_uuid - self.access_key = access_key + def delegate_ownership(self, owner_id: uuid.UUID) -> None: + self.user_uuid = owner_id @classmethod async def set_kernel_status( @@ -945,8 +944,7 @@ def from_kernel_info(cls, info: KernelInfo) -> Self: agent_addr=info.resource.agent_addr, domain_name=info.user_permission.domain_name, group_id=info.user_permission.group_id, - user_uuid=info.user_permission.user_uuid, - access_key=info.user_permission.access_key, + user_uuid=info.user_permission.owner_id, image=info.image.identifier.canonical if info.image.identifier else None, architecture=info.image.identifier.architecture if info.image.identifier else None, registry=info.image.registry, @@ -1002,8 +1000,8 @@ def to_kernel_info(self) -> KernelInfo: session_type=self.session_type, ), user_permission=UserPermission( - user_uuid=self.user_uuid, - access_key=self.access_key or "", + owner_id=self.user_uuid, + main_access_key=self.user_row.main_access_key if self.user_row else None, domain_name=self.domain_name, group_id=self.group_id, uid=self.uid, @@ -1113,12 +1111,19 @@ async def recalc_concurrency_used( ) -> None: from ai.backend.manager.models.session import PRIVATE_SESSION_TYPES + # TODO(BA-5609 phase D): kernels.access_key is removed. Resolve the + # owner_id for this access_key (via users.main_access_key) and filter by + # KernelRow.user_uuid instead. The join below is a temporary shim that + # selects kernels whose owning user has main_access_key == access_key. + owner_id_subq = ( + sa.select(users.c.uuid).where(users.c.main_access_key == access_key).scalar_subquery() + ) async with db_sess.begin_nested(): result = await db_sess.execute( sa.select(sa.func.count()) .select_from(KernelRow) .where( - (KernelRow.access_key == access_key) + (KernelRow.user_uuid == owner_id_subq) & (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & (KernelRow.session_type.not_in(PRIVATE_SESSION_TYPES)) ), @@ -1128,7 +1133,7 @@ async def recalc_concurrency_used( sa.select(sa.func.count()) .select_from(KernelRow) .where( - (KernelRow.access_key == access_key) + (KernelRow.user_uuid == owner_id_subq) & (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & (KernelRow.session_type.in_(PRIVATE_SESSION_TYPES)) ), diff --git a/src/ai/backend/manager/models/keypair/row.py b/src/ai/backend/manager/models/keypair/row.py index 99837afe8d0..2dfac89bf0e 100644 --- a/src/ai/backend/manager/models/keypair/row.py +++ b/src/ai/backend/manager/models/keypair/row.py @@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from cryptography.hazmat.primitives.hashes import SHA256 from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection -from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.sql.expression import false from ai.backend.common import msgpack @@ -31,7 +31,6 @@ if TYPE_CHECKING: from ai.backend.manager.models.resource_policy import KeyPairResourcePolicyRow from ai.backend.manager.models.scaling_group import ScalingGroupForKeypairsRow - from ai.backend.manager.models.session import SessionRow from ai.backend.manager.models.user import UserRow __all__: Sequence[str] = ( @@ -48,13 +47,6 @@ MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB -# Defined for avoiding circular import -def _get_session_row_join_condition() -> sa.ColumnElement[bool]: - from ai.backend.manager.models.session import SessionRow - - return KeyPairRow.access_key == foreign(SessionRow.access_key) - - class KeyPairRow(Base): # type: ignore[misc] __tablename__ = "keypairs" @@ -100,12 +92,6 @@ class KeyPairRow(Base): # type: ignore[misc] ) # Relationships - sessions: Mapped[list[SessionRow]] = relationship( - "SessionRow", - primaryjoin=_get_session_row_join_condition, - foreign_keys="SessionRow.access_key", - back_populates="access_key_row", - ) resource_policy_row: Mapped[KeyPairResourcePolicyRow] = relationship( "KeyPairResourcePolicyRow", back_populates="keypairs" ) diff --git a/src/ai/backend/manager/models/resource_usage.py b/src/ai/backend/manager/models/resource_usage.py index a4454bf1ac0..5f523ab9c76 100644 --- a/src/ai/backend/manager/models/resource_usage.py +++ b/src/ai/backend/manager/models/resource_usage.py @@ -524,7 +524,12 @@ async def parse_resource_usage_groups( last_stat=stat_map.get(kern.id), user_id=kern.session.user_uuid, user_email=kern.session.user.email if kern.session.user is not None else None, - access_key=kern.session.access_key, + # The old ``SessionRow.access_key`` column is being dropped in a + # later slice; source the keypair access_key from the owner's + # ``main_access_key`` instead. + access_key=( + kern.session.user.main_access_key if kern.session.user is not None else None + ), project_id=kern.session.group.id if kern.session.group is not None else None, project_name=kern.session.group.name if kern.session.group is not None else None, kernel_id=kern.id, @@ -553,7 +558,9 @@ async def parse_resource_usage_groups( SessionRow.domain_name, SessionRow.id, SessionRow.group_id, - SessionRow.access_key, + # SessionRow.access_key is deprecated (removed in a later slice); callers + # that need the keypair access_key should join UserRow and read + # users.main_access_key instead. SessionRow.images, SessionRow.cluster_mode, SessionRow.status_history, @@ -606,7 +613,12 @@ def _parse_query( session_load.options( load_only(*SESSION_RESOURCE_SELECT_COLS), joinedload(SessionRow.user).options( - load_only(UserRow.email, UserRow.username, UserRow.full_name) + load_only( + UserRow.email, + UserRow.username, + UserRow.full_name, + UserRow.main_access_key, + ) ), project_load.options(load_only(*PROJECT_RESOURCE_SELECT_COLS)), ), diff --git a/src/ai/backend/manager/models/session/conditions.py b/src/ai/backend/manager/models/session/conditions.py index 0298889188f..913a818c10d 100644 --- a/src/ai/backend/manager/models/session/conditions.py +++ b/src/ai/backend/manager/models/session/conditions.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from ai.backend.common.data.filter_specs import ( + StringInMatchSpec, StringMatchSpec, UUIDEqualMatchSpec, UUIDInMatchSpec, @@ -20,6 +21,7 @@ from ai.backend.manager.data.session.types import KernelMatchType, SessionStatus from ai.backend.manager.models.condition_utils import make_string_in_factory from ai.backend.manager.models.kernel import KernelRow +from ai.backend.manager.models.user import UserRow from ai.backend.manager.repositories.base import QueryCondition from .row import SessionRow @@ -28,6 +30,19 @@ class SessionConditions: """Query conditions for sessions.""" + @staticmethod + def _owners_where_main_access_key( + condition: sa.sql.expression.ColumnElement[bool], + ) -> sa.sql.expression.ColumnElement[bool]: + """Return a predicate matching ``SessionRow.user_uuid`` against users whose ``main_access_key`` satisfies ``condition``. + + The subquery selects ``users.uuid`` (non-null PK) so ``NOT IN`` is + well-defined. NULL ``main_access_key`` fails ``condition`` (evaluates + to NULL, not TRUE), so such users are excluded from the subquery + without needing an explicit ``IS NOT NULL`` guard. + """ + return SessionRow.user_uuid.in_(sa.select(UserRow.uuid).where(condition)) + @staticmethod def by_ids(session_ids: Collection[SessionId]) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: @@ -107,9 +122,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}%") + match = UserRow.main_access_key.ilike(f"%{spec.value}%") else: - condition = SessionRow.access_key.like(f"%{spec.value}%") + match = UserRow.main_access_key.like(f"%{spec.value}%") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -120,9 +136,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = sa.func.lower(SessionRow.access_key) == spec.value.lower() + match = sa.func.lower(UserRow.main_access_key) == spec.value.lower() else: - condition = SessionRow.access_key == spec.value + match = UserRow.main_access_key == spec.value + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -133,9 +150,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"{spec.value}%") + match = UserRow.main_access_key.ilike(f"{spec.value}%") else: - condition = SessionRow.access_key.like(f"{spec.value}%") + match = UserRow.main_access_key.like(f"{spec.value}%") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -146,16 +164,29 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}") + match = UserRow.main_access_key.ilike(f"%{spec.value}") else: - condition = SessionRow.access_key.like(f"%{spec.value}") + match = UserRow.main_access_key.like(f"%{spec.value}") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition return inner - by_access_key_in = staticmethod(make_string_in_factory(SessionRow.access_key)) + @staticmethod + def by_access_key_in(spec: StringInMatchSpec) -> QueryCondition: + def inner() -> sa.sql.expression.ColumnElement[bool]: + if spec.case_insensitive: + match = sa.func.lower(UserRow.main_access_key).in_([v.lower() for v in spec.values]) + else: + match = UserRow.main_access_key.in_(spec.values) + condition = SessionConditions._owners_where_main_access_key(match) + if spec.negated: + condition = sa.not_(condition) + return condition + + return inner @staticmethod def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition: diff --git a/src/ai/backend/manager/models/session/row.py b/src/ai/backend/manager/models/session/row.py index 4f7beeaa5cd..4cdfda3d11e 100644 --- a/src/ai/backend/manager/models/session/row.py +++ b/src/ai/backend/manager/models/session/row.py @@ -123,7 +123,6 @@ if TYPE_CHECKING: from ai.backend.manager.models.domain import DomainRow - from ai.backend.manager.models.keypair import KeyPairRow from ai.backend.manager.models.scaling_group import ScalingGroupRow from ai.backend.manager.models.user import UserRow @@ -494,17 +493,26 @@ async def handle_session_exception( def _build_session_fetch_query( base_cond: Any, - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_stale: bool = True, for_update: bool = False, do_ordering: bool = False, max_matches: int | None = None, eager_loading_op: Sequence[_AbstractLoad] | None = None, ) -> sa.sql.Select[Any]: + from ai.backend.manager.models.user import UserRow as _UserRow + cond = base_cond - if access_key: - cond = cond & (SessionRow.access_key == access_key) + if owner_id is not None: + cond = cond & (SessionRow.user_uuid == owner_id) + if owner_access_key is not None: + # Resolve the access key to its user via the users table so sessions + # filtered by the caller's main_access_key continue to work while the + # DB-level ``sessions.access_key`` column is being phased out. + owner_subq = sa.select(_UserRow.uuid).where(_UserRow.main_access_key == owner_access_key) + cond = cond & SessionRow.user_uuid.in_(owner_subq) if not allow_stale: cond = cond & (~SessionRow.status.in_(DEAD_SESSION_STATUSES)) query = ( @@ -528,8 +536,9 @@ def _build_session_fetch_query( async def _match_sessions_by_id( db_session: SASession, session_id_or_list: SessionId | list[SessionId], - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_prefix: bool = False, allow_stale: bool = True, for_update: bool = False, @@ -546,7 +555,8 @@ async def _match_sessions_by_id( cond = SessionRow.id == session_id_or_list query = _build_session_fetch_query( cond, - access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, max_matches=max_matches, allow_stale=allow_stale, for_update=for_update, @@ -560,8 +570,9 @@ async def _match_sessions_by_id( async def _match_sessions_by_name( db_session: SASession, session_name: str, - access_key: AccessKey, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_prefix: bool = False, allow_stale: bool = True, for_update: bool = False, @@ -575,7 +586,8 @@ async def _match_sessions_by_name( cond = SessionRow.name == session_name query = _build_session_fetch_query( cond, - access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, max_matches=max_matches, allow_stale=allow_stale, for_update=for_update, @@ -595,20 +607,6 @@ class ConcurrencyUsed: compute_session_ids: set[SessionId] = field(default_factory=set) system_session_ids: set[SessionId] = field(default_factory=set) - @property - def compute_concurrency_used_key(self) -> str: - return f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" - - @property - def system_concurrency_used_key(self) -> str: - return f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" - - def to_cnt_map(self) -> Mapping[str, int]: - return { - self.compute_concurrency_used_key: len(self.compute_session_ids), - self.system_concurrency_used_key: len(self.system_session_ids), - } - class SessionOp(enum.StrEnum): CREATE = "create_session" @@ -637,13 +635,6 @@ class KernelLoadingStrategy(enum.StrEnum): } -# Defined for avoiding circular import -def _get_keypair_row_join_condition() -> sa.sql.elements.ColumnElement[Any]: - from ai.backend.manager.models.keypair import KeyPairRow - - return KeyPairRow.access_key == foreign(SessionRow.access_key) - - def _get_user_row_join_condition() -> sa.sql.elements.ColumnElement[Any]: from ai.backend.manager.models.user import UserRow @@ -731,14 +722,7 @@ class SessionRow(Base): # type: ignore[misc] back_populates="sessions", foreign_keys=[user_uuid], ) - access_key: Mapped[str | None] = mapped_column("access_key", sa.String(length=20)) - access_key_row: Mapped[KeyPairRow | None] = relationship( - "KeyPairRow", - primaryjoin=_get_keypair_row_join_condition, - back_populates="sessions", - foreign_keys=[access_key], - ) # `image` column is identical to kernels `image` column. images: Mapped[list[str] | None] = mapped_column("images", sa.ARRAY(sa.String), nullable=True) @@ -884,7 +868,7 @@ class SessionRow(Base): # type: ignore[misc] sa.Index("ix_session_status_with_priority", "status", "priority"), # Unique index for session names per user excluding terminal statuses sa.Index( - "ix_sessions_unique_name_per_user_nonterminal", + "ix_sessions_unique_name_per_owner_nonterminal", "name", "user_uuid", unique=True, @@ -923,8 +907,7 @@ def from_dataclass(cls, session_data: SessionData) -> SessionRow: target_sgroup_names=session_data.target_sgroup_names, domain_name=session_data.domain_name, group_id=session_data.group_id, - user_uuid=session_data.user_uuid, - access_key=session_data.access_key, + user_uuid=session_data.owner_id, images=session_data.images, tag=session_data.tag, occupying_slots=session_data.occupying_slots, @@ -968,8 +951,7 @@ def to_dataclass(self, owner: UserData | None = None) -> SessionData: target_sgroup_names=self.target_sgroup_names, domain_name=self.domain_name, group_id=self.group_id, - user_uuid=self.user_uuid, - access_key=AccessKey(self.access_key) if self.access_key else None, + owner_id=self.user_uuid, images=self.images, tag=self.tag, occupying_slots=self.occupying_slots, @@ -1017,8 +999,7 @@ def from_session_info(cls, info: SessionInfo) -> Self: target_sgroup_names=info.resource.target_sgroup_names, domain_name=info.metadata.domain_name, group_id=info.metadata.group_id, - user_uuid=info.metadata.user_uuid, - access_key=info.metadata.access_key, + user_uuid=info.metadata.owner_id, images=info.image.images, tag=info.image.tag or info.metadata.tag, occupying_slots=info.resource.occupying_slots, @@ -1059,8 +1040,7 @@ def to_session_info(self) -> SessionInfo: name=self.name or "", domain_name=self.domain_name, group_id=self.group_id, - user_uuid=self.user_uuid, - access_key=self.access_key or "", + owner_id=self.user_uuid, session_type=self.session_type, priority=self.priority, created_at=self.created_at, @@ -1284,11 +1264,10 @@ def set_status( if _status_info is not None: self.status_info = _status_info - def delegate_ownership(self, user_uuid: UUID, access_key: AccessKey) -> None: - self.user_uuid = user_uuid - self.access_key = access_key + def delegate_ownership(self, owner_id: UUID) -> None: + self.user_uuid = owner_id for kernel_row in self.kernels: - kernel_row.delegate_ownership(user_uuid, access_key) + kernel_row.delegate_ownership(owner_id) @staticmethod async def delete_by_user_id(user_uuid: UUID, *, db_session: SASession) -> None: @@ -1357,8 +1336,9 @@ async def match_sessions( cls, db_session: SASession, session_reference: str | UUID | list[UUID], - access_key: AccessKey | None, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_prefix: bool = False, allow_stale: bool = True, for_update: bool = False, @@ -1367,7 +1347,8 @@ async def match_sessions( ) -> list[SessionRow]: """ Match the prefix of session ID or session name among the sessions - that belongs to the given access key, and return the list of SessionRow. + that belong to the given owner (``owner_id``), and return the list + of ``SessionRow``. """ if isinstance(session_reference, list): @@ -1412,7 +1393,8 @@ async def match_sessions( for fetch_func in query_list: rows = await fetch_func( db_session, - access_key=access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, allow_stale=allow_stale, for_update=for_update, max_matches=max_matches, @@ -1428,8 +1410,9 @@ async def get_session( cls, db_session: SASession, session_name_or_id: str | UUID, - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, + owner_access_key: Any = None, allow_stale: bool = False, for_update: bool = False, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE, @@ -1437,12 +1420,12 @@ async def get_session( ) -> SessionRow: """ Retrieve the session information by session's UUID, - or session's name paired with access_key. + or session's name paired with ``owner_id``. This will return the information of the session and the sibling kernel(s). :param db_session: Database connection to use when fetching row. :param session_name_or_id: Name or ID (UUID) of session to look up. - :param access_key: Access key used to create session. + :param owner_id: UUID of the session owner; required when ``session_name_or_id`` is a name. :param allow_stale: If set to True, filter "inactive" sessions as well as "active" ones. Otherwise filter "active" sessions only. :param for_update: Apply for_update during executing select query. @@ -1474,7 +1457,8 @@ async def get_session( session_list = await cls.match_sessions( db_session, session_name_or_id, - access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, allow_stale=allow_stale, for_update=for_update, eager_loading_op=_eager_loading_op, @@ -1499,8 +1483,8 @@ async def list_sessions( cls, db_session: SASession, session_ids: list[UUID], - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, allow_stale: bool = False, for_update: bool = False, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE, @@ -1531,7 +1515,7 @@ async def list_sessions( session_list = await cls.match_sessions( db_session, session_ids, - access_key, + owner_id=owner_id, allow_stale=allow_stale, for_update=for_update, eager_loading_op=_eager_loading_op, @@ -1547,8 +1531,8 @@ async def get_session_by_id( cls, db_session: SASession, session_id: SessionId, - access_key: AccessKey | None = None, *, + owner_id: UUID | None = None, max_matches: int | None = None, allow_stale: bool = True, for_update: bool = False, @@ -1557,7 +1541,7 @@ async def get_session_by_id( sessions = await _match_sessions_by_id( db_session, session_id, - access_key, + owner_id=owner_id, max_matches=max_matches, allow_stale=allow_stale, for_update=for_update, @@ -1586,7 +1570,6 @@ async def get_sgroup_managed_sessions( noload("*"), selectinload(SessionRow.group).options(noload("*")), selectinload(SessionRow.domain).options(noload("*")), - selectinload(SessionRow.access_key_row).options(noload("*")), selectinload(SessionRow.kernels).options(noload("*")), ) ) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 9fa12a3e836..72663586670 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -116,6 +116,7 @@ from ai.backend.manager.plugin.network import NetworkPluginContext from ai.backend.manager.repositories.resource_slot import ResourceSlotRepository from ai.backend.manager.repositories.scheduler.types.session_creation import SessionCreationSpec +from ai.backend.manager.repositories.user.repository import UserRepository from ai.backend.manager.sokovan.scheduling_controller import SchedulingController from .agent_cache import AgentRPCCache @@ -221,6 +222,7 @@ def __init__( hook_plugin_ctx: HookPluginContext, network_plugin_ctx: NetworkPluginContext, scheduling_controller: SchedulingController, + user_repository: UserRepository, *, debug: bool = False, manager_public_key: PublicKey, @@ -252,6 +254,7 @@ def __init__( event_producer, hook_plugin_ctx, self, + user_repository, ) self._client_pool = ClientPool(tcp_client_session_factory) @@ -460,7 +463,7 @@ async def create_session( sess = await SessionRow.get_session( db_session, session_name, - owner_access_key, + owner_id=user_scope.user_uuid, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) if sess.main_kernel.image is None: @@ -687,7 +690,7 @@ async def create_cluster( await SessionRow.get_session( db_sess, session_name, - owner_access_key, + owner_id=user_scope.user_uuid, ) except SessionNotFound: pass diff --git a/src/ai/backend/manager/repositories/events/db_source/db_source.py b/src/ai/backend/manager/repositories/events/db_source/db_source.py index c163a8ce0c7..ec66a8f0432 100644 --- a/src/ai/backend/manager/repositories/events/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/events/db_source/db_source.py @@ -6,6 +6,7 @@ from ai.backend.manager.errors.resource import ProjectNotFound from ai.backend.manager.models.group import groups from ai.backend.manager.models.session import SessionRow +from ai.backend.manager.models.user import UserRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine @@ -21,8 +22,13 @@ async def match_sessions_by_name( access_key: AccessKey, ) -> list[SessionRow]: async with self._db.begin_readonly_session(isolation_level="READ COMMITTED") as db_sess: + owner_id = await db_sess.scalar( + sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key) + ) + if owner_id is None: + return [] return await SessionRow.match_sessions( - db_sess, session_name, access_key, allow_prefix=False + db_sess, session_name, owner_id=owner_id, allow_prefix=False ) async def resolve_group_id(self, group_name: str) -> uuid.UUID: diff --git a/src/ai/backend/manager/repositories/model_serving/repository.py b/src/ai/backend/manager/repositories/model_serving/repository.py index 20c7eca71ce..9866f696e42 100644 --- a/src/ai/backend/manager/repositories/model_serving/repository.py +++ b/src/ai/backend/manager/repositories/model_serving/repository.py @@ -5,7 +5,6 @@ import sqlalchemy as sa from pydantic import HttpUrl -from ruamel.yaml import YAML from sqlalchemy.exc import IntegrityError, NoResultFound, StatementError from sqlalchemy.ext.asyncio import AsyncSession as SASession from sqlalchemy.orm import selectinload @@ -40,7 +39,7 @@ UserData, ) from ai.backend.manager.data.permission.types import RBACElementRef -from ai.backend.manager.data.vfolder.types import VFolderLocation, VFolderOwnershipType +from ai.backend.manager.data.vfolder.types import VFolderOwnershipType from ai.backend.manager.errors.common import ObjectNotFound from ai.backend.manager.errors.resource import DatabaseConnectionUnavailable from ai.backend.manager.errors.service import EndpointNotFound @@ -81,9 +80,6 @@ execute_rbac_entity_creator, ) from ai.backend.manager.repositories.deployment.creators import DeploymentPolicyCreatorSpec -from ai.backend.manager.repositories.deployment.storage_source.storage_source import ( - DeploymentStorageSource, -) from ai.backend.manager.repositories.model_serving.updaters import EndpointUpdaterSpec from ai.backend.manager.services.model_serving.actions.modify_endpoint import ModifyEndpointAction from ai.backend.manager.services.model_serving.exceptions import ( @@ -738,7 +734,7 @@ async def get_session_by_id( async with self._db.begin_readonly_session_read_committed() as session: try: return await SessionRow.get_session( - session, session_id, None, kernel_loading_strategy=kernel_loading_strategy + session, session_id, kernel_loading_strategy=kernel_loading_strategy ) except NoResultFound: return None @@ -832,15 +828,6 @@ async def _do_mutate() -> MutationResult: if current_rev is None: raise InvalidAPIParameters("Endpoint has no current revision") - # Re-read model definition from vfolder to pick up file changes - refreshed_model_definition = await self._fetch_model_definition_from_vfolder( - db_session, - storage_manager, - current_rev.model, - spec.model_definition_path.optional_value() - or current_rev.model_definition_path, - ) - # Resolve image if changed image_id = current_rev.image image_ref = spec.image.optional_value() @@ -873,7 +860,6 @@ async def _do_mutate() -> MutationResult: if spec.model_definition_path.optional_value() is not None else current_rev.model_definition_path ), - model_definition=refreshed_model_definition or current_rev.model_definition, resource_group=endpoint_row.resource_group, resource_opts=( spec.resource_opts.optional_value() @@ -955,51 +941,6 @@ async def _do_mutate() -> MutationResult: except Exception: raise - async def _fetch_model_definition_from_vfolder( - self, - db_session: SASession, - storage_manager: StorageSessionManager, - vfolder_id: uuid.UUID | None, - model_definition_path: str | None, - ) -> dict[str, Any] | None: - """Re-read model definition file from the vfolder storage. - - Returns the parsed YAML content, or None if the file cannot be read. - """ - if vfolder_id is None: - return None - try: - vf_query = sa.select( - VFolderRow.id, - VFolderRow.host, - VFolderRow.quota_scope_id, - VFolderRow.ownership_type, - VFolderRow.usage_mode, - ).where(VFolderRow.id == vfolder_id) - vf_result = await db_session.execute(vf_query) - vf_row = vf_result.one_or_none() - if vf_row is None: - return None - - vfolder_location = VFolderLocation( - id=vf_row.id, - quota_scope_id=vf_row.quota_scope_id, - host=vf_row.host, - ownership_type=vf_row.ownership_type, - usage_mode=vf_row.usage_mode, - ) - candidates = ( - [model_definition_path] - if model_definition_path - else ["model-definition.yaml", "model-definition.yml"] - ) - storage_source = DeploymentStorageSource(storage_manager) - content = await storage_source.fetch_definition_file(vfolder_location, candidates) - yaml = YAML() - return cast(dict[str, Any], yaml.load(content)) - except Exception: - return None - @model_serving_repository_resilience.apply() async def search_auto_scaling_rules( self, diff --git a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py index 0af02fa9c20..c1a2a9e8ea5 100644 --- a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py @@ -351,9 +351,9 @@ async def _fetch_pending_sessions( if session_id not in sessions_map: sessions_map[session_id] = PendingSessionData( id=session_id, - access_key=row.access_key, + main_access_key=row.access_key, requested_slots=row.requested_slots, - user_uuid=row.user_uuid, + owner_id=row.user_uuid, group_id=row.group_id, domain_name=row.domain_name, scaling_group_name=row.scaling_group_name, @@ -700,7 +700,7 @@ async def _fetch_user_policies( """Fetch user resource policies for users in pending sessions.""" user_policies: dict[UUID, UserResourcePolicy] = {} - if not pending_sessions.user_uuids: + if not pending_sessions.owner_ids: return user_policies user_policy_result = await db_sess.execute( @@ -716,7 +716,7 @@ async def _fetch_user_policies( KeyPairResourcePolicyRow, KeyPairRow.resource_policy == KeyPairResourcePolicyRow.name, ) - .where(UserRow.uuid.in_(pending_sessions.user_uuids)) + .where(UserRow.uuid.in_(pending_sessions.owner_ids)) ) for row in user_policy_result: @@ -1146,11 +1146,12 @@ async def get_terminating_sessions_by_ids( for kernel in session_row.kernels ] + owner_main_ak = session_row.user.main_access_key if session_row.user else None terminating_sessions.append( TerminatingSessionData( session_id=session_row.id, - access_key=AccessKey(session_row.access_key) - if session_row.access_key + main_access_key=AccessKey(owner_main_ak) + if owner_main_ak else AccessKey(""), creation_id=session_row.creation_id or "", status=session_row.status, @@ -1183,11 +1184,12 @@ async def get_pending_timeout_sessions_by_ids( sa.select( SessionRow.id, SessionRow.creation_id, - SessionRow.access_key, + UserRow.main_access_key, SessionRow.created_at, ScalingGroupRow.scheduler_opts, ) .select_from(SessionRow) + .join(UserRow, SessionRow.user_uuid == UserRow.uuid) .join(ScalingGroupRow, SessionRow.scaling_group_name == ScalingGroupRow.name) .where( SessionRow.id.in_(session_ids), @@ -1213,7 +1215,7 @@ async def get_pending_timeout_sessions_by_ids( SweptSessionInfo( session_id=row.id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, ) ) @@ -1302,8 +1304,8 @@ async def enqueue_session( id=session_data.id, creation_id=session_data.creation_id, name=session_data.name, - access_key=session_data.access_key, - user_uuid=session_data.user_uuid, + access_key=session_data.main_access_key, + user_uuid=session_data.owner_id, group_id=session_data.group_id, domain_name=session_data.domain_name, scaling_group_name=session_data.scaling_group_name, @@ -1349,8 +1351,8 @@ async def enqueue_session( scaling_group=kernel.scaling_group, domain_name=kernel.domain_name, group_id=kernel.group_id, - user_uuid=kernel.user_uuid, - access_key=kernel.access_key, + user_uuid=kernel.owner_id, + access_key=kernel.main_access_key, image=kernel.image, architecture=kernel.architecture, registry=kernel.registry, @@ -1387,7 +1389,7 @@ async def enqueue_session( element_type=RBACElementType.SESSION, scope_ref=RBACElementRef( element_type=RBACElementType.USER, - element_id=str(session_data.user_uuid), + element_id=str(session_data.owner_id), ), additional_scope_refs=[ RBACElementRef( @@ -1843,21 +1845,26 @@ async def allocate_sessions( # First, fetch session data to get creation_id and access_key session_ids = {alloc.session_id for alloc in allocation_batch.allocations} if session_ids: - query = sa.select( - SessionRow.id, SessionRow.creation_id, SessionRow.access_key - ).where(SessionRow.id.in_(session_ids)) + query = ( + sa.select(SessionRow.id, SessionRow.creation_id, UserRow.main_access_key) + .select_from(SessionRow) + .join(UserRow, SessionRow.user_uuid == UserRow.uuid) + .where(SessionRow.id.in_(session_ids)) + ) result = await db_sess.execute(query) - session_data_map = {row.id: (row.creation_id, row.access_key) for row in result} + session_data_map = { + row.id: (row.creation_id, row.main_access_key) for row in result + } # Create SessionEventData for each allocated session for allocation in allocation_batch.allocations: if session_data := session_data_map.get(allocation.session_id): - creation_id, access_key = session_data + creation_id, main_access_key = session_data scheduled_sessions.append( ScheduledSessionData( session_id=allocation.session_id, creation_id=creation_id, - access_key=access_key, + main_access_key=main_access_key, reason="triggered-by-scheduler", ) ) @@ -2917,7 +2924,9 @@ async def _get_sessions_by_statuses( scheduled_session = ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""), + main_access_key=AccessKey(session.access_key) + if session.access_key + else AccessKey(""), reason="triggered-by-scheduler", ) scheduled_sessions.append(scheduled_session) @@ -2962,7 +2971,7 @@ async def _get_scheduled_sessions(self, db_sess: SASession) -> list[ScheduledSes ScheduledSessionData( session_id=session.id, creation_id=session.creation_id or "", - access_key=AccessKey(session.access_key) + main_access_key=AccessKey(session.access_key) if session.access_key else AccessKey(""), reason="triggered-by-scheduler", @@ -3102,7 +3111,7 @@ async def _get_sessions_for_pull( sessions_map[session_id] = SessionDataForPull( session_id=session_id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, kernels=[], ) @@ -3294,13 +3303,13 @@ async def _get_sessions_for_start( SessionDataForStart( session_id=session_info["id"], creation_id=session_info["creation_id"], - access_key=session_info["access_key"], + main_access_key=session_info["access_key"], session_type=session_info["session_type"], name=session_info["name"], cluster_mode=session_info["cluster_mode"], kernels=kernel_bindings, environ=session_info.get("environ", {}), - user_uuid=session_info["user_uuid"], + owner_id=session_info["user_uuid"], user_email=user_info.email, user_name=user_info.username, ) @@ -4074,7 +4083,7 @@ async def _fetch_sessions_for_pull_by_ids( sessions_map[session_id] = SessionDataForPull( session_id=session_id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, kernels=[], ) @@ -4293,13 +4302,13 @@ async def _fetch_sessions_for_start_by_ids( SessionDataForStart( session_id=session_info["id"], creation_id=session_info["creation_id"], - access_key=session_info["access_key"], + main_access_key=session_info["access_key"], session_type=session_info["session_type"], name=session_info["name"], cluster_mode=session_info["cluster_mode"], kernels=kernel_bindings, environ=session_info.get("environ", {}), - user_uuid=session_info["user_uuid"], + owner_id=session_info["user_uuid"], user_email=user_info.email, user_name=user_info.username, ) @@ -4369,7 +4378,7 @@ async def search_sessions_with_kernels( sessions_map[row.id] = SessionDataForPull( session_id=row.id, creation_id=row.creation_id, - access_key=row.access_key, + main_access_key=row.access_key, kernels=[], ) @@ -4625,13 +4634,13 @@ async def search_sessions_with_kernels_and_user( SessionDataForStart( session_id=session_info["id"], creation_id=session_info["creation_id"], - access_key=session_info["access_key"], + main_access_key=session_info["access_key"], session_type=session_info["session_type"], name=session_info["name"], cluster_mode=session_info["cluster_mode"], kernels=session_info["kernels"], environ=session_info.get("environ") or {}, - user_uuid=session_info["user_uuid"], + owner_id=session_info["user_uuid"], user_email=user_info.email, user_name=user_info.username, ) @@ -4774,6 +4783,26 @@ async def get_db_now(self) -> datetime: result = await conn.execute(sa.select(sa.func.now())) return result.scalar_one() + async def resolve_main_access_keys( + self, session_ids: Sequence[SessionId] + ) -> dict[SessionId, AccessKey]: + """Resolve the main access key for each session's owner. + + Joins ``sessions`` → ``users`` to look up the owner's + ``main_access_key``. Sessions whose owner has no configured + main access key are omitted from the returned mapping. + """ + if not session_ids: + return {} + async with self._db.begin_readonly_session() as db_sess: + stmt = ( + sa.select(SessionRow.id, UserRow.main_access_key) + .join(UserRow, SessionRow.user_uuid == UserRow.uuid) + .where(SessionRow.id.in_([sid for sid in session_ids])) + ) + rows = (await db_sess.execute(stmt)).all() + return {SessionId(row[0]): AccessKey(row[1]) for row in rows if row[1] is not None} + async def _get_db_now_in_session(self, db_sess: SASession) -> datetime: """Get the current timestamp from the database within an existing session. diff --git a/src/ai/backend/manager/repositories/scheduler/options.py b/src/ai/backend/manager/repositories/scheduler/options.py index 7ceb9b68088..7a27cb148cf 100644 --- a/src/ai/backend/manager/repositories/scheduler/options.py +++ b/src/ai/backend/manager/repositories/scheduler/options.py @@ -107,58 +107,6 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner - @staticmethod - def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}%") - else: - condition = SessionRow.access_key.like(f"%{spec.value}%") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = sa.func.lower(SessionRow.access_key) == spec.value.lower() - else: - condition = SessionRow.access_key == spec.value - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"{spec.value}%") - else: - condition = SessionRow.access_key.like(f"{spec.value}%") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - - @staticmethod - def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition: - def inner() -> sa.sql.expression.ColumnElement[bool]: - if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}") - else: - condition = SessionRow.access_key.like(f"%{spec.value}") - if spec.negated: - condition = sa.not_(condition) - return condition - - return inner - @staticmethod def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: @@ -413,8 +361,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner @staticmethod - def by_user_uuid_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition: - """Factory for user UUID equality filter.""" + def by_owner_id_filter_equals(spec: UUIDEqualMatchSpec) -> QueryCondition: + """Factory for owner_id equality filter.""" def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.negated: @@ -424,8 +372,8 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner @staticmethod - def by_user_uuid_filter_in(spec: UUIDInMatchSpec) -> QueryCondition: - """Factory for user UUID IN filter.""" + def by_owner_id_filter_in(spec: UUIDInMatchSpec) -> QueryCondition: + """Factory for owner_id IN filter.""" def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.negated: diff --git a/src/ai/backend/manager/repositories/scheduler/repository.py b/src/ai/backend/manager/repositories/scheduler/repository.py index 2921a662814..0e38d59b6a9 100644 --- a/src/ai/backend/manager/repositories/scheduler/repository.py +++ b/src/ai/backend/manager/repositories/scheduler/repository.py @@ -959,3 +959,9 @@ async def get_db_now(self) -> datetime: Current database timestamp with timezone """ return await self._db_source.get_db_now() + + async def resolve_main_access_keys( + self, session_ids: Sequence[SessionId] + ) -> dict[SessionId, AccessKey]: + """Resolve the main access key for each session's owner.""" + return await self._db_source.resolve_main_access_keys(session_ids) diff --git a/src/ai/backend/manager/repositories/scheduler/types/allocation.py b/src/ai/backend/manager/repositories/scheduler/types/allocation.py index d25c0440d52..2b9629ba7ce 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/allocation.py +++ b/src/ai/backend/manager/repositories/scheduler/types/allocation.py @@ -69,8 +69,7 @@ class SessionAllocation: kernel_allocations: list[KernelAllocation] # List of agent allocations for this session agent_allocations: list[AgentAllocation] - # Keypair associated with the session - access_key: AccessKey + main_access_key: AccessKey # Phases that passed during scheduling passed_phases: list[SchedulingPredicate] = field(default_factory=list) # Phases that failed during scheduling (normally empty for successful allocations) diff --git a/src/ai/backend/manager/repositories/scheduler/types/results.py b/src/ai/backend/manager/repositories/scheduler/types/results.py index 75be947cd08..40fff71ddf5 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/results.py +++ b/src/ai/backend/manager/repositories/scheduler/types/results.py @@ -13,5 +13,5 @@ class ScheduledSessionData: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey reason: str diff --git a/src/ai/backend/manager/repositories/scheduler/types/session.py b/src/ai/backend/manager/repositories/scheduler/types/session.py index fe975b3d4e8..8da9962fcac 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/session.py +++ b/src/ai/backend/manager/repositories/scheduler/types/session.py @@ -44,9 +44,9 @@ class PendingSessionData: """Pending session data for scheduling.""" id: SessionId - access_key: AccessKey + main_access_key: AccessKey requested_slots: ResourceSlot - user_uuid: UUID + owner_id: UUID group_id: UUID domain_name: str scaling_group_name: str @@ -64,9 +64,9 @@ def to_session_workload(self) -> SessionWorkload: kernel_workloads = [k.to_kernel_workload() for k in self.kernels] return SessionWorkload( session_id=self.id, - access_key=self.access_key, + main_access_key=self.main_access_key, requested_slots=self.requested_slots, - user_uuid=self.user_uuid, + owner_id=self.owner_id, group_id=self.group_id, domain_name=self.domain_name, scaling_group=self.scaling_group_name, @@ -90,12 +90,12 @@ class PendingSessions: @cached_property def access_keys(self) -> set[AccessKey]: """Extract unique access keys from pending sessions.""" - return {s.access_key for s in self.sessions} + return {s.main_access_key for s in self.sessions} @cached_property - def user_uuids(self) -> set[UUID]: - """Extract unique user UUIDs from pending sessions.""" - return {s.user_uuid for s in self.sessions} + def owner_ids(self) -> set[UUID]: + """Extract unique owner (user) UUIDs from pending sessions.""" + return {s.owner_id for s in self.sessions} @cached_property def group_ids(self) -> set[UUID]: @@ -125,7 +125,7 @@ class TerminatingSessionData: """Data for a session that needs to be terminated.""" session_id: SessionId - access_key: AccessKey + main_access_key: AccessKey creation_id: str status: SessionStatus status_info: str @@ -161,7 +161,7 @@ class SweptSessionInfo: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey @dataclass diff --git a/src/ai/backend/manager/repositories/scheduler/types/session_creation.py b/src/ai/backend/manager/repositories/scheduler/types/session_creation.py index 3dfcab2a6ae..a2187381587 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/session_creation.py +++ b/src/ai/backend/manager/repositories/scheduler/types/session_creation.py @@ -222,8 +222,8 @@ class KernelEnqueueData: scaling_group: str domain_name: str group_id: UUID - user_uuid: UUID - access_key: AccessKey + owner_id: UUID + main_access_key: AccessKey image: str # Canonical image name architecture: str registry: str @@ -268,8 +268,8 @@ class SessionEnqueueData: id: SessionId creation_id: str name: str - access_key: AccessKey - user_uuid: UUID + main_access_key: AccessKey + owner_id: UUID group_id: UUID domain_name: str scaling_group_name: str diff --git a/src/ai/backend/manager/repositories/session/creators.py b/src/ai/backend/manager/repositories/session/creators.py index bb1605aefcc..be89fe4c3fd 100644 --- a/src/ai/backend/manager/repositories/session/creators.py +++ b/src/ai/backend/manager/repositories/session/creators.py @@ -18,7 +18,7 @@ class SessionRowCreatorSpec(CreatorSpec[SessionRow]): SessionRow instances. It simply returns the provided row in build_row(). For scope information needed by RBACEntityCreator, use the row's user_uuid - field as the scope_id with ScopeType.USER. + field (the owner's UUID) as the scope_id with ScopeType.USER. """ row: SessionRow diff --git a/src/ai/backend/manager/repositories/session/db_source/db_source.py b/src/ai/backend/manager/repositories/session/db_source/db_source.py index b544b94dd89..7eed18c3170 100644 --- a/src/ai/backend/manager/repositories/session/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/session/db_source/db_source.py @@ -63,7 +63,7 @@ async def get_session_owner(self, session_id: str | SessionId) -> UserData: async def get_session_validated( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY, allow_stale: bool = False, eager_loading_op: Sequence[_AbstractLoad] | None = None, @@ -73,7 +73,7 @@ async def get_session_validated( return await SessionRow.get_session( db_sess, session_name_or_id, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=kernel_loading_strategy, allow_stale=allow_stale, eager_loading_op=list(eager_loading_op) if eager_loading_op else None, @@ -82,13 +82,13 @@ async def get_session_validated( async def match_sessions( self, id_or_name_prefix: str, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> list[SessionRow]: async with self._db.begin_readonly_session_read_committed() as db_sess: return await SessionRow.match_sessions( db_sess, id_or_name_prefix, - owner_access_key, + owner_id=owner_id, ) async def get_session_to_determine_status( @@ -132,7 +132,7 @@ async def update_session_name( self, session_name_or_id: str | SessionId, new_name: str, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> SessionRow: async def _update(db_session: AsyncSession) -> SessionRow: # Check if new name already exists for this owner @@ -140,7 +140,7 @@ async def _update(db_session: AsyncSession) -> SessionRow: await SessionRow.get_session( db_session, new_name, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.NONE, ) raise SessionAlreadyExists(f"Session with name '{new_name}' already exists") @@ -151,7 +151,7 @@ async def _update(db_session: AsyncSession) -> SessionRow: session_row = await SessionRow.get_session( db_session, session_name_or_id, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS, ) @@ -305,13 +305,12 @@ async def modify_session( if session_row is None: raise SessionNotFound(f"Session not found (id:{session_id})") - if session_name and session_row.access_key is not None: - # Check the owner of the target session has any session with the same name + if session_name: try: sess = await SessionRow.get_session( db_session, session_name, - AccessKey(session_row.access_key), + owner_id=session_row.user_uuid, ) except SessionNotFound: pass @@ -371,7 +370,7 @@ async def _find_dependent_sessions( self, db_sess: AsyncSession, root_session_name_or_id: str | uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, allow_stale: bool = False, ) -> tuple[uuid.UUID, set[uuid.UUID]]: """ @@ -379,7 +378,7 @@ async def _find_dependent_sessions( :param db_sess: Database session :param root_session_name_or_id: Root session name or ID - :param access_key: Access key of the session owner + :param owner_id: UUID of the session owner :param allow_stale: Whether to allow stale sessions :return: Tuple of (root_session_id, set of dependent session IDs) """ @@ -401,7 +400,7 @@ async def _find_recursive_dependencies(session_id: uuid.UUID) -> set[uuid.UUID]: root_session = await SessionRow.get_session( db_sess, root_session_name_or_id, - access_key=access_key, + owner_id=owner_id, allow_stale=allow_stale, ) root_session_id = cast(uuid.UUID, root_session.id) @@ -412,14 +411,14 @@ async def _find_recursive_dependencies(session_id: uuid.UUID) -> set[uuid.UUID]: async def get_target_session_ids( self, session_name_or_id: str | uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, recursive: bool = False, ) -> list[SessionId]: """ Get list of session IDs including dependent sessions if recursive. :param session_name_or_id: Name or ID of the primary session - :param access_key: Access key of the session owner + :param owner_id: User UUID of the session owner :param recursive: If True, include dependent sessions :return: List of session IDs """ @@ -430,7 +429,7 @@ async def get_target_session_ids( root_id, dependent_ids = await self._find_dependent_sessions( db_sess, session_name_or_id, - access_key, + owner_id, allow_stale=True, ) # Return dependent sessions first, then root session @@ -441,7 +440,7 @@ async def get_target_session_ids( session = await SessionRow.get_session( db_sess, session_name_or_id, - access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.NONE, allow_stale=True, ) @@ -454,19 +453,19 @@ async def get_target_session_ids( async def find_dependency_sessions( self, session_name_or_id: uuid.UUID | str, - access_key: AccessKey, + owner_id: uuid.UUID, ) -> dict[str, list[Any] | str]: async with self._db.begin_readonly_session_read_committed() as db_sess: return await find_dependency_sessions( session_name_or_id, db_sess, - access_key, + owner_id, ) async def get_session_with_group( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY, allow_stale: bool = False, ) -> SessionRow: @@ -475,7 +474,7 @@ async def get_session_with_group( return await SessionRow.get_session( db_sess, session_name_or_id, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=kernel_loading_strategy, allow_stale=allow_stale, eager_loading_op=[selectinload(SessionRow.group)], @@ -484,14 +483,14 @@ async def get_session_with_group( async def get_session_with_routing_minimal( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> SessionRow: """Get session with minimal routing information""" async with self._db.begin_readonly_session_read_committed() as db_sess: return await SessionRow.get_session( db_sess, session_name_or_id, - owner_access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, eager_loading_op=[ selectinload(SessionRow.routing).options(noload("*")), @@ -594,7 +593,7 @@ async def search_kernels( KernelListResult with items, total count, and pagination info """ async with self._db.begin_readonly_session() as db_sess: - query = sa.select(KernelRow) + query = sa.select(KernelRow).options(selectinload(KernelRow.user_row)) result = await execute_batch_querier( db_sess, diff --git a/src/ai/backend/manager/repositories/session/dependency_graph.py b/src/ai/backend/manager/repositories/session/dependency_graph.py index d9234416bb2..3ebdded333f 100644 --- a/src/ai/backend/manager/repositories/session/dependency_graph.py +++ b/src/ai/backend/manager/repositories/session/dependency_graph.py @@ -23,12 +23,15 @@ async def _find_dependency_sessions( session_name_or_id: UUID | str, db_session: SASession, - access_key: AccessKey, + owner: UUID | AccessKey, ) -> dict[str, list[Any] | str]: + owner_id: UUID | None = owner if isinstance(owner, UUID) else None + owner_access_key = owner if not isinstance(owner, UUID) else None sessions = await SessionRow.match_sessions( db_session, session_name_or_id, - access_key=access_key, + owner_id=owner_id, + owner_access_key=owner_access_key, ) if len(sessions) < 1: @@ -66,7 +69,7 @@ async def _find_dependency_sessions( "status": str(kernel_query_result[0]), "status_changed": str(kernel_query_result[1]), "depends_on": [ - await _find_dependency_sessions(dependency_session_id, db_session, access_key) + await _find_dependency_sessions(dependency_session_id, db_session, owner) for dependency_session_id in dependency_session_ids ], } @@ -77,15 +80,15 @@ async def _find_dependency_sessions( async def find_dependency_sessions( session_name_or_id: UUID | str, db_session: SASession, - access_key: AccessKey, + owner: UUID | AccessKey, ) -> dict[str, list[Any] | str]: - return await _find_dependency_sessions(session_name_or_id, db_session, access_key) + return await _find_dependency_sessions(session_name_or_id, db_session, owner) async def find_dependent_sessions( root_session_name_or_id: str | UUID, db_session: SASession, - access_key: AccessKey, + owner_id: UUID, *, allow_stale: bool = False, ) -> set[UUID]: @@ -108,7 +111,7 @@ async def _find_dependent_sessions(session_id: UUID) -> set[UUID]: root_session = await SessionRow.get_session( db_session, root_session_name_or_id, - access_key=access_key, + owner_id=owner_id, allow_stale=allow_stale, ) return await _find_dependent_sessions(cast(UUID, root_session.id)) diff --git a/src/ai/backend/manager/repositories/session/repository.py b/src/ai/backend/manager/repositories/session/repository.py index 805604170e8..2ac82f731a8 100644 --- a/src/ai/backend/manager/repositories/session/repository.py +++ b/src/ai/backend/manager/repositories/session/repository.py @@ -57,14 +57,14 @@ async def get_session_owner(self, session_id: str | SessionId) -> UserData: async def get_session_validated( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY, allow_stale: bool = False, eager_loading_op: Sequence[_AbstractLoad] | None = None, ) -> SessionRow: return await self._db_source.get_session_validated( session_name_or_id, - owner_access_key, + owner_id, kernel_loading_strategy, allow_stale, eager_loading_op, @@ -74,9 +74,9 @@ async def get_session_validated( async def match_sessions( self, id_or_name_prefix: str, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> list[SessionRow]: - return await self._db_source.match_sessions(id_or_name_prefix, owner_access_key) + return await self._db_source.match_sessions(id_or_name_prefix, owner_id) @session_repository_resilience.apply() async def get_session_to_determine_status( @@ -104,11 +104,9 @@ async def update_session_name( self, session_name_or_id: str | SessionId, new_name: str, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> SessionRow: - return await self._db_source.update_session_name( - session_name_or_id, new_name, owner_access_key - ) + return await self._db_source.update_session_name(session_name_or_id, new_name, owner_id) @session_repository_resilience.apply() async def get_container_registry( @@ -210,52 +208,48 @@ async def query_userinfo( async def get_target_session_ids( self, session_name_or_id: str | uuid.UUID, - access_key: AccessKey, + owner_id: uuid.UUID, recursive: bool = False, ) -> list[SessionId]: """ Get list of session IDs including dependent sessions if recursive. :param session_name_or_id: Name or ID of the primary session - :param access_key: Access key of the session owner + :param owner_id: User UUID of the session owner :param recursive: If True, include dependent sessions :return: List of session IDs """ - return await self._db_source.get_target_session_ids( - session_name_or_id, access_key, recursive - ) + return await self._db_source.get_target_session_ids(session_name_or_id, owner_id, recursive) @session_repository_resilience.apply() async def find_dependency_sessions( self, session_name_or_id: uuid.UUID | str, - access_key: AccessKey, + owner_id: uuid.UUID, ) -> dict[str, list[Any] | str]: - return await self._db_source.find_dependency_sessions(session_name_or_id, access_key) + return await self._db_source.find_dependency_sessions(session_name_or_id, owner_id) @session_repository_resilience.apply() async def get_session_with_group( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.MAIN_KERNEL_ONLY, allow_stale: bool = False, ) -> SessionRow: """Get session with group information eagerly loaded""" return await self._db_source.get_session_with_group( - session_name_or_id, owner_access_key, kernel_loading_strategy, allow_stale + session_name_or_id, owner_id, kernel_loading_strategy, allow_stale ) @session_repository_resilience.apply() async def get_session_with_routing_minimal( self, session_name_or_id: str | SessionId, - owner_access_key: AccessKey, + owner_id: uuid.UUID, ) -> SessionRow: """Get session with minimal routing information""" - return await self._db_source.get_session_with_routing_minimal( - session_name_or_id, owner_access_key - ) + return await self._db_source.get_session_with_routing_minimal(session_name_or_id, owner_id) @session_repository_resilience.apply() async def search( diff --git a/src/ai/backend/manager/repositories/stream/db_source/db_source.py b/src/ai/backend/manager/repositories/stream/db_source/db_source.py index 8bfee9d611e..bc475de9bff 100644 --- a/src/ai/backend/manager/repositories/stream/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/stream/db_source/db_source.py @@ -1,5 +1,9 @@ +import sqlalchemy as sa + from ai.backend.common.types import AccessKey +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.session import KernelLoadingStrategy, SessionRow +from ai.backend.manager.models.user import UserRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine @@ -15,9 +19,14 @@ async def get_streaming_session( access_key: AccessKey, ) -> SessionRow: async with self._db.begin_readonly_session() as db_sess: + owner_id = await db_sess.scalar( + sa.select(UserRow.uuid).where(UserRow.main_access_key == access_key) + ) + if owner_id is None: + raise UserNotFound(f"No user with main_access_key={access_key}") return await SessionRow.get_session( db_sess, session_name, - access_key, + owner_id=owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) diff --git a/src/ai/backend/manager/repositories/user/db_source/db_source.py b/src/ai/backend/manager/repositories/user/db_source/db_source.py index f3ff8eef41b..2b0d3fd70d8 100644 --- a/src/ai/backend/manager/repositories/user/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/user/db_source/db_source.py @@ -137,6 +137,13 @@ async def get_user_by_uuid(self, user_uuid: UUID) -> UserData: user_row = await self._get_user_by_uuid(db_session, user_uuid) return user_row.to_data() + async def get_main_access_key_by_id(self, user_uuid: UUID) -> str | None: + """Return the user's ``main_access_key`` or ``None`` if unset/missing.""" + async with self._db.begin_readonly_session() as db_session: + return await db_session.scalar( + sa.select(UserRow.main_access_key).where(UserRow.uuid == user_uuid) + ) + async def get_by_email_validated( self, email: str, @@ -667,11 +674,14 @@ async def delegate_endpoint_ownership( target_user_uuid: UUID, target_main_access_key: AccessKey, ) -> None: - """Delegate endpoint ownership to another user.""" + """Delegate endpoint ownership to another user. + + ``target_main_access_key`` is kept on the facade for caller compatibility + but is no longer required by ``EndpointRow.delegate_endpoint_ownership``. + """ + del target_main_access_key # unused async with self._db.begin_session() as session: - await EndpointRow.delegate_endpoint_ownership( - session, user_uuid, target_user_uuid, target_main_access_key - ) + await EndpointRow.delegate_endpoint_ownership(session, user_uuid, target_user_uuid) async def delete_endpoints( self, diff --git a/src/ai/backend/manager/repositories/user/repository.py b/src/ai/backend/manager/repositories/user/repository.py index 9df169da1f9..9907b828595 100644 --- a/src/ai/backend/manager/repositories/user/repository.py +++ b/src/ai/backend/manager/repositories/user/repository.py @@ -77,6 +77,11 @@ async def get_user_by_uuid(self, user_uuid: UUID) -> UserData: """ return await self._db_source.get_user_by_uuid(user_uuid) + @user_repository_resilience.apply() + async def get_main_access_key_by_id(self, user_uuid: UUID) -> str | None: + """Return the user's ``main_access_key`` or ``None`` if unset/missing.""" + return await self._db_source.get_main_access_key_by_id(user_uuid) + @user_repository_resilience.apply() async def get_by_email_validated( self, diff --git a/src/ai/backend/manager/scheduler/drf.py b/src/ai/backend/manager/scheduler/drf.py index 6de246211f5..6b18a9b0645 100644 --- a/src/ai/backend/manager/scheduler/drf.py +++ b/src/ai/backend/manager/scheduler/drf.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import uuid from collections import defaultdict from collections.abc import Mapping, Sequence from decimal import Decimal @@ -9,7 +10,6 @@ import trafaret as t from ai.backend.common.types import ( - AccessKey, ResourceSlot, SessionId, ) @@ -24,7 +24,7 @@ class DRFScheduler(AbstractScheduler): - per_user_dominant_share: dict[AccessKey, Decimal] + per_user_dominant_share: dict[uuid.UUID, Decimal] total_capacity: ResourceSlot def __init__( @@ -60,23 +60,22 @@ def pick_session( slot_share = Decimal(value) / slot_cap if dominant_share < slot_share: dominant_share = slot_share - raw_access_key = existing_sess.access_key - if raw_access_key is not None: - access_key = AccessKey(raw_access_key) - if self.per_user_dominant_share[access_key] < dominant_share: - self.per_user_dominant_share[access_key] = dominant_share + owner_id = existing_sess.user_uuid + if owner_id is not None: + if self.per_user_dominant_share[owner_id] < dominant_share: + self.per_user_dominant_share[owner_id] = dominant_share log.debug("per-user dominant share: {}", dict(self.per_user_dominant_share)) # Find who has the least dominant share among the pending session. - users_with_pending_session: set[AccessKey] = { - AccessKey(pending_sess.access_key) + users_with_pending_session: set[uuid.UUID] = { + pending_sess.user_uuid for pending_sess in pending_sessions - if pending_sess.access_key is not None + if pending_sess.user_uuid is not None } if not users_with_pending_session: return None least_dominant_share_user, dshare = min( - ((akey, self.per_user_dominant_share[akey]) for akey in users_with_pending_session), + ((oid, self.per_user_dominant_share[oid]) for oid in users_with_pending_session), key=lambda item: item[1], ) log.debug("least dominant share user: {} ({})", least_dominant_share_user, dshare) @@ -84,7 +83,7 @@ def pick_session( # Pick the first pending session of the user # who has the lowest dominant share. for pending_sess in pending_sessions: - if pending_sess.access_key == least_dominant_share_user: + if pending_sess.user_uuid == least_dominant_share_user: return SessionId(pending_sess.id) return None @@ -96,10 +95,7 @@ def update_allocation( ) -> None: # In such case, we just skip updating self.per_user_dominant_share state # and the scheduler continues to pick another session within the same scaling group. - raw_access_key = scheduled_session_or_kernel.access_key - if raw_access_key is None: - return - access_key = AccessKey(raw_access_key) + owner_id = scheduled_session_or_kernel.user_uuid requested_slots = scheduled_session_or_kernel.requested_slots # Update the dominant share. @@ -114,5 +110,5 @@ def update_allocation( slot_share = Decimal(value) / slot_cap if dominant_share_from_request < slot_share: dominant_share_from_request = slot_share - if self.per_user_dominant_share[access_key] < dominant_share_from_request: - self.per_user_dominant_share[access_key] = dominant_share_from_request + if self.per_user_dominant_share[owner_id] < dominant_share_from_request: + self.per_user_dominant_share[owner_id] = dominant_share_from_request diff --git a/src/ai/backend/manager/scheduler/predicates.py b/src/ai/backend/manager/scheduler/predicates.py index f4cb7feffce..027a1faa1c0 100644 --- a/src/ai/backend/manager/scheduler/predicates.py +++ b/src/ai/backend/manager/scheduler/predicates.py @@ -29,6 +29,12 @@ log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.scheduler")) +async def _resolve_main_access_key(db_sess: SASession, sess_ctx: SessionRow) -> str | None: + """Resolve the owner's main access key via UserRow join.""" + stmt = sa.select(UserRow.main_access_key).where(UserRow.uuid == sess_ctx.user_uuid) + return await db_sess.scalar(stmt) + + async def check_reserved_batch_session( db_sess: SASession, _sched_ctx: SchedulingContext, @@ -53,9 +59,16 @@ async def check_concurrency( sched_ctx: SchedulingContext, sess_ctx: SessionRow, ) -> PredicateResult: + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult( + False, + "Session owner has no main_access_key; cannot evaluate concurrency policy", + ) + async def _get_max_concurrent_sessions() -> int: resouce_policy_q = sa.select(KeyPairRow.resource_policy).where( - KeyPairRow.access_key == sess_ctx.access_key + KeyPairRow.access_key == main_ak ) if sess_ctx.is_private: concurrent_session_column = KeyPairResourcePolicyRow.max_concurrent_sftp_sessions @@ -69,9 +82,9 @@ async def _get_max_concurrent_sessions() -> int: max_concurrent_sessions = await execute_with_retry(_get_max_concurrent_sessions) or 0 if sess_ctx.is_private: - redis_key = f"keypair.sftp_concurrency_used.{sess_ctx.access_key}" + redis_key = f"keypair.sftp_concurrency_used.{main_ak}" else: - redis_key = f"keypair.concurrency_used.{sess_ctx.access_key}" + redis_key = f"keypair.concurrency_used.{main_ak}" ok, concurrency_used = await sched_ctx.registry.valkey_stat.check_keypair_concurrency( redis_key, max_concurrent_sessions, @@ -83,7 +96,7 @@ async def _get_max_concurrent_sessions() -> int: ) log.debug( "number of concurrent sessions of ak:{0} = {1} / {2}", - sess_ctx.access_key, + main_ak, concurrency_used, max_concurrent_sessions, ) @@ -135,9 +148,10 @@ async def check_keypair_resource_limit( sched_ctx: SchedulingContext, sess_ctx: SessionRow, ) -> PredicateResult: - resouce_policy_q = sa.select(KeyPairRow.resource_policy).where( - KeyPairRow.access_key == sess_ctx.access_key - ) + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") + resouce_policy_q = sa.select(KeyPairRow.resource_policy).where(KeyPairRow.access_key == main_ak) select_query = sa.select(KeyPairResourcePolicyRow).where( KeyPairResourcePolicyRow.name == resouce_policy_q.scalar_subquery() ) @@ -146,7 +160,7 @@ async def check_keypair_resource_limit( if resource_policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) resource_policy_map = { "total_resource_slots": resource_policy.total_resource_slots, @@ -155,14 +169,11 @@ async def check_keypair_resource_limit( total_keypair_allowed = ResourceSlot.from_policy( resource_policy_map, cast(Mapping[str, Any], sched_ctx.known_slot_types) ) - - if sess_ctx.access_key is None: - return PredicateResult(False, "Session has no access key") key_occupied = await sched_ctx.registry.get_keypair_occupancy( - AccessKey(sess_ctx.access_key), db_sess=db_sess + AccessKey(main_ak), db_sess=db_sess ) - log.debug("keypair:{} current-occupancy: {}", sess_ctx.access_key, key_occupied) - log.debug("keypair:{} total-allowed: {}", sess_ctx.access_key, total_keypair_allowed) + log.debug("keypair:{} current-occupancy: {}", main_ak, key_occupied) + log.debug("keypair:{} total-allowed: {}", main_ak, total_keypair_allowed) if not (key_occupied + sess_ctx.requested_slots <= total_keypair_allowed): return PredicateResult( False, @@ -300,10 +311,13 @@ async def check_pending_session_count_limit( result = True failure_msgs = [] + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") query = ( sa.select(SessionRow) .where( - (SessionRow.access_key == sess_ctx.access_key) + (SessionRow.user_uuid == sess_ctx.user_uuid) & (SessionRow.status == SessionStatus.PENDING) ) .options(noload("*"), load_only(SessionRow.requested_slots)) @@ -319,7 +333,7 @@ async def check_pending_session_count_limit( policy_stmt = ( sa.select(KeyPairResourcePolicyRow) .select_from(j) - .where(KeyPairRow.access_key == sess_ctx.access_key) + .where(KeyPairRow.access_key == main_ak) .options( noload("*"), load_only( @@ -331,7 +345,7 @@ async def check_pending_session_count_limit( if policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) pending_count_limit: int | None = policy.max_pending_session_count @@ -344,7 +358,7 @@ async def check_pending_session_count_limit( log.debug( "access key:{} number of pending sessions: {} / {}", - sess_ctx.access_key, + main_ak, len(pending_sessions), pending_count_limit, ) @@ -361,10 +375,13 @@ async def check_pending_session_resource_limit( result = True failure_msgs = [] + main_ak = await _resolve_main_access_key(db_sess, sess_ctx) + if main_ak is None: + return PredicateResult(False, "Session owner has no main_access_key") query = ( sa.select(SessionRow) .where( - (SessionRow.access_key == sess_ctx.access_key) + (SessionRow.user_uuid == sess_ctx.user_uuid) & (SessionRow.status == SessionStatus.PENDING) ) .options(noload("*"), load_only(SessionRow.requested_slots)) @@ -380,7 +397,7 @@ async def check_pending_session_resource_limit( policy_stmt = ( sa.select(KeyPairResourcePolicyRow) .select_from(j) - .where(KeyPairRow.access_key == sess_ctx.access_key) + .where(KeyPairRow.access_key == main_ak) .options( noload("*"), load_only( @@ -392,7 +409,7 @@ async def check_pending_session_resource_limit( if policy is None: return PredicateResult( False, - f"Resource policy not found for keypair (ak: {sess_ctx.access_key})", + f"Resource policy not found for keypair (ak: {main_ak})", ) pending_resource_limit: ResourceSlot | None = policy.max_pending_session_resource_slots @@ -413,12 +430,12 @@ async def check_pending_session_resource_limit( failure_msgs.append(msg) log.debug( "access key:{} current-occupancy of pending sessions: {}", - sess_ctx.access_key, + main_ak, current_pending_session_slots, ) log.debug( "access key:{} total-allowed of pending sessions: {}", - sess_ctx.access_key, + main_ak, pending_resource_limit, ) if not result: diff --git a/src/ai/backend/manager/services/session/actions/commit_session.py b/src/ai/backend/manager/services/session/actions/commit_session.py index ad268fb6070..9a63a53accd 100644 --- a/src/ai/backend/manager/services/session/actions/commit_session.py +++ b/src/ai/backend/manager/services/session/actions/commit_session.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class CommitSessionAction(SessionCommitAction): session_name: str - owner_access_key: AccessKey filename: str | None @override diff --git a/src/ai/backend/manager/services/session/actions/complete.py b/src/ai/backend/manager/services/session/actions/complete.py index a67eab86e11..65c67d4df38 100644 --- a/src/ai/backend/manager/services/session/actions/complete.py +++ b/src/ai/backend/manager/services/session/actions/complete.py @@ -3,7 +3,6 @@ from typing import Any, override from ai.backend.common.dto.agent.response import CodeCompletionResp -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -14,7 +13,6 @@ @dataclass class CompleteAction(SessionAction): session_name: str - owner_access_key: AccessKey code: str # TODO: Add type options: Mapping[str, Any] | None diff --git a/src/ai/backend/manager/services/session/actions/convert_session_to_image.py b/src/ai/backend/manager/services/session/actions/convert_session_to_image.py index 9c56d624815..d6892d5ba7c 100644 --- a/src/ai/backend/manager/services/session/actions/convert_session_to_image.py +++ b/src/ai/backend/manager/services/session/actions/convert_session_to_image.py @@ -3,7 +3,6 @@ from typing import override from ai.backend.common.data.session.types import CustomizedImageVisibilityScope -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -13,7 +12,6 @@ @dataclass class ConvertSessionToImageAction(SessionAction): session_name: str - owner_access_key: AccessKey image_name: str image_visibility: CustomizedImageVisibilityScope image_owner_id: uuid.UUID diff --git a/src/ai/backend/manager/services/session/actions/create_cluster.py b/src/ai/backend/manager/services/session/actions/create_cluster.py index 10e843b1df7..9756bdae519 100644 --- a/src/ai/backend/manager/services/session/actions/create_cluster.py +++ b/src/ai/backend/manager/services/session/actions/create_cluster.py @@ -30,7 +30,7 @@ class CreateClusterAction(SessionScopeAction): domain_name: str scaling_group_name: str requester_access_key: AccessKey - owner_access_key: AccessKey + owner_id: uuid.UUID tag: str enqueue_only: bool keypair_resource_policy: dict[str, Any] | None diff --git a/src/ai/backend/manager/services/session/actions/create_from_params.py b/src/ai/backend/manager/services/session/actions/create_from_params.py index 68a293cf4c9..248380ef0f3 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_params.py +++ b/src/ai/backend/manager/services/session/actions/create_from_params.py @@ -30,7 +30,7 @@ class CreateFromParamsActionParams: tag: str priority: int is_preemptible: bool - owner_access_key: AccessKey + owner_id: uuid.UUID enqueue_only: bool max_wait_seconds: int starts_at: str | None diff --git a/src/ai/backend/manager/services/session/actions/create_from_template.py b/src/ai/backend/manager/services/session/actions/create_from_template.py index 2b721d36d3a..9eea9834b73 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_template.py +++ b/src/ai/backend/manager/services/session/actions/create_from_template.py @@ -33,7 +33,7 @@ class CreateFromTemplateActionParams: tag: str | Undefined priority: int is_preemptible: bool - owner_access_key: AccessKey | Undefined + owner_id: uuid.UUID | Undefined enqueue_only: bool max_wait_seconds: int starts_at: str | None diff --git a/src/ai/backend/manager/services/session/actions/destroy_session.py b/src/ai/backend/manager/services/session/actions/destroy_session.py index 7a8d3703165..519d5efbf54 100644 --- a/src/ai/backend/manager/services/session/actions/destroy_session.py +++ b/src/ai/backend/manager/services/session/actions/destroy_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.models.user import UserRole @@ -15,7 +14,6 @@ class DestroySessionAction(SessionAction): session_name: str forced: bool recursive: bool - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/download_file.py b/src/ai/backend/manager/services/session/actions/download_file.py index 1b030460dc1..86b97843ea3 100644 --- a/src/ai/backend/manager/services/session/actions/download_file.py +++ b/src/ai/backend/manager/services/session/actions/download_file.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -14,7 +13,6 @@ class DownloadFileAction(SessionFileAction): user_id: uuid.UUID session_name: str file: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/download_files.py b/src/ai/backend/manager/services/session/actions/download_files.py index 2a86031e1f9..fec5b246eb6 100644 --- a/src/ai/backend/manager/services/session/actions/download_files.py +++ b/src/ai/backend/manager/services/session/actions/download_files.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -14,7 +13,6 @@ class DownloadFilesAction(SessionFileAction): user_id: uuid.UUID session_name: str files: list[str] - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/execute_session.py b/src/ai/backend/manager/services/session/actions/execute_session.py index 5cd34477adc..2fb685005c8 100644 --- a/src/ai/backend/manager/services/session/actions/execute_session.py +++ b/src/ai/backend/manager/services/session/actions/execute_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -21,7 +20,6 @@ class ExecuteSessionActionParams: class ExecuteSessionAction(SessionAction): session_name: str api_version: tuple[Any, ...] - owner_access_key: AccessKey params: ExecuteSessionActionParams @override diff --git a/src/ai/backend/manager/services/session/actions/get_abusing_report.py b/src/ai/backend/manager/services/session/actions/get_abusing_report.py index c6be8be7fb2..a88273e1268 100644 --- a/src/ai/backend/manager/services/session/actions/get_abusing_report.py +++ b/src/ai/backend/manager/services/session/actions/get_abusing_report.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AbuseReport, AccessKey +from ai.backend.common.types import AbuseReport from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +12,6 @@ @dataclass class GetAbusingReportAction(SessionAction): session_name: str - owner_access_key: AccessKey @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/get_commit_status.py b/src/ai/backend/manager/services/session/actions/get_commit_status.py index 6f1818b803e..beeef9669cb 100644 --- a/src/ai/backend/manager/services/session/actions/get_commit_status.py +++ b/src/ai/backend/manager/services/session/actions/get_commit_status.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class GetCommitStatusAction(SessionCommitAction): session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/get_container_logs.py b/src/ai/backend/manager/services/session/actions/get_container_logs.py index fd0bbc1bdc8..33780232cbc 100644 --- a/src/ai/backend/manager/services/session/actions/get_container_logs.py +++ b/src/ai/backend/manager/services/session/actions/get_container_logs.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AccessKey, KernelId +from ai.backend.common.types import KernelId from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +12,6 @@ @dataclass class GetContainerLogsAction(SessionAction): session_name: str - owner_access_key: AccessKey kernel_id: KernelId | None @override diff --git a/src/ai/backend/manager/services/session/actions/get_dependency_graph.py b/src/ai/backend/manager/services/session/actions/get_dependency_graph.py index 67276f5f8d0..993f31cd230 100644 --- a/src/ai/backend/manager/services/session/actions/get_dependency_graph.py +++ b/src/ai/backend/manager/services/session/actions/get_dependency_graph.py @@ -2,7 +2,6 @@ from typing import Any, override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class GetDependencyGraphAction(SessionAction): root_session_name: str - owner_access_key: AccessKey @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/get_direct_access_info.py b/src/ai/backend/manager/services/session/actions/get_direct_access_info.py index 8d34063e881..b58990d3390 100644 --- a/src/ai/backend/manager/services/session/actions/get_direct_access_info.py +++ b/src/ai/backend/manager/services/session/actions/get_direct_access_info.py @@ -2,7 +2,6 @@ from typing import Any, override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class GetDirectAccessInfoAction(SessionAction): session_name: str - owner_access_key: AccessKey @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/get_session_info.py b/src/ai/backend/manager/services/session/actions/get_session_info.py index 9e39453f5c1..3c7b197f342 100644 --- a/src/ai/backend/manager/services/session/actions/get_session_info.py +++ b/src/ai/backend/manager/services/session/actions/get_session_info.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ @dataclass class GetSessionInfoAction(SessionAction): session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/get_status_history.py b/src/ai/backend/manager/services/session/actions/get_status_history.py index bdd1067d8a0..3de466918b6 100644 --- a/src/ai/backend/manager/services/session/actions/get_status_history.py +++ b/src/ai/backend/manager/services/session/actions/get_status_history.py @@ -3,7 +3,6 @@ from typing import Any, override from ai.backend.common.data.permission.types import EntityType -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.services.session.base import SessionAction @@ -12,7 +11,6 @@ @dataclass class GetStatusHistoryAction(SessionAction): session_name: str - owner_access_key: AccessKey @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/interrupt_session.py b/src/ai/backend/manager/services/session/actions/interrupt_session.py index e92303e7a59..721c3893623 100644 --- a/src/ai/backend/manager/services/session/actions/interrupt_session.py +++ b/src/ai/backend/manager/services/session/actions/interrupt_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -11,7 +10,6 @@ @dataclass class InterruptSessionAction(SessionAction): session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/list_files.py b/src/ai/backend/manager/services/session/actions/list_files.py index 0db43c50183..6ea242ff943 100644 --- a/src/ai/backend/manager/services/session/actions/list_files.py +++ b/src/ai/backend/manager/services/session/actions/list_files.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -15,7 +14,6 @@ class ListFilesAction(SessionFileAction): user_id: uuid.UUID path: str session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/match_sessions.py b/src/ai/backend/manager/services/session/actions/match_sessions.py index ca91ae8bab5..a9c54f62360 100644 --- a/src/ai/backend/manager/services/session/actions/match_sessions.py +++ b/src/ai/backend/manager/services/session/actions/match_sessions.py @@ -3,7 +3,6 @@ from typing import Any, override from ai.backend.common.data.permission.types import RBACElementType, ScopeType -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.permission.types import RBACElementRef @@ -20,7 +19,6 @@ class MatchSessionsAction(SessionScopeAction): """ id_or_name_prefix: str - owner_access_key: AccessKey user_id: uuid.UUID @override diff --git a/src/ai/backend/manager/services/session/actions/rename_session.py b/src/ai/backend/manager/services/session/actions/rename_session.py index a2b12968f52..5f616f7ba10 100644 --- a/src/ai/backend/manager/services/session/actions/rename_session.py +++ b/src/ai/backend/manager/services/session/actions/rename_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -12,7 +11,6 @@ class RenameSessionAction(SessionAction): session_name: str new_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/restart_session.py b/src/ai/backend/manager/services/session/actions/restart_session.py index 10b65b3e910..af4ac516bdb 100644 --- a/src/ai/backend/manager/services/session/actions/restart_session.py +++ b/src/ai/backend/manager/services/session/actions/restart_session.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -11,7 +10,6 @@ @dataclass class RestartSessionAction(SessionAction): session_name: str - owner_access_key: AccessKey @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/session/actions/shutdown_service.py b/src/ai/backend/manager/services/session/actions/shutdown_service.py index a07de63baf2..4a7f1810e5a 100644 --- a/src/ai/backend/manager/services/session/actions/shutdown_service.py +++ b/src/ai/backend/manager/services/session/actions/shutdown_service.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -11,7 +10,6 @@ @dataclass class ShutdownServiceAction(SessionAppServiceAction): session_name: str - owner_access_key: AccessKey service_name: str @override diff --git a/src/ai/backend/manager/services/session/actions/upload_files.py b/src/ai/backend/manager/services/session/actions/upload_files.py index 28f76fef60f..1935ad0c3c8 100644 --- a/src/ai/backend/manager/services/session/actions/upload_files.py +++ b/src/ai/backend/manager/services/session/actions/upload_files.py @@ -3,7 +3,6 @@ from aiohttp import MultipartReader -from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData @@ -13,7 +12,6 @@ @dataclass class UploadFilesAction(SessionFileAction): session_name: str - owner_access_key: AccessKey # TODO: Refactor this. reader: MultipartReader diff --git a/src/ai/backend/manager/services/session/lifecycle.py b/src/ai/backend/manager/services/session/lifecycle.py index 1849edfe4ef..28d65740ccc 100644 --- a/src/ai/backend/manager/services/session/lifecycle.py +++ b/src/ai/backend/manager/services/session/lifecycle.py @@ -23,6 +23,7 @@ ) from ai.backend.common.plugin.hook import HookPluginContext from ai.backend.common.types import ( + AccessKey, ResourceSlot, SessionId, SessionTypes, @@ -36,6 +37,7 @@ ExtendedAsyncSAEngine, execute_with_txn_retry, ) +from ai.backend.manager.repositories.user.repository import UserRepository if TYPE_CHECKING: from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient @@ -55,6 +57,7 @@ def __init__( event_producer: EventProducer, hook_plugin_ctx: HookPluginContext, registry: AgentRegistry, + user_repository: UserRepository, ) -> None: self.db = db self.valkey_stat = valkey_stat @@ -62,6 +65,7 @@ def __init__( self.event_producer = event_producer self.hook_plugin_ctx = hook_plugin_ctx self.registry = registry + self._user_repository = user_repository def _encode(sid: SessionId) -> bytes: return sid.bytes @@ -132,14 +136,29 @@ async def _post_status_transition( await self.event_producer.anycast_event( SessionStartedAnycastEvent(session_row.id, creation_id) ) - await self.hook_plugin_ctx.notify( - "POST_START_SESSION", - ( - session_row.id, - session_row.name, - session_row.access_key, - ), + # BA-5609: resolve main_access_key from owner_id; external + # hook plugins still receive the resolved access key. If the + # owner has no main_access_key configured, skip the hook — + # calling it with ``None`` would likely break plugins that + # assume a non-null keypair identifier. + session_main_access_key = await self._user_repository.get_main_access_key_by_id( + session_row.user_uuid ) + if session_main_access_key is not None: + await self.hook_plugin_ctx.notify( + "POST_START_SESSION", + ( + session_row.id, + session_row.name, + AccessKey(session_main_access_key), + ), + ) + else: + log.warning( + "POST_START_SESSION skipped: owner {} has no main_access_key (session {})", + session_row.user_uuid, + session_row.id, + ) match session_row.session_type: case SessionTypes.BATCH: await self.registry.trigger_batch_execution(session_row) diff --git a/src/ai/backend/manager/services/session/service.py b/src/ai/backend/manager/services/session/service.py index de43d25dec2..b93fccaa8d3 100644 --- a/src/ai/backend/manager/services/session/service.py +++ b/src/ai/backend/manager/services/session/service.py @@ -19,6 +19,7 @@ from dateutil.tz import tzutc from ai.backend.common.bgtask.bgtask import BackgroundTaskManager +from ai.backend.common.contexts.user import current_user from ai.backend.common.data.session.types import CustomizedImageVisibilityScope from ai.backend.common.events.event_types.kernel.types import KernelLifecycleEventReason from ai.backend.common.events.fetcher import EventFetcher @@ -269,9 +270,39 @@ def __init__( self._rpc_ptask_group = aiotools.PersistentTaskGroup() self._webhook_ptask_group = aiotools.PersistentTaskGroup() + @staticmethod + def _requester_user_id() -> uuid.UUID: + """Return the authenticated caller's user UUID from context. + + Raises ``InternalServerError`` if no user is in context (should never + happen after the auth middleware). + """ + user = current_user() + if user is None: + raise InternalServerError("No authenticated user in request context") + return user.user_id + + async def _resolve_owner_main_access_key( + self, + owner_id: uuid.UUID, + ) -> AccessKey: + """Resolve a delegated owner UUID to that user's main access key. + + Uses the narrower ``UserRepository.get_main_access_key_by_id`` helper + so we only fetch the single scalar column we need. Raises + ``InternalServerError`` if the target user has no main access key + configured. + """ + main_access_key = await self._user_repository.get_main_access_key_by_id(owner_id) + if main_access_key is None: + raise InternalServerError( + f"Delegated owner {owner_id} has no main access key configured" + ) + return AccessKey(main_access_key) + async def commit_session(self, action: CommitSessionAction) -> CommitSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() filename = action.filename myself = asyncio.current_task() @@ -280,7 +311,7 @@ async def commit_session(self, action: CommitSessionAction) -> CommitSessionActi session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) @@ -297,13 +328,13 @@ async def commit_session(self, action: CommitSessionAction) -> CommitSessionActi async def complete(self, action: CompleteAction) -> CompleteActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() code = action.code options = action.options or {} session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: @@ -320,7 +351,7 @@ async def convert_session_to_image( self, action: ConvertSessionToImageAction ) -> ConvertSessionToImageActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() image_name = action.image_name image_visibility = action.image_visibility image_owner_id = action.image_owner_id @@ -351,7 +382,7 @@ async def convert_session_to_image( session = await self._session_repository.get_session_with_group( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) @@ -406,7 +437,7 @@ async def create_cluster(self, action: CreateClusterAction) -> CreateClusterActi sudo_session_enabled = action.sudo_session_enabled keypair_resource_policy = action.keypair_resource_policy requester_access_key = action.requester_access_key - owner_access_key = action.owner_access_key + owner_access_key = await self._resolve_owner_main_access_key(action.owner_id) domain_name = action.domain_name group_name = action.group_name scaling_group_name = action.scaling_group_name @@ -475,7 +506,7 @@ async def create_from_params( keypair_resource_policy = action.keypair_resource_policy requester_access_key = action.requester_access_key - owner_access_key = action.params.owner_access_key + owner_access_key = await self._resolve_owner_main_access_key(action.params.owner_id) domain_name = action.params.domain_name group_name = action.params.group_name config = action.params.config @@ -527,7 +558,7 @@ async def create_from_params( user_uuid=user_info.owner_uuid, user_role=user_info.owner_role.value, ), - owner_access_key, + owner_access_key if owner_access_key is not None else requester_access_key, user_info.resource_policy, session_type, config, @@ -680,7 +711,10 @@ async def create_from_template( bootstrap += cmd_builder params["bootstrap_script"] = base64.b64encode(bootstrap.encode()).decode() - owner_access_key = params["owner_access_key"] + owner_id_param = params["owner_id"] + owner_access_key: AccessKey | None = None + if owner_id_param is not None and owner_id_param is not undefined: + owner_access_key = await self._resolve_owner_main_access_key(owner_id_param) config = params["config"] cluster_size = params["cluster_size"] cluster_mode = params["cluster_mode"] @@ -710,7 +744,7 @@ async def create_from_template( keypair_resource_policy, domain_name, params["group_name"], - query_on_behalf_of=(None if owner_access_key is undefined else owner_access_key), + query_on_behalf_of=owner_access_key, ) try: @@ -731,7 +765,7 @@ async def create_from_template( user_uuid=user_info.owner_uuid, user_role=user_info.owner_role.value, ), - owner_access_key, + owner_access_key if owner_access_key is not None else requester_access_key, user_info.resource_policy, session_type, config, @@ -763,14 +797,14 @@ async def create_from_template( async def destroy_session(self, action: DestroySessionAction) -> DestroySessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() forced = action.forced recursive = action.recursive # Get session IDs to terminate (based on recursive flag) session_ids = await self._session_repository.get_target_session_ids( session_name, - owner_access_key, + owner_id, recursive=recursive, ) @@ -847,13 +881,14 @@ async def terminate_sessions_in_project( async def download_file(self, action: DownloadFileAction) -> DownloadFileActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() + owner_access_key = await self._resolve_owner_main_access_key(owner_id) user_id = action.user_id file = action.file try: session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await self._agent_registry.increment_session_usage(session) @@ -873,12 +908,12 @@ async def download_file(self, action: DownloadFileAction) -> DownloadFileActionR async def download_files(self, action: DownloadFilesAction) -> DownloadFilesActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() user_id = action.user_id files = action.files session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: @@ -916,13 +951,13 @@ async def download_files(self, action: DownloadFilesAction) -> DownloadFilesActi async def execute_session(self, action: ExecuteSessionAction) -> ExecuteSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() api_version = action.api_version resp = {} session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) try: @@ -999,10 +1034,10 @@ async def get_abusing_report( self, action: GetAbusingReportAction ) -> GetAbusingReportActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) kernel = session.main_kernel @@ -1013,11 +1048,11 @@ async def get_abusing_report( async def get_commit_status(self, action: GetCommitStatusAction) -> GetCommitStatusActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) statuses = await self._agent_registry.get_commit_status([session.main_kernel.id]) @@ -1034,12 +1069,12 @@ async def get_container_logs( ) -> GetContainerLogsActionResult: resp = {"result": {"logs": ""}} session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() kernel_id = action.kernel_id compute_session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, allow_stale=True, kernel_loading_strategy=( KernelLoadingStrategy.MAIN_KERNEL_ONLY @@ -1080,10 +1115,10 @@ async def get_dependency_graph( self, action: GetDependencyGraphAction ) -> GetDependencyGraphActionResult: root_session_name = action.root_session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() dependency_graph = await self._session_repository.find_dependency_sessions( - root_session_name, owner_access_key + root_session_name, owner_id ) session_id = ( @@ -1106,11 +1141,11 @@ async def get_direct_access_info( self, action: GetDirectAccessInfoAction ) -> GetDirectAccessInfoActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() sess = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) resp = {} @@ -1143,11 +1178,11 @@ async def get_direct_access_info( async def get_session_info(self, action: GetSessionInfoAction) -> GetSessionInfoActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() sess = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await self._agent_registry.increment_session_usage(sess) @@ -1157,7 +1192,7 @@ async def get_session_info(self, action: GetSessionInfoAction) -> GetSessionInfo session_info = LegacySessionInfo( domain_name=sess.domain_name, group_id=sess.group_id, - user_id=sess.user_uuid, + user_id=sess.owner_id, lang=sess.main_kernel.image or "", # legacy image=sess.main_kernel.image or "", architecture=sess.main_kernel.architecture or "", @@ -1193,11 +1228,11 @@ async def get_status_history( self, action: GetStatusHistoryAction ) -> GetStatusHistoryActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session_row = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.NONE, ) result = session_row.status_history or {} @@ -1206,11 +1241,11 @@ async def get_status_history( async def interrupt(self, action: InterruptSessionAction) -> InterruptSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await self._agent_registry.increment_session_usage(session) @@ -1220,13 +1255,13 @@ async def interrupt(self, action: InterruptSessionAction) -> InterruptSessionAct async def list_files(self, action: ListFilesAction) -> ListFilesActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() user_id = action.user_id path = action.path session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) @@ -1249,12 +1284,12 @@ async def list_files(self, action: ListFilesAction) -> ListFilesActionResult: async def match_sessions(self, action: MatchSessionsAction) -> MatchSessionsActionResult: id_or_name_prefix = action.id_or_name_prefix - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() matches: list[dict[str, Any]] = [] sessions = await self._session_repository.match_sessions( id_or_name_prefix, - owner_access_key, + owner_id, ) if sessions: matches.extend( @@ -1269,12 +1304,12 @@ async def match_sessions(self, action: MatchSessionsAction) -> MatchSessionsActi async def rename_session(self, action: RenameSessionAction) -> RenameSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() new_name = action.new_name try: compute_session = await self._session_repository.update_session_name( - session_name, new_name, owner_access_key + session_name, new_name, owner_id ) if compute_session.status != SessionStatus.RUNNING: raise InvalidAPIParameters("Can't change name of not running session") @@ -1287,11 +1322,11 @@ async def rename_session(self, action: RenameSessionAction) -> RenameSessionActi async def restart_session(self, action: RestartSessionAction) -> RestartSessionActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS, ) await self._agent_registry.increment_session_usage(session) @@ -1300,12 +1335,12 @@ async def restart_session(self, action: RestartSessionAction) -> RestartSessionA async def shutdown_service(self, action: ShutdownServiceAction) -> ShutdownServiceActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() service_name = action.service_name session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) await self._agent_registry.shutdown_service(session, service_name) @@ -1321,11 +1356,12 @@ async def start_service(self, action: StartServiceAction) -> StartServiceActionR envs = action.envs login_session_token = action.login_session_token + keypair = await self._user_repository.admin_get_keypair(access_key) session = await asyncio.shield( self._database_ptask_group.create_task( self._session_repository.get_session_with_routing_minimal( session_name, - access_key, + keypair.user_id, ) ) ) @@ -1399,15 +1435,16 @@ async def start_service(self, action: StartServiceAction) -> StartServiceActionR "Failed to launch the app service", extra_data=result["error"] ) + resolved_access_key = await self._resolve_owner_main_access_key(session.owner_id) body = { "login_session_token": login_session_token, "kernel_host": kernel_host, "kernel_port": host_port, "session": { "id": str(session.id), - "user_uuid": str(session.user_uuid), + "user_uuid": str(session.owner_id), "group_id": str(session.group_id), - "access_key": session.access_key, + "access_key": resolved_access_key, "domain_name": session.domain_name, }, } @@ -1430,12 +1467,12 @@ async def start_service(self, action: StartServiceAction) -> StartServiceActionR async def upload_files(self, action: UploadFilesAction) -> UploadFilesActionResult: session_name = action.session_name - owner_access_key = action.owner_access_key + owner_id = self._requester_user_id() reader = action.reader session = await self._session_repository.get_session_validated( session_name, - owner_access_key, + owner_id, kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, ) diff --git a/src/ai/backend/manager/services/session/types.py b/src/ai/backend/manager/services/session/types.py index cf806a2a1b0..929000084f8 100644 --- a/src/ai/backend/manager/services/session/types.py +++ b/src/ai/backend/manager/services/session/types.py @@ -26,7 +26,7 @@ t.Key("reuse", default=True): t.ToBool, t.Key("startup_command", default=None): t.Null | t.String, t.Key("bootstrap_script", default=None): t.Null | t.String, - t.Key("owner_access_key", default=None): t.Null | t.String, + t.Key("owner_id", default=None): t.Null | tx.UUID, tx.AliasedKey(["scaling_group", "scalingGroup"], default=None): t.Null | t.String, tx.AliasedKey(["cluster_size", "clusterSize"], default=None): t.Null | t.Int[1:], tx.AliasedKey(["cluster_mode", "clusterMode"], default="SINGLE_NODE"): tx.Enum(ClusterMode), diff --git a/src/ai/backend/manager/sokovan/data/allocation.py b/src/ai/backend/manager/sokovan/data/allocation.py index 72334089fe9..c3d3e3d477e 100644 --- a/src/ai/backend/manager/sokovan/data/allocation.py +++ b/src/ai/backend/manager/sokovan/data/allocation.py @@ -80,8 +80,8 @@ class SessionAllocation: kernel_allocations: list[KernelAllocation] # List of agent allocations for this session agent_allocations: list[AgentAllocation] - # Keypair associated with the session - access_key: AccessKey + # Owner's resolved main_access_key; required for keypair-scoped concurrency tracking and resource policy lookups. + main_access_key: AccessKey # Phases that passed during scheduling passed_phases: list[SchedulingPredicate] = field(default_factory=list) # Phases that failed during scheduling (normally empty for successful allocations) @@ -141,7 +141,7 @@ def from_agent_selections( scaling_group=scaling_group, kernel_allocations=kernel_allocations, agent_allocations=agent_allocations, - access_key=session_workload.access_key, + main_access_key=session_workload.main_access_key, ) def unique_agent_ids(self) -> list[AgentId]: diff --git a/src/ai/backend/manager/sokovan/data/lifecycle.py b/src/ai/backend/manager/sokovan/data/lifecycle.py index d332814968e..e4bf41b6a33 100644 --- a/src/ai/backend/manager/sokovan/data/lifecycle.py +++ b/src/ai/backend/manager/sokovan/data/lifecycle.py @@ -64,7 +64,7 @@ class SessionDataForPull: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey kernels: list[KernelBindingData] @@ -74,12 +74,12 @@ class SessionDataForStart: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey session_type: SessionTypes name: str cluster_mode: ClusterMode kernels: list[KernelBindingData] - user_uuid: UUID + owner_id: UUID user_email: str user_name: str environ: dict[str, str] @@ -93,13 +93,13 @@ class ScheduledSessionData: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey session_type: SessionTypes name: str kernels: list[KernelBindingData] # Additional fields for PREPARED sessions cluster_mode: ClusterMode | None = None - user_uuid: UUID | None = None + owner_id: UUID | None = None user_email: str | None = None user_name: str | None = None network_type: NetworkType | None = None @@ -157,12 +157,12 @@ class PreparedSessionData: session_id: SessionId creation_id: str - access_key: AccessKey + main_access_key: AccessKey session_type: SessionTypes name: str cluster_mode: ClusterMode kernels: list[KernelStartData] - user_uuid: UUID + owner_id: UUID user_email: str user_name: str network_type: NetworkType | None = None diff --git a/src/ai/backend/manager/sokovan/data/workload.py b/src/ai/backend/manager/sokovan/data/workload.py index c953818621e..ab5d5cafcec 100644 --- a/src/ai/backend/manager/sokovan/data/workload.py +++ b/src/ai/backend/manager/sokovan/data/workload.py @@ -77,12 +77,12 @@ class SessionWorkload: # Session identifier session_id: SessionId - # User identification for fairness calculation - access_key: AccessKey + # Owner's resolved main_access_key; required for keypair-scoped concurrency tracking and resource policy lookups. + main_access_key: AccessKey # Resource requirements requested_slots: ResourceSlot - # User UUID for user resource limit checks - user_uuid: UUID + # Owner (user) UUID for user resource limit checks + owner_id: UUID # Group ID for group resource limit checks group_id: UUID # Domain name for domain resource limit checks diff --git a/src/ai/backend/manager/sokovan/deployment/route/coordinator.py b/src/ai/backend/manager/sokovan/deployment/route/coordinator.py index 3ce2c514987..3625eb5a81e 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/coordinator.py +++ b/src/ai/backend/manager/sokovan/deployment/route/coordinator.py @@ -370,15 +370,6 @@ async def _handle_status_transitions( batch_updaters, BulkCreator(specs=all_history_specs) ) - # Record running_at in Valkey for routes that just transitioned to RUNNING - if ( - transitions.success is not None - and transitions.success.status == RouteStatus.RUNNING - and result.successes - ): - for route in result.successes: - await self._valkey_schedule.mark_route_running_at(str(route.route_id)) - async def process_if_needed(self, lifecycle_type: RouteLifecycleType) -> None: """ Process route lifecycle operation if needed (based on internal state). diff --git a/src/ai/backend/manager/sokovan/deployment/route/executor.py b/src/ai/backend/manager/sokovan/deployment/route/executor.py index 8e9d5058f3c..4285a9b308c 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/executor.py +++ b/src/ai/backend/manager/sokovan/deployment/route/executor.py @@ -193,11 +193,6 @@ async def check_running_routes(self, routes: Sequence[RouteData]) -> RouteExecut with RouteRecorderContext.shared_step("fetch_kernel_connection_info"): await self._populate_replica_info(routes_missing_replica) - # Phase 4: Ensure RouteHealthRecords exist in Valkey for routes with replica info - routes_with_replica = [r for r in successes if r.replica_host and r.replica_port] - if routes_with_replica: - await self._ensure_health_records(routes_with_replica) - return RouteExecutionResult( successes=successes, errors=errors, @@ -225,29 +220,6 @@ async def _populate_replica_info(self, routes: Sequence[RouteData]) -> None: if populated_routes: await self._initialize_health_records(populated_routes, updates) - async def _ensure_health_records(self, routes: Sequence[RouteData]) -> None: - """Ensure RouteHealthRecords exist in Valkey for routes that already have replica info. - - Routes may already have replica_host/port in DB (set by a previous cycle or legacy code) - but lack a RouteHealthRecord in Valkey. This method checks and initializes missing records. - """ - route_id_strs = [str(r.route_id) for r in routes] - existing = await self._valkey_schedule.get_route_health_records_batch(route_id_strs) - missing = [r for r in routes if existing.get(str(r.route_id)) is None] - if not missing: - return - log.warning( - "RouteHealthRecord missing in Valkey for {} routes, re-initializing: {}", - len(missing), - [str(r.route_id)[:8] for r in missing], - ) - replica_info = { - r.route_id: (r.replica_host, r.replica_port) - for r in missing - if r.replica_host and r.replica_port - } - await self._initialize_health_records(missing, replica_info) - async def _initialize_health_records( self, routes: Sequence[RouteData], @@ -258,14 +230,6 @@ async def _initialize_health_records( health_configs = await self._deployment_repo.fetch_health_check_configs_by_revision_ids( revision_ids ) - redis_time = await self._valkey_schedule.get_redis_time() - - # Read existing running_at values that were set when routes transitioned to RUNNING - # These may be in partial hashes (only running_at field), so read raw field directly - running_at_map = await self._valkey_schedule.get_route_running_at_batch([ - str(r.route_id) for r in routes - ]) - records: list[RouteHealthRecord] = [] for route in routes: host, port = replica_info[route.route_id] @@ -274,21 +238,16 @@ async def _initialize_health_records( health_path = health_config.path if health_config else "/" initial_delay = health_config.initial_delay if health_config else 60.0 created_at = int(route.created_at.timestamp()) - - # Use running_at from Valkey (set at RUNNING transition), fallback to redis_time - route_id_str = str(route.route_id) - running_at = running_at_map.get(route_id_str) or redis_time - initial_delay_until = running_at + int(initial_delay) + initial_delay_until = created_at + int(initial_delay) records.append( RouteHealthRecord( - route_id=route_id_str, + route_id=str(route.route_id), created_at=created_at, initial_delay_until=initial_delay_until, health_path=health_path, inference_port=port, replica_host=host, - running_at=running_at, ) ) diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py index 6a8af7221d0..3f0049e1675 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py @@ -73,11 +73,6 @@ async def observe(self, routes: Sequence[RouteData]) -> RouteObservationResult: targets.append((route_id_str, record)) if not targets: - if checkable: - log.warning( - "Health observer: {} checkable routes but 0 have records in Valkey", - len(checkable), - ) return RouteObservationResult(observed_count=0) # Perform HTTP health checks in parallel diff --git a/src/ai/backend/manager/sokovan/scheduler/coordinator.py b/src/ai/backend/manager/sokovan/scheduler/coordinator.py index abdf61d4ac0..a35a1da9eec 100644 --- a/src/ai/backend/manager/sokovan/scheduler/coordinator.py +++ b/src/ai/backend/manager/sokovan/scheduler/coordinator.py @@ -25,7 +25,7 @@ ) from ai.backend.common.events.types import AbstractBroadcastEvent from ai.backend.common.leader.tasks import EventTaskSpec -from ai.backend.common.types import AccessKey, AgentId, SessionId +from ai.backend.common.types import AgentId, SessionId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.config.provider import ManagerConfigProvider from ai.backend.manager.data.kernel.types import KernelStatus @@ -831,6 +831,8 @@ async def _process_promotion_scaling_group( "check_kernel_status", success_detail=f"All kernels ready for {spec.success_status.value}", ): + # BA-5609: resolve main_access_key for cache invalidation consumer. + access_key_by_id = await self._repository.resolve_main_access_keys(session_ids) result = SessionExecutionResult() for session_info in session_infos: result.successes.append( @@ -839,7 +841,7 @@ async def _process_promotion_scaling_group( from_status=session_info.lifecycle.status, reason=spec.reason, creation_id=session_info.identity.creation_id, - access_key=AccessKey(session_info.metadata.access_key), + access_key=access_key_by_id.get(session_info.identity.id), ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/fair_share/aggregator.py b/src/ai/backend/manager/sokovan/scheduler/fair_share/aggregator.py index 746ec15b527..67e17f5b158 100644 --- a/src/ai/backend/manager/sokovan/scheduler/fair_share/aggregator.py +++ b/src/ai/backend/manager/sokovan/scheduler/fair_share/aggregator.py @@ -484,7 +484,7 @@ def _generate_slice_specs( spec = KernelUsageRecordCreatorSpec( kernel_id=UUID(str(kernel.id)), session_id=UUID(kernel.session.session_id), - user_uuid=kernel.user_permission.user_uuid, + user_uuid=kernel.user_permission.owner_id, project_id=kernel.user_permission.group_id, domain_name=kernel.user_permission.domain_name, resource_group=scaling_group, diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/check_precondition.py b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/check_precondition.py index 3f30adcb754..8db44812541 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/check_precondition.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/check_precondition.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from ai.backend.common.types import AccessKey from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import SessionStatus, StatusTransitions, TransitionStatus @@ -116,6 +115,11 @@ async def execute( sessions_for_pull_data.image_configs, ) + # BA-5609: source resolved main_access_key from SessionDataForPull. + access_key_by_id = { + s.session_id: s.main_access_key for s in sessions_for_pull_data.sessions + } + # Mark all sessions as success for status transition for session in sessions: session_info = session.session_info @@ -125,7 +129,7 @@ async def execute( from_status=session_info.lifecycle.status, reason="passed-preconditions", creation_id=session_info.identity.creation_id, - access_key=AccessKey(session_info.metadata.access_key), + access_key=access_key_by_id.get(session_info.identity.id), ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/deprioritize_sessions.py b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/deprioritize_sessions.py index 6169cda4d8c..55d68d15e97 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/deprioritize_sessions.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/deprioritize_sessions.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from ai.backend.common.defs.session import SESSION_PRIORITY_MIN -from ai.backend.common.types import AccessKey from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import SessionStatus, StatusTransitions, TransitionStatus @@ -112,6 +111,9 @@ async def execute( scaling_group, ) + # BA-5609: resolve main_access_key for cache invalidation consumer. + access_key_by_id = await self._repository.resolve_main_access_keys(session_ids) + # Mark all sessions as success for status transition to PENDING for session in sessions: session_info = session.session_info @@ -121,7 +123,7 @@ async def execute( from_status=session_info.lifecycle.status, reason="deprioritized-for-rescheduling", creation_id=session_info.identity.creation_id, - access_key=AccessKey(session_info.metadata.access_key), + access_key=access_key_by_id.get(session_info.identity.id), ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/schedule_sessions.py b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/schedule_sessions.py index 012fd7ff1a6..ab57109618a 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/schedule_sessions.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/schedule_sessions.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from ai.backend.common.types import AccessKey from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import SessionStatus, StatusTransitions, TransitionStatus @@ -128,7 +127,10 @@ async def execute( from_status=session.session_info.lifecycle.status, reason="no-scheduling-data", creation_id=session.session_info.identity.creation_id, - access_key=AccessKey(session.session_info.metadata.access_key), + # BA-5609: skipped sessions are only recorded to + # scheduling history; access_key is not used by that + # consumer, so leaving it None is safe. + access_key=None, ) ) return result @@ -157,7 +159,7 @@ async def execute( from_status=from_status, reason=event_data.reason, creation_id=event_data.creation_id, - access_key=event_data.access_key, + access_key=event_data.main_access_key, ) ) @@ -171,7 +173,10 @@ async def execute( from_status=session.session_info.lifecycle.status, reason="not-scheduled-this-cycle", creation_id=session.session_info.identity.creation_id, - access_key=AccessKey(session.session_info.metadata.access_key), + # BA-5609: skipped sessions are only recorded to + # scheduling history; access_key is not used by that + # consumer, so leaving it None is safe. + access_key=None, ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/start_sessions.py b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/start_sessions.py index d003f790d1f..f077fa6dba5 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/start_sessions.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/lifecycle/start_sessions.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from ai.backend.common.types import AccessKey from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import SessionStatus, StatusTransitions, TransitionStatus @@ -123,6 +122,9 @@ async def execute( sessions_data.image_configs, ) + # BA-5609: source resolved main_access_key from SessionDataForStart. + access_key_by_id = {s.session_id: s.main_access_key for s in sessions_data.sessions} + # Mark all sessions as success for status transition for session in sessions: session_info = session.session_info @@ -132,7 +134,7 @@ async def execute( from_status=session_info.lifecycle.status, reason="triggered-by-scheduler", creation_id=session_info.identity.creation_id, - access_key=AccessKey(session_info.metadata.access_key), + access_key=access_key_by_id.get(session_info.identity.id), ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/handlers/maintenance/sweep_sessions.py b/src/ai/backend/manager/sokovan/scheduler/handlers/maintenance/sweep_sessions.py index 3a31cf55ac8..39d1cd266ed 100644 --- a/src/ai/backend/manager/sokovan/scheduler/handlers/maintenance/sweep_sessions.py +++ b/src/ai/backend/manager/sokovan/scheduler/handlers/maintenance/sweep_sessions.py @@ -118,7 +118,7 @@ async def execute( from_status=session_data.session_info.lifecycle.status, reason="PENDING_TIMEOUT_EXCEEDED", creation_id=timed_out.creation_id, - access_key=timed_out.access_key, + access_key=timed_out.main_access_key, ) ) diff --git a/src/ai/backend/manager/sokovan/scheduler/launcher/launcher.py b/src/ai/backend/manager/sokovan/scheduler/launcher/launcher.py index a7f88de58a7..4c7b8538f8b 100644 --- a/src/ai/backend/manager/sokovan/scheduler/launcher/launcher.py +++ b/src/ai/backend/manager/sokovan/scheduler/launcher/launcher.py @@ -229,7 +229,7 @@ async def _start_single_session( session.session_id, session.session_type, session.name, - session.access_key, + session.main_access_key, session.cluster_mode, ) log.debug(log_fmt + "try-starting", *log_args) @@ -262,7 +262,7 @@ async def _start_single_session( } environ: dict[str, str] = { **session.environ, - "BACKENDAI_USER_UUID": str(session.user_uuid), + "BACKENDAI_USER_UUID": str(session.owner_id), "BACKENDAI_USER_EMAIL": session.user_email, "BACKENDAI_USER_NAME": session.user_name, "BACKENDAI_SESSION_ID": str(session.session_id), @@ -273,7 +273,7 @@ async def _start_single_session( k.cluster_hostname or f"{k.cluster_role}{k.cluster_idx}" for k in session.kernels ), - "BACKENDAI_ACCESS_KEY": session.access_key, + "BACKENDAI_ACCESS_KEY": session.main_access_key, # BACKENDAI_SERVICE_PORTS are set as per-kernel env-vars. "BACKENDAI_PREOPEN_PORTS": ( ",".join(str(port) for port in session.kernels[0].preopen_ports) @@ -335,7 +335,7 @@ async def create_kernels_on_agent( "image": kernel_image_config, "kernel_id": kernel_id_str, "session_id": str(session.session_id), - "owner_user_id": str(session.user_uuid), + "owner_user_id": str(session.owner_id), "owner_project_id": None, # TODO: Implement project-owned sessions "network_id": str(session.session_id), "session_type": session.session_type, diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/provisioner.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/provisioner.py index f689f0d7cab..9a7d11fc32a 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/provisioner.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/provisioner.py @@ -374,7 +374,7 @@ def _update_system_snapshot( # 1. Update resource occupancy - add the session's allocated slots # Update keypair occupancy - current_keypair = snapshot.resource_occupancy.by_keypair.get(workload.access_key) + current_keypair = snapshot.resource_occupancy.by_keypair.get(workload.main_access_key) if current_keypair is None: current_keypair = KeypairOccupancy( occupied_slots=[], session_count=0, sftp_session_count=0 @@ -389,11 +389,11 @@ def _update_system_snapshot( else: current_keypair.session_count += 1 - snapshot.resource_occupancy.by_keypair[workload.access_key] = current_keypair + snapshot.resource_occupancy.by_keypair[workload.main_access_key] = current_keypair # Update user occupancy - current_user = snapshot.resource_occupancy.by_user.get(workload.user_uuid, []) - snapshot.resource_occupancy.by_user[workload.user_uuid] = add_quantities( + current_user = snapshot.resource_occupancy.by_user.get(workload.owner_id, []) + snapshot.resource_occupancy.by_user[workload.owner_id] = add_quantities( current_user, total_quantities ) @@ -412,12 +412,20 @@ def _update_system_snapshot( # 2. Update concurrency counts if workload.is_private: # Increment SFTP session count - current_sftp = snapshot.concurrency.sftp_sessions_by_keypair.get(workload.access_key, 0) - snapshot.concurrency.sftp_sessions_by_keypair[workload.access_key] = current_sftp + 1 + current_sftp = snapshot.concurrency.sftp_sessions_by_keypair.get( + workload.main_access_key, 0 + ) + snapshot.concurrency.sftp_sessions_by_keypair[workload.main_access_key] = ( + current_sftp + 1 + ) else: # Increment regular session count - current_sessions = snapshot.concurrency.sessions_by_keypair.get(workload.access_key, 0) - snapshot.concurrency.sessions_by_keypair[workload.access_key] = current_sessions + 1 + current_sessions = snapshot.concurrency.sessions_by_keypair.get( + workload.main_access_key, 0 + ) + snapshot.concurrency.sessions_by_keypair[workload.main_access_key] = ( + current_sessions + 1 + ) async def _allocate_workload( self, diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/drf.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/drf.py index 22243c2ad2a..b4ba256d656 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/drf.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/drf.py @@ -62,7 +62,7 @@ async def sequence( # Sort workloads by dominant share (ascending order - lower share gets higher priority) # For users with the same dominant share, maintain original order - return sorted(workloads, key=lambda w: user_dominant_shares[w.access_key]) + return sorted(workloads, key=lambda w: user_dominant_shares[w.main_access_key]) def _calculate_dominant_share( self, resource_slots: ResourceSlot, total_capacity: ResourceSlot diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/fair_share.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/fair_share.py index 1b9d462cf88..462deee1bd3 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/fair_share.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/sequencers/fair_share.py @@ -80,7 +80,7 @@ async def sequence( # If a user doesn't have recorded factors, use default (lowest priority) return sorted( workloads, - key=lambda w: self._get_sort_key(w.user_uuid, user_factors), + key=lambda w: self._get_sort_key(w.owner_id, user_factors), ) async def _load_factors( @@ -92,7 +92,7 @@ async def _load_factors( # Group user_ids by project_id project_users: dict[UUID, set[UUID]] = defaultdict(set) for w in workloads: - project_users[w.group_id].add(w.user_uuid) + project_users[w.group_id].add(w.owner_id) # Build ProjectUserIds list project_user_ids = [ diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/concurrency.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/concurrency.py index 18f926f4579..fe702ddbbce 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/concurrency.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/concurrency.py @@ -22,15 +22,15 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the keypair's resource policy - policy = snapshot.resource_policy.keypair_policies.get(workload.access_key) + policy = snapshot.resource_policy.keypair_policies.get(workload.main_access_key) if not policy: # If no policy is defined, we can't validate - let it pass return # Get current session count - current_sessions = snapshot.concurrency.sessions_by_keypair.get(workload.access_key, 0) + current_sessions = snapshot.concurrency.sessions_by_keypair.get(workload.main_access_key, 0) current_sftp_sessions = snapshot.concurrency.sftp_sessions_by_keypair.get( - workload.access_key, 0 + workload.main_access_key, 0 ) # Check the appropriate limit based on session type diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/keypair_resource_limit.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/keypair_resource_limit.py index 88e8c4dc723..20174bf517b 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/keypair_resource_limit.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/keypair_resource_limit.py @@ -26,13 +26,13 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the keypair's resource policy - policy = snapshot.resource_policy.keypair_policies.get(workload.access_key) + policy = snapshot.resource_policy.keypair_policies.get(workload.main_access_key) if not policy: # If no policy is defined, we can't validate - let it pass return # Get current keypair occupancy (occupied_slots is list[SlotQuantity]) - key_occupancy = snapshot.resource_occupancy.by_keypair.get(workload.access_key) + key_occupancy = snapshot.resource_occupancy.by_keypair.get(workload.main_access_key) if key_occupancy: key_occupied = ResourceSlot({ sq.slot_name: sq.quantity for sq in key_occupancy.occupied_slots diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_count_limit.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_count_limit.py index a4ae0a3c51e..c14e64a861c 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_count_limit.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_count_limit.py @@ -22,7 +22,7 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the keypair's resource policy - policy = snapshot.resource_policy.keypair_policies.get(workload.access_key) + policy = snapshot.resource_policy.keypair_policies.get(workload.main_access_key) if not policy: # If no policy is defined, we can't validate - let it pass return @@ -34,7 +34,7 @@ def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: return # Get current pending sessions for this keypair - pending_sessions = snapshot.pending_sessions.by_keypair.get(workload.access_key, []) + pending_sessions = snapshot.pending_sessions.by_keypair.get(workload.main_access_key, []) current_pending_count = len(pending_sessions) # Check if creating this session would exceed the limit diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_resource_limit.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_resource_limit.py index 758b5c577fc..0ea514479b3 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_resource_limit.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/pending_session_resource_limit.py @@ -23,7 +23,7 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the keypair's resource policy - policy = snapshot.resource_policy.keypair_policies.get(workload.access_key) + policy = snapshot.resource_policy.keypair_policies.get(workload.main_access_key) if not policy: # If no policy is defined, we can't validate - let it pass return @@ -35,7 +35,7 @@ def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: return # Calculate current pending session resource usage - pending_sessions = snapshot.pending_sessions.by_keypair.get(workload.access_key, []) + pending_sessions = snapshot.pending_sessions.by_keypair.get(workload.main_access_key, []) current_pending_slots = ResourceSlot() for session in pending_sessions: current_pending_slots += session.requested_slots diff --git a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/user_resource_limit.py b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/user_resource_limit.py index 69d8c6cb88f..fe366c99237 100644 --- a/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/user_resource_limit.py +++ b/src/ai/backend/manager/sokovan/scheduler/provisioner/validators/user_resource_limit.py @@ -23,13 +23,13 @@ def success_message(self) -> str: def validate(self, snapshot: SystemSnapshot, workload: SessionWorkload) -> None: # Get the user's resource policy - policy = snapshot.resource_policy.user_policies.get(workload.user_uuid) + policy = snapshot.resource_policy.user_policies.get(workload.owner_id) if not policy: # If no user-specific policy, skip validation (no limits apply) return # Get current user occupancy (list[SlotQuantity]) and convert to ResourceSlot - user_occupied_quantities = snapshot.resource_occupancy.by_user.get(workload.user_uuid, []) + user_occupied_quantities = snapshot.resource_occupancy.by_user.get(workload.owner_id, []) user_occupied = ResourceSlot({sq.slot_name: sq.quantity for sq in user_occupied_quantities}) # Check if adding this workload would exceed the limit diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py b/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py index 61c815d3d9e..941382e08e5 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/preparers/preparer.py @@ -125,8 +125,8 @@ async def prepare( id=session_id, creation_id=spec.session_creation_id, name=spec.session_name, - access_key=spec.access_key, - user_uuid=spec.user_scope.user_uuid, + main_access_key=spec.access_key, + owner_id=spec.user_scope.user_uuid, group_id=spec.user_scope.group_id, domain_name=spec.user_scope.domain_name, scaling_group_name=validated_scaling_group.name, @@ -254,8 +254,8 @@ async def _prepare_kernels( scaling_group=validated_scaling_group.name, domain_name=spec.user_scope.domain_name, group_id=spec.user_scope.group_id, - user_uuid=spec.user_scope.user_uuid, - access_key=spec.access_key, + owner_id=spec.user_scope.user_uuid, + main_access_key=spec.access_key, image=image_info.canonical if image_info else self.DEFAULT_IMAGE_NAME, architecture=image_info.architecture if image_info else self.DEFAULT_ARCHITECTURE, registry=image_info.registry if image_info else self.DEFAULT_REGISTRY, diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py b/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py index c8fd82d2d2d..927e92f599c 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py @@ -257,7 +257,7 @@ async def enqueue_session( hook_result = await self._hook_plugin_ctx.dispatch( "PRE_ENQUEUE_SESSION", - (session_data.id, session_data.name, session_data.access_key), + (session_data.id, session_data.name, session_data.main_access_key), return_when=ALL_COMPLETED, ) if hook_result.status != PASSED: @@ -295,7 +295,7 @@ async def enqueue_session( ) await self._hook_plugin_ctx.notify( "POST_ENQUEUE_SESSION", - (session_id, session_data.name, session_data.access_key), + (session_id, session_data.name, session_data.main_access_key), ) return session_id diff --git a/tests/component/session/conftest.py b/tests/component/session/conftest.py index d8c2730b88f..b525a9ca84c 100644 --- a/tests/component/session/conftest.py +++ b/tests/component/session/conftest.py @@ -34,7 +34,6 @@ from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.repositories.session.repository import SessionRepository from ai.backend.manager.services.agent.processors import AgentProcessors -from ai.backend.manager.services.auth.processors import AuthProcessors from ai.backend.manager.services.session.processors import SessionProcessors from ai.backend.manager.services.session.service import SessionService, SessionServiceArgs from ai.backend.manager.services.vfolder.processors.vfolder import VFolderProcessors @@ -138,7 +137,6 @@ def vfolder_processors_mock() -> VFolderProcessors: def server_module_registries( route_deps: RouteDeps, config_provider: ManagerConfigProvider, - auth_processors: AuthProcessors, session_processors: SessionProcessors, agent_processors_mock: AgentProcessors, vfolder_processors_mock: VFolderProcessors, @@ -147,7 +145,6 @@ def server_module_registries( return [ register_session_routes( SessionHandler( - auth=auth_processors, session=session_processors, agent=agent_processors_mock, vfolder=vfolder_processors_mock, diff --git a/tests/component/session/test_session_create_delegation.py b/tests/component/session/test_session_create_delegation.py index 3b3df508eba..c9a3dc083dc 100644 --- a/tests/component/session/test_session_create_delegation.py +++ b/tests/component/session/test_session_create_delegation.py @@ -1,4 +1,4 @@ -"""BA-5608: ``POST /session`` with ``owner_access_key`` must build ``UserScope`` +"""BA-5608: ``POST /session`` with ``owner_id`` must build ``UserScope`` from the owner, not the requester admin.""" from __future__ import annotations @@ -24,7 +24,7 @@ class TestDelegatedSessionCreation: - """Tests for legacy ``POST /session`` with ``owner_access_key`` (BA-5608).""" + """Tests for legacy ``POST /session`` with the new ``owner_id`` (BA-5608).""" @pytest.fixture() def stub_image_row(self) -> MagicMock: @@ -84,7 +84,7 @@ async def group_name_for_fixture( assert name is not None, "group_fixture row missing name" return str(name) - async def test_admin_create_with_owner_access_key_routes_owner_into_user_scope( + async def test_admin_create_with_owner_id_routes_owner_into_user_scope( self, admin_registry: BackendAIClientRegistry, domain_fixture: str, @@ -96,7 +96,7 @@ async def test_admin_create_with_owner_access_key_routes_owner_into_user_scope( ) -> None: """ POST /session signed by the admin keypair, with - ``owner_access_key=``, must reach + ``owner_id=``, must reach ``AgentRegistry.create_session`` with a ``UserScope`` carrying the regular user's uuid — not the admin's. """ @@ -107,7 +107,7 @@ async def test_admin_create_with_owner_access_key_routes_owner_into_user_scope( session_type=SessionTypes.INTERACTIVE, domain=domain_fixture, group=group_name_for_fixture, - owner_access_key=regular_user_fixture.keypair.access_key, + owner_id=regular_user_fixture.user_uuid, # ``reuse=False`` skips the existing-session lookup branch in # ``AgentRegistry.create_session`` so we exercise the new # session creation path. diff --git a/tests/component/session/test_session_query.py b/tests/component/session/test_session_query.py index c69c28351e0..c7c24b60270 100644 --- a/tests/component/session/test_session_query.py +++ b/tests/component/session/test_session_query.py @@ -51,7 +51,6 @@ from ai.backend.manager.models.kernel import kernels from ai.backend.manager.models.session import SessionRow from ai.backend.manager.services.agent.processors import AgentProcessors -from ai.backend.manager.services.auth.processors import AuthProcessors from ai.backend.manager.services.session.processors import SessionProcessors from ai.backend.manager.services.vfolder.processors.vfolder import VFolderProcessors @@ -62,7 +61,6 @@ def server_module_registries( route_deps: RouteDeps, config_provider: ManagerConfigProvider, - auth_processors: AuthProcessors, session_processors: SessionProcessors, agent_processors_mock: AgentProcessors, vfolder_processors_mock: VFolderProcessors, @@ -75,7 +73,6 @@ def server_module_registries( return [ register_session_routes( SessionHandler( - auth=auth_processors, session=session_processors, agent=agent_processors_mock, vfolder=vfolder_processors_mock, diff --git a/tests/unit/common/dto/manager/session/test_request.py b/tests/unit/common/dto/manager/session/test_request.py index 8e64b8a4666..c4ecc18b85f 100644 --- a/tests/unit/common/dto/manager/session/test_request.py +++ b/tests/unit/common/dto/manager/session/test_request.py @@ -25,12 +25,10 @@ GetAbusingReportRequest, GetCommitStatusRequest, GetContainerLogsRequest, - GetStatusHistoryRequest, GetTaskLogsRequest, ListFilesRequest, MatchSessionsRequest, RenameSessionRequest, - RestartSessionRequest, ShutdownServiceRequest, StartServiceRequest, SyncAgentRegistryRequest, @@ -341,13 +339,6 @@ def test_defaults(self) -> None: req = DestroySessionRequest() assert req.forced is False assert req.recursive is False - assert req.owner_access_key is None - - -class TestRestartSessionRequest: - def test_defaults(self) -> None: - req = RestartSessionRequest() - assert req.owner_access_key is None class TestMatchSessionsRequest: @@ -413,16 +404,13 @@ def test_custom_path(self) -> None: class TestGetContainerLogsRequest: def test_defaults(self) -> None: req = GetContainerLogsRequest() - assert req.owner_access_key is None assert req.kernel_id is None def test_aliases(self) -> None: kid = uuid4() req = GetContainerLogsRequest.model_validate({ - "ownerAccessKey": "AKIAEXAMPLE", "kernelId": str(kid), }) - assert req.owner_access_key == "AKIAEXAMPLE" assert req.kernel_id == kid @@ -441,9 +429,3 @@ def test_task_id_alias(self) -> None: kid = uuid4() req = GetTaskLogsRequest.model_validate({"taskId": str(kid)}) assert req.kernel_id == kid - - -class TestGetStatusHistoryRequest: - def test_defaults(self) -> None: - req = GetStatusHistoryRequest() - assert req.owner_access_key is None diff --git a/tests/unit/manager/api/adapters/test_session_adapter.py b/tests/unit/manager/api/adapters/test_session_adapter.py index 65001843b10..bcc38bc33f4 100644 --- a/tests/unit/manager/api/adapters/test_session_adapter.py +++ b/tests/unit/manager/api/adapters/test_session_adapter.py @@ -39,7 +39,7 @@ def _create_session_data( cluster_size=1, domain_name="default", group_id=uuid4(), - user_uuid=uuid4(), + owner_id=uuid4(), occupying_slots={}, requested_slots={"cpu": Decimal("1"), "mem": Decimal("1073741824")}, use_host_network=False, @@ -49,7 +49,6 @@ def _create_session_data( num_queries=0, creation_id="test-creation-id", name=name, - access_key=None, scaling_group_name="default", target_sgroup_names=None, agent_ids=None, @@ -123,7 +122,7 @@ def test_domain_and_user_fields(self) -> None: data = _create_session_data() node = SessionAdapter._session_data_to_node(data) assert node.domain_name == "default" - assert node.user_id == data.user_uuid + assert node.user_id == data.owner_id assert node.project_id == data.group_id def test_network_host_network_false(self) -> None: diff --git a/tests/unit/manager/api/compute_sessions/test_handler.py b/tests/unit/manager/api/compute_sessions/test_handler.py index c66a3db2431..30a8c14a176 100644 --- a/tests/unit/manager/api/compute_sessions/test_handler.py +++ b/tests/unit/manager/api/compute_sessions/test_handler.py @@ -70,7 +70,7 @@ def create_session_data( cluster_size=1, domain_name="default", group_id=uuid4(), - user_uuid=uuid4(), + owner_id=uuid4(), occupying_slots=ResourceSlot({"cpu": Decimal("2.0"), "mem": Decimal("4294967296")}), requested_slots=ResourceSlot({"cpu": Decimal("4.0"), "mem": Decimal("8589934592")}), use_host_network=False, @@ -80,7 +80,6 @@ def create_session_data( num_queries=0, creation_id="test-creation-id", name=name, - access_key=None, agent_ids=["agent-001"], images=images or ["cr.backend.ai/stable/python:3.11"], tag=None, @@ -123,8 +122,8 @@ def create_kernel_info( session_type=SessionTypes.INTERACTIVE, ), user_permission=UserPermission( - user_uuid=uuid4(), - access_key="TESTKEY", + owner_id=uuid4(), + main_access_key="TESTKEY", domain_name="default", group_id=uuid4(), uid=None, diff --git a/tests/unit/manager/dependencies/agents/test_registry.py b/tests/unit/manager/dependencies/agents/test_registry.py index 7ee1f2de877..cf4223e0022 100644 --- a/tests/unit/manager/dependencies/agents/test_registry.py +++ b/tests/unit/manager/dependencies/agents/test_registry.py @@ -36,6 +36,7 @@ async def test_provide_agent_registry( hook_plugin_ctx=MagicMock(), network_plugin_ctx=MagicMock(), scheduling_controller=MagicMock(), + user_repository=MagicMock(), debug=False, manager_public_key=MagicMock(), manager_secret_key=MagicMock(), diff --git a/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py b/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py index f8e96ec47e7..687bd85c648 100644 --- a/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py +++ b/tests/unit/manager/repositories/scheduler/test_scheduling_history_recording.py @@ -316,8 +316,8 @@ async def test_enqueue_session_creates_scheduling_history( id=session_id, creation_id=creation_id, name=f"test-session-{uuid.uuid4().hex[:8]}", - access_key=test_access_key, - user_uuid=test_user_uuid, + main_access_key=test_access_key, + owner_id=test_user_uuid, group_id=test_group_id, domain_name=test_domain_name, scaling_group_name=test_scaling_group_name, @@ -354,8 +354,8 @@ async def test_enqueue_session_creates_scheduling_history( scaling_group=test_scaling_group_name, domain_name=test_domain_name, group_id=test_group_id, - user_uuid=test_user_uuid, - access_key=test_access_key, + owner_id=test_user_uuid, + main_access_key=test_access_key, image="python:3.8", architecture="x86_64", registry="docker.io", diff --git a/tests/unit/manager/repositories/session/test_session_repository.py b/tests/unit/manager/repositories/session/test_session_repository.py index 7681676ed30..d9842123213 100644 --- a/tests/unit/manager/repositories/session/test_session_repository.py +++ b/tests/unit/manager/repositories/session/test_session_repository.py @@ -222,7 +222,6 @@ async def session_with_kernel(self, db_with_cleanup: ExtendedAsyncSAEngine) -> S domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, tag=None, status=SessionStatus.RUNNING, status_info=None, @@ -254,7 +253,6 @@ async def session_with_kernel(self, db_with_cleanup: ExtendedAsyncSAEngine) -> S domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, cluster_mode=ClusterMode.SINGLE_NODE.value, cluster_size=1, cluster_role="main", @@ -375,8 +373,7 @@ async def test_search_sessions( assert session_data.name == "test-session" assert session_data.domain_name == session_with_kernel.domain_name assert session_data.group_id == session_with_kernel.group_id - assert session_data.user_uuid == session_with_kernel.user_id - assert session_data.access_key == session_with_kernel.access_key + assert session_data.owner_id == session_with_kernel.user_id async def test_search_sessions_empty_result( self, @@ -548,7 +545,6 @@ async def session_with_allocations( domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, tag=None, status=SessionStatus.RUNNING, status_info=None, @@ -578,7 +574,6 @@ async def session_with_allocations( domain_name=domain_name, group_id=group_id, user_uuid=user_id, - access_key=access_key, cluster_mode=ClusterMode.SINGLE_NODE.value, cluster_size=1, cluster_role="main", diff --git a/tests/unit/manager/services/session/test_session_lifecycle_service.py b/tests/unit/manager/services/session/test_session_lifecycle_service.py index 5919d5171a2..58a4030df25 100644 --- a/tests/unit/manager/services/session/test_session_lifecycle_service.py +++ b/tests/unit/manager/services/session/test_session_lifecycle_service.py @@ -168,7 +168,7 @@ async def session_service( session_repository=mock_session_repository, scheduling_controller=mock_scheduling_controller, appproxy_client_pool=mock_appproxy_client_pool, - user_repository=MagicMock(), + user_repository=AsyncMock(), ) return SessionService(args) @@ -210,6 +210,7 @@ def _make_session_data( ) -> SessionData: return SessionData( id=session_id, + owner_id=user_id, creation_id="test-creation-id", name=name, session_type=session_type, @@ -220,8 +221,6 @@ def _make_session_data( agent_ids=["i-ubuntu"], domain_name="default", group_id=group_id, - user_uuid=user_id, - access_key=access_key, images=["cr.backend.ai/stable/python:latest"], tag=None, occupying_slots=ResourceSlot({"cpu": 1, "mem": 1024}), @@ -312,7 +311,6 @@ async def test_commit_success( action = CommitSessionAction( session_name="test-session", - owner_access_key=sample_access_key, filename=None, ) @@ -327,6 +325,7 @@ async def test_commit_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.get_session_validated = AsyncMock( side_effect=SessionNotFound("Session not found") @@ -334,7 +333,6 @@ async def test_commit_session_not_found( action = CommitSessionAction( session_name="nonexistent", - owner_access_key=sample_access_key, filename=None, ) @@ -363,7 +361,6 @@ async def test_commit_custom_filename( action = CommitSessionAction( session_name="test-session", - owner_access_key=sample_access_key, filename="my-snapshot.tar.gz", ) @@ -398,7 +395,6 @@ async def test_success( action = GetCommitStatusAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_commit_status(action) @@ -410,6 +406,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.get_session_validated = AsyncMock( side_effect=SessionNotFound("Session not found") @@ -417,7 +414,6 @@ async def test_session_not_found( action = GetCommitStatusAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -460,7 +456,6 @@ async def test_v1_query_mode( action = ExecuteSessionAction( session_name="test-session", api_version=(1,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode=None, options=None, @@ -505,7 +500,6 @@ async def test_v2_batch_mode( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="batch", options=None, @@ -544,7 +538,6 @@ async def test_v2_complete_mode( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="complete", options={}, @@ -575,7 +568,6 @@ async def test_v2_continue_without_run_id_raises( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="continue", options=None, @@ -606,7 +598,6 @@ async def test_v2_invalid_mode_raises( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="invalid_mode", options=None, @@ -637,7 +628,6 @@ async def test_v2_null_mode_raises( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode=None, options=None, @@ -678,7 +668,6 @@ async def test_null_code_defaults_to_empty_string( action = ExecuteSessionAction( session_name="test-session", api_version=(2,), - owner_access_key=sample_access_key, params=ExecuteSessionActionParams( mode="query", options=None, @@ -712,11 +701,11 @@ def delegated_session_action( self, sample_access_key: AccessKey, sample_user_id: UUID, - delegated_owner_access_key: AccessKey, + delegated_owner_id: UUID, ) -> CreateFromParamsAction: """ CreateFromParamsAction representing an admin (sample_user_id) creating - a session on behalf of another user via owner_access_key. + a session on behalf of another user via owner_id. """ return CreateFromParamsAction( params=CreateFromParamsActionParams( @@ -732,7 +721,7 @@ def delegated_session_action( tag="", priority=0, is_preemptible=True, - owner_access_key=delegated_owner_access_key, + owner_id=delegated_owner_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -784,7 +773,7 @@ async def test_image_resolve_failure_raises( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -830,7 +819,7 @@ async def test_invalid_domain_group_raises( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -888,7 +877,7 @@ async def test_create_distributed_session( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -951,7 +940,7 @@ async def test_reuse_if_exists_returns_existing( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -1011,7 +1000,7 @@ async def test_quota_exceeded_raises( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -1053,6 +1042,11 @@ async def test_owner_access_key_uses_owner_user_scope( identity leaked into the session row, causing scaling group access checks and container UID/GID lookups to use the wrong user. """ + user_repo_mock = MagicMock() + user_repo_mock.get_user_by_uuid = AsyncMock( + return_value=MagicMock(main_access_key=str(delegated_owner_access_key)) + ) + session_service._user_repository = user_repo_mock new_session_id = str(uuid4()) mock_session_repository.query_userinfo = AsyncMock( return_value=SessionOwnerContext( @@ -1159,7 +1153,7 @@ async def test_create_from_template_success( tag=undefined, priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -1205,7 +1199,7 @@ async def test_template_not_found_raises( tag="", priority=0, is_preemptible=True, - owner_access_key=sample_access_key, + owner_id=sample_user_id, enqueue_only=False, max_wait_seconds=0, starts_at=None, @@ -1266,7 +1260,7 @@ async def test_create_cluster_success( domain_name="default", scaling_group_name="default", requester_access_key=sample_access_key, - owner_access_key=sample_access_key, + owner_id=sample_user_id, tag="", enqueue_only=False, keypair_resource_policy=None, @@ -1297,7 +1291,7 @@ async def test_template_not_found_raises( domain_name="default", scaling_group_name="default", requester_access_key=sample_access_key, - owner_access_key=sample_access_key, + owner_id=sample_user_id, tag="", enqueue_only=False, keypair_resource_policy=None, @@ -1343,7 +1337,7 @@ async def test_too_many_sessions_converts_to_already_exists( domain_name="default", scaling_group_name="default", requester_access_key=sample_access_key, - owner_access_key=sample_access_key, + owner_id=sample_user_id, tag="", enqueue_only=False, keypair_resource_policy=None, @@ -1381,7 +1375,6 @@ async def test_prefix_matching_returns_sessions( action = MatchSessionsAction( id_or_name_prefix="test", - owner_access_key=sample_access_key, user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1401,7 +1394,6 @@ async def test_no_match_returns_empty( action = MatchSessionsAction( id_or_name_prefix="nonexistent", - owner_access_key=sample_access_key, user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1419,12 +1411,11 @@ async def test_owner_access_key_filtering( action = MatchSessionsAction( id_or_name_prefix="test", - owner_access_key=sample_access_key, user_id=sample_user_id, ) await session_service.match_sessions(action) - mock_session_repository.match_sessions.assert_called_once_with("test", sample_access_key) + mock_session_repository.match_sessions.assert_called_once_with("test", sample_user_id) # ==================== GetAbusingReport Tests ==================== @@ -1451,7 +1442,6 @@ async def test_valid_session_returns_report( action = GetAbusingReportAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_abusing_report(action) @@ -1462,6 +1452,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.get_session_validated = AsyncMock( side_effect=SessionNotFound("not found") @@ -1469,7 +1460,6 @@ async def test_session_not_found( action = GetAbusingReportAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -1503,7 +1493,6 @@ async def test_system_session_returns_ports( action = GetDirectAccessInfoAction( session_name="system-session", - owner_access_key=sample_access_key, ) result = await session_service.get_direct_access_info(action) @@ -1533,7 +1522,6 @@ async def test_interactive_session_returns_empty_dict( action = GetDirectAccessInfoAction( session_name="interactive-session", - owner_access_key=sample_access_key, ) result = await session_service.get_direct_access_info(action) @@ -1562,7 +1550,6 @@ async def test_agent_row_none_raises_kernel_not_ready( action = GetDirectAccessInfoAction( session_name="system-session", - owner_access_key=sample_access_key, ) with pytest.raises(KernelNotReady): @@ -1597,7 +1584,6 @@ async def test_session_with_dependencies( action = GetDependencyGraphAction( root_session_name="root-session", - owner_access_key=sample_access_key, ) result = await session_service.get_dependency_graph(action) @@ -1626,7 +1612,6 @@ async def test_root_only_session( action = GetDependencyGraphAction( root_session_name="root-session", - owner_access_key=sample_access_key, ) result = await session_service.get_dependency_graph(action) @@ -1637,6 +1622,7 @@ async def test_empty_session_id_raises_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: dep_graph: dict[str, Any] = {"session_id": "", "children": []} mock_session_repository.find_dependency_sessions = AsyncMock(return_value=dep_graph) @@ -1644,7 +1630,6 @@ async def test_empty_session_id_raises_not_found( action = GetDependencyGraphAction( root_session_name="root-session", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -1871,7 +1856,6 @@ async def test_shutdown_success( action = ShutdownServiceAction( session_name="test-session", - owner_access_key=sample_access_key, service_name="jupyter", ) result = await session_service.shutdown_service(action) @@ -1884,6 +1868,7 @@ async def test_session_not_found( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.get_session_validated = AsyncMock( side_effect=SessionNotFound("not found") @@ -1891,7 +1876,6 @@ async def test_session_not_found( action = ShutdownServiceAction( session_name="nonexistent", - owner_access_key=sample_access_key, service_name="jupyter", ) diff --git a/tests/unit/manager/services/session/test_session_service.py b/tests/unit/manager/services/session/test_session_service.py index 20999c0e168..82827bf834d 100644 --- a/tests/unit/manager/services/session/test_session_service.py +++ b/tests/unit/manager/services/session/test_session_service.py @@ -245,8 +245,7 @@ def sample_session_data( agent_ids=["i-ubuntu"], domain_name="default", group_id=sample_group_id, - user_uuid=sample_user_id, - access_key=sample_access_key, + owner_id=sample_user_id, images=["cr.backend.ai/stable/python:latest"], tag=None, occupying_slots=ResourceSlot({"cpu": 1, "mem": 1024}), @@ -297,7 +296,6 @@ async def test_success( action = MatchSessionsAction( id_or_name_prefix="test", - owner_access_key=sample_access_key, user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -320,7 +318,6 @@ async def test_no_matches( action = MatchSessionsAction( id_or_name_prefix="nonexistent", - owner_access_key=sample_access_key, user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -349,8 +346,7 @@ async def test_multiple_matches( agent_ids=[], domain_name="default", group_id=sample_group_id, - user_uuid=sample_user_id, - access_key=sample_access_key, + owner_id=sample_user_id, images=["python:latest"], tag=None, occupying_slots=ResourceSlot({}), @@ -386,7 +382,6 @@ async def test_multiple_matches( action = MatchSessionsAction( id_or_name_prefix="test", - owner_access_key=sample_access_key, user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -423,7 +418,6 @@ async def test_success( action = GetStatusHistoryAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_status_history(action) @@ -444,7 +438,6 @@ async def test_session_not_found( action = GetStatusHistoryAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -465,7 +458,6 @@ async def test_empty_status_history( action = GetStatusHistoryAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_status_history(action) @@ -503,7 +495,6 @@ async def test_success_cancelled( session_name="test-session", forced=False, recursive=False, - owner_access_key=sample_access_key, ) result = await session_service.destroy_session(action) @@ -537,7 +528,6 @@ async def test_success_terminated( session_name="test-session", forced=False, recursive=False, - owner_access_key=sample_access_key, ) result = await session_service.destroy_session(action) @@ -572,7 +562,6 @@ async def test_force_terminate_directly_terminated( session_name="test-session", forced=True, recursive=False, - owner_access_key=sample_access_key, ) result = await session_service.destroy_session(action) @@ -607,7 +596,6 @@ async def test_recursive_destroy( session_name="test-session", forced=False, recursive=True, - owner_access_key=sample_access_key, ) result = await session_service.destroy_session(action) @@ -639,7 +627,6 @@ async def test_no_sessions_to_destroy( session_name="nonexistent", forced=False, recursive=False, - owner_access_key=sample_access_key, ) result = await session_service.destroy_session(action) @@ -676,7 +663,6 @@ async def test_success( action = CompleteAction( session_name="test-session", - owner_access_key=sample_access_key, code="print('Hello')", options=None, ) @@ -702,7 +688,6 @@ async def test_session_not_found( action = CompleteAction( session_name="nonexistent", - owner_access_key=sample_access_key, code="print('Hello')", options=None, ) @@ -768,7 +753,6 @@ async def test_success( action = GetSessionInfoAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_session_info(action) @@ -796,7 +780,6 @@ async def test_success_with_no_container_id( action = GetSessionInfoAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_session_info(action) @@ -816,7 +799,6 @@ async def test_session_not_found( action = GetSessionInfoAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -847,7 +829,6 @@ async def test_success( action = DownloadFilesAction( user_id=sample_user_id, session_name="test-session", - owner_access_key=sample_access_key, files=["test_file.txt"], ) result = await session_service.download_files(action) @@ -874,7 +855,6 @@ async def test_session_not_found( action = DownloadFilesAction( user_id=sample_user_id, session_name="nonexistent", - owner_access_key=sample_access_key, files=["test_file.txt"], ) @@ -897,7 +877,6 @@ async def test_too_many_files( action = DownloadFilesAction( user_id=sample_user_id, session_name="test-session", - owner_access_key=sample_access_key, files=["file1.txt", "file2.txt", "file3.txt", "file4.txt", "file5.txt", "file6.txt"], ) @@ -927,7 +906,6 @@ async def test_success( action = GetDirectAccessInfoAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.get_direct_access_info(action) @@ -950,7 +928,6 @@ async def test_session_not_found( action = GetDirectAccessInfoAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -978,7 +955,6 @@ async def test_success( action = RenameSessionAction( session_name="test-session", - owner_access_key=sample_access_key, new_name="new-session-name", ) result = await session_service.rename_session(action) @@ -1003,7 +979,6 @@ async def test_not_running_session( action = RenameSessionAction( session_name="test-session", - owner_access_key=sample_access_key, new_name="new-session-name", ) @@ -1033,7 +1008,6 @@ async def test_success( action = RestartSessionAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.restart_session(action) @@ -1057,7 +1031,6 @@ async def test_session_not_found( action = RestartSessionAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -1086,7 +1059,6 @@ async def test_success( action = ShutdownServiceAction( session_name="test-session", - owner_access_key=sample_access_key, service_name="test-service", ) result = await session_service.shutdown_service(action) @@ -1110,7 +1082,6 @@ async def test_session_not_found( action = ShutdownServiceAction( session_name="nonexistent", - owner_access_key=sample_access_key, service_name="test-service", ) @@ -1157,7 +1128,6 @@ async def mock_next() -> MagicMock | None: action = UploadFilesAction( session_name="test-session", - owner_access_key=sample_access_key, reader=mock_reader, ) result = await session_service.upload_files(action) @@ -1182,7 +1152,6 @@ async def test_session_not_found( action = UploadFilesAction( session_name="nonexistent", - owner_access_key=sample_access_key, reader=mock_reader, ) @@ -1228,7 +1197,6 @@ async def test_success( action = ExecuteSessionAction( session_name="test-session", api_version=(4, 0), - owner_access_key=sample_access_key, params=params, ) result = await session_service.execute_session(action) @@ -1261,7 +1229,6 @@ async def test_session_not_found( action = ExecuteSessionAction( session_name="nonexistent", api_version=(4, 0), - owner_access_key=sample_access_key, params=params, ) @@ -1291,7 +1258,6 @@ async def test_success( action = InterruptSessionAction( session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.interrupt(action) @@ -1314,7 +1280,6 @@ async def test_session_not_found( action = InterruptSessionAction( session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -1348,7 +1313,6 @@ async def test_success( user_id=sample_user_id, path="/home/work", session_name="test-session", - owner_access_key=sample_access_key, ) result = await session_service.list_files(action) @@ -1374,7 +1338,6 @@ async def test_session_not_found( user_id=sample_user_id, path="/home/work", session_name="nonexistent", - owner_access_key=sample_access_key, ) with pytest.raises(SessionNotFound): @@ -1407,7 +1370,6 @@ async def test_success( action = GetContainerLogsAction( session_name="test-session", - owner_access_key=sample_access_key, kernel_id=None, # Optional - get logs from main kernel ) result = await session_service.get_container_logs(action) @@ -1432,7 +1394,6 @@ async def test_session_not_found( action = GetContainerLogsAction( session_name="nonexistent", - owner_access_key=sample_access_key, kernel_id=None, ) @@ -1712,8 +1673,8 @@ def sample_kernel_info(self) -> KernelInfo: session_type=SessionTypes.INTERACTIVE, ), user_permission=UserPermission( - user_uuid=user_id, - access_key="TESTKEY", + owner_id=user_id, + main_access_key="TESTKEY", domain_name="default", group_id=group_id, uid=1000, diff --git a/tests/unit/manager/sokovan/scheduler/conftest.py b/tests/unit/manager/sokovan/scheduler/conftest.py index 5f8e59f0c2a..5396824f2a1 100644 --- a/tests/unit/manager/sokovan/scheduler/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/conftest.py @@ -112,9 +112,9 @@ def basic_session_workload() -> SessionWorkload: """Basic SessionWorkload instance with default values.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -134,9 +134,9 @@ def batch_session_workload() -> SessionWorkload: """Batch SessionWorkload instance.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -156,9 +156,9 @@ def inference_session_workload() -> SessionWorkload: """Inference SessionWorkload instance.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -178,9 +178,9 @@ def minimal_resource_workload() -> SessionWorkload: """SessionWorkload with minimal resource requirements (1 CPU, 1 mem).""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -200,9 +200,9 @@ def small_resource_workload() -> SessionWorkload: """SessionWorkload with small resource requirements.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(2), "mem": Decimal(2)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -222,9 +222,9 @@ def medium_resource_workload() -> SessionWorkload: """SessionWorkload with medium resource requirements.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(5), "mem": Decimal(5)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -244,9 +244,9 @@ def large_resource_workload() -> SessionWorkload: """SessionWorkload with large resource requirements.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(100), "mem": Decimal(100)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -272,9 +272,9 @@ def test_domain_small_resource_workload(test_domain_name: str) -> SessionWorkloa """SessionWorkload with small resources for domain testing.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(2), "mem": Decimal(2)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name=test_domain_name, scaling_group="default", @@ -294,9 +294,9 @@ def test_domain_medium_resource_workload(test_domain_name: str) -> SessionWorklo """SessionWorkload with medium resources for domain testing.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(5), "mem": Decimal(5)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name=test_domain_name, scaling_group="default", @@ -316,9 +316,9 @@ def test_domain_large_resource_workload(test_domain_name: str) -> SessionWorkloa """SessionWorkload with large resources for domain testing.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(100), "mem": Decimal(100)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name=test_domain_name, scaling_group="default", @@ -338,9 +338,9 @@ def user1_minimal_workload() -> SessionWorkload: """Minimal workload for user1.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -366,9 +366,9 @@ def user_specific_small_workload(test_user_id: uuid.UUID) -> SessionWorkload: """Small workload for a specific user.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(2), "mem": Decimal(2)}), - user_uuid=test_user_id, + owner_id=test_user_id, group_id=uuid4(), domain_name="default", scaling_group="default", @@ -388,9 +388,9 @@ def user_specific_medium_workload(test_user_id: uuid.UUID) -> SessionWorkload: """Medium workload for a specific user.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(5), "mem": Decimal(5)}), - user_uuid=test_user_id, + owner_id=test_user_id, group_id=uuid4(), domain_name="default", scaling_group="default", @@ -410,9 +410,9 @@ def user_specific_minimal_workload(test_user_id: uuid.UUID) -> SessionWorkload: """Minimal workload for a specific user.""" return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=test_user_id, + owner_id=test_user_id, group_id=uuid4(), domain_name="default", scaling_group="default", @@ -433,9 +433,9 @@ def batch_session_past_start_time() -> SessionWorkload: past_time = datetime.now(tzutc()) - timedelta(hours=1) return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", @@ -456,9 +456,9 @@ def batch_session_future_start_time() -> SessionWorkload: future_time = datetime.now(tzutc()) + timedelta(hours=1) return SessionWorkload( session_id=SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal(1), "mem": Decimal(1)}), - user_uuid=uuid4(), + owner_id=uuid4(), group_id=uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/handlers/cleanup/test_force_terminated.py b/tests/unit/manager/sokovan/scheduler/handlers/cleanup/test_force_terminated.py index 959b4f30e34..2377b8d0bc4 100644 --- a/tests/unit/manager/sokovan/scheduler/handlers/cleanup/test_force_terminated.py +++ b/tests/unit/manager/sokovan/scheduler/handlers/cleanup/test_force_terminated.py @@ -64,7 +64,7 @@ def handler( def _make_terminating_session_data(session_id: SessionId) -> TerminatingSessionData: return TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-access-key"), + main_access_key=AccessKey("test-access-key"), creation_id="test-creation-id", status=SessionStatus.TERMINATED, status_info="FORCE_TERMINATED", diff --git a/tests/unit/manager/sokovan/scheduler/handlers/conftest.py b/tests/unit/manager/sokovan/scheduler/handlers/conftest.py index 97cc68b2458..492967da6e2 100644 --- a/tests/unit/manager/sokovan/scheduler/handlers/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/handlers/conftest.py @@ -106,8 +106,7 @@ def _create_session( name=f"session-{sid}", domain_name="default", group_id=group_id, - user_uuid=user_uuid, - access_key=access_key, + owner_id=user_uuid, session_type=session_type, priority=0, created_at=now, @@ -160,8 +159,8 @@ def _create_session( session_type=session_type, ), user_permission=UserPermission( - user_uuid=user_uuid, - access_key=access_key, + owner_id=user_uuid, + main_access_key=None, domain_name="default", group_id=group_id, uid=None, @@ -261,8 +260,8 @@ def _create_kernel( session_type=SessionTypes.INTERACTIVE, ), user_permission=UserPermission( - user_uuid=user_uuid, - access_key="test-access-key", + owner_id=user_uuid, + main_access_key=None, domain_name="default", group_id=group_id, uid=None, @@ -545,7 +544,7 @@ def _create(sessions: list[SessionWithKernels]) -> ScheduleResult: ScheduledSessionData( session_id=s.session_info.identity.id, creation_id=s.session_info.identity.creation_id, - access_key=AccessKey(s.session_info.metadata.access_key), + main_access_key=AccessKey("test-access-key"), reason="scheduled-successfully", ) for s in sessions @@ -564,7 +563,7 @@ def _create(sessions: list[SessionWithKernels]) -> SessionsForPullWithImages: SessionDataForPull( session_id=s.session_info.identity.id, creation_id=s.session_info.identity.creation_id, - access_key=AccessKey(s.session_info.metadata.access_key), + main_access_key=AccessKey("test-access-key"), kernels=[ KernelBindingData( kernel_id=KernelId(k.id), @@ -609,7 +608,7 @@ def _create(sessions: list[SessionWithKernels]) -> SessionsForStartWithImages: SessionDataForStart( session_id=s.session_info.identity.id, creation_id=s.session_info.identity.creation_id, - access_key=AccessKey(s.session_info.metadata.access_key), + main_access_key=AccessKey("test-access-key"), session_type=s.session_info.identity.session_type, name=s.session_info.identity.name, cluster_mode=ClusterMode(s.session_info.resource.cluster_mode), @@ -624,7 +623,7 @@ def _create(sessions: list[SessionWithKernels]) -> SessionsForStartWithImages: ) for k in s.kernel_infos ], - user_uuid=s.session_info.metadata.user_uuid, + owner_id=s.session_info.metadata.owner_id, user_email="test@example.com", user_name="test-user", environ={}, @@ -660,7 +659,7 @@ def _create(sessions: list[SessionWithKernels]) -> list[TerminatingSessionData]: return [ TerminatingSessionData( session_id=s.session_info.identity.id, - access_key=AccessKey(s.session_info.metadata.access_key), + main_access_key=AccessKey("test-access-key"), creation_id=s.session_info.identity.creation_id, status=s.session_info.lifecycle.status, status_info="user-requested", diff --git a/tests/unit/manager/sokovan/scheduler/handlers/test_lifecycle_handlers.py b/tests/unit/manager/sokovan/scheduler/handlers/test_lifecycle_handlers.py index 787c969edde..01f0a1740c7 100644 --- a/tests/unit/manager/sokovan/scheduler/handlers/test_lifecycle_handlers.py +++ b/tests/unit/manager/sokovan/scheduler/handlers/test_lifecycle_handlers.py @@ -123,7 +123,7 @@ async def test_partial_scheduling_returns_skipped( ScheduledSessionData( session_id=first_session.session_info.identity.id, creation_id=first_session.session_info.identity.creation_id, - access_key=AccessKey(first_session.session_info.metadata.access_key), + main_access_key=AccessKey("test-access-key"), reason="scheduled-successfully", ) ] diff --git a/tests/unit/manager/sokovan/scheduler/launcher/conftest.py b/tests/unit/manager/sokovan/scheduler/launcher/conftest.py index 3ccf846a507..8b12a6933c8 100644 --- a/tests/unit/manager/sokovan/scheduler/launcher/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/launcher/conftest.py @@ -163,7 +163,7 @@ def _create_session_for_pull( return SessionDataForPull( session_id=session_id or SessionId(uuid4()), creation_id=str(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), kernels=kernels, ) @@ -234,10 +234,10 @@ def _create_session_for_start( return SessionDataForStart( session_id=session_id or SessionId(uuid4()), creation_id=str(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), session_type=SessionTypes.INTERACTIVE, name="test-session", - user_uuid=uuid4(), + owner_id=uuid4(), user_email="test@example.com", user_name="testuser", cluster_mode=cluster_mode, diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_drf.py b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_drf.py index e3425dd9db1..199eaf1c085 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_drf.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_drf.py @@ -126,9 +126,9 @@ async def test_single_user_workloads( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -136,9 +136,9 @@ async def test_single_user_workloads( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("20"), mem=Decimal("20")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -162,9 +162,9 @@ async def test_multiple_users_different_dominant_shares( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), # 30% dominant share + main_access_key=AccessKey("user2"), # 30% dominant share requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -172,9 +172,9 @@ async def test_multiple_users_different_dominant_shares( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), # 5% dominant share (lowest) + main_access_key=AccessKey("user3"), # 5% dominant share (lowest) requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -182,9 +182,9 @@ async def test_multiple_users_different_dominant_shares( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), # 20% dominant share + main_access_key=AccessKey("user1"), # 20% dominant share requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -198,9 +198,9 @@ async def test_multiple_users_different_dominant_shares( # Should be ordered by dominant share (ascending): user3 (5%), user1 (20%), user2 (30%) assert len(result) == 3 - assert result[0].access_key == AccessKey("user3") - assert result[1].access_key == AccessKey("user1") - assert result[2].access_key == AccessKey("user2") + assert result[0].main_access_key == AccessKey("user3") + assert result[1].main_access_key == AccessKey("user1") + assert result[2].main_access_key == AccessKey("user2") async def test_multiple_users_same_dominant_share( self, scaling_group: str, sequencer: DRFSequencer, empty_system_snapshot: SystemSnapshot @@ -209,9 +209,9 @@ async def test_multiple_users_same_dominant_share( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -219,9 +219,9 @@ async def test_multiple_users_same_dominant_share( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), + main_access_key=AccessKey("user2"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -229,9 +229,9 @@ async def test_multiple_users_same_dominant_share( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), + main_access_key=AccessKey("user3"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -256,9 +256,9 @@ async def test_new_user_gets_priority( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), # 30% dominant share + main_access_key=AccessKey("user2"), # 30% dominant share requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -266,9 +266,9 @@ async def test_new_user_gets_priority( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("new_user"), # 0% dominant share (new user) + main_access_key=AccessKey("new_user"), # 0% dominant share (new user) requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -282,8 +282,8 @@ async def test_new_user_gets_priority( # New user with 0% dominant share should get priority assert len(result) == 2 - assert result[0].access_key == AccessKey("new_user") - assert result[1].access_key == AccessKey("user2") + assert result[0].main_access_key == AccessKey("new_user") + assert result[1].main_access_key == AccessKey("user2") async def test_dominant_share_calculation_with_zero_capacity( self, scaling_group: str, sequencer: DRFSequencer @@ -331,9 +331,9 @@ async def test_dominant_share_calculation_with_zero_capacity( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_fifo.py b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_fifo.py index 0815f5e161b..2a9c48add60 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_fifo.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_fifo.py @@ -71,9 +71,9 @@ async def test_preserves_order( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -81,9 +81,9 @@ async def test_preserves_order( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), + main_access_key=AccessKey("user2"), requested_slots=ResourceSlot(cpu=Decimal("20"), mem=Decimal("20")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -91,9 +91,9 @@ async def test_preserves_order( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), + main_access_key=AccessKey("user3"), requested_slots=ResourceSlot(cpu=Decimal("30"), mem=Decimal("30")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -161,9 +161,9 @@ async def test_ignores_system_snapshot( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), # User with more allocation + main_access_key=AccessKey("user2"), # User with more allocation requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -171,9 +171,9 @@ async def test_ignores_system_snapshot( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), # User with less allocation + main_access_key=AccessKey("user1"), # User with less allocation requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_lifo.py b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_lifo.py index 41394684266..240915fb253 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_lifo.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/sequencers/test_lifo.py @@ -71,9 +71,9 @@ async def test_reverses_order( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -81,9 +81,9 @@ async def test_reverses_order( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), + main_access_key=AccessKey("user2"), requested_slots=ResourceSlot(cpu=Decimal("20"), mem=Decimal("20")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -91,9 +91,9 @@ async def test_reverses_order( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), + main_access_key=AccessKey("user3"), requested_slots=ResourceSlot(cpu=Decimal("30"), mem=Decimal("30")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -115,9 +115,9 @@ async def test_single_workload( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -183,9 +183,9 @@ async def test_ignores_system_snapshot( workloads = [ SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user2"), + main_access_key=AccessKey("user2"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -193,9 +193,9 @@ async def test_ignores_system_snapshot( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -203,9 +203,9 @@ async def test_ignores_system_snapshot( ), SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user3"), # New user + main_access_key=AccessKey("user3"), # New user requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/test_provisioner.py b/tests/unit/manager/sokovan/scheduler/provisioner/test_provisioner.py index ad2e6e75ae9..b257d9d937a 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/test_provisioner.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/test_provisioner.py @@ -66,9 +66,9 @@ def _create_scheduling_data_with_strategy( # Create one pending session session = PendingSessionData( id=SessionId(uuid.uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), requested_slots=ResourceSlot({"cpu": Decimal("1"), "mem": Decimal("1024")}), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group_name="test-sg", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_concurrency.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_concurrency.py index 230ece49647..46acad45823 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_concurrency.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_concurrency.py @@ -35,9 +35,9 @@ def sftp_validator(self) -> ConcurrencyValidator: def workload(self) -> SessionWorkload: return SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -48,9 +48,9 @@ def workload(self) -> SessionWorkload: def sftp_workload(self) -> SessionWorkload: return SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_dependencies.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_dependencies.py index 4884f49da5c..63c010a524a 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_dependencies.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_dependencies.py @@ -31,9 +31,9 @@ def validator(self) -> DependenciesValidator: def test_passes_when_no_dependencies(self, validator: DependenciesValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -60,9 +60,9 @@ def test_passes_when_dependencies_satisfied(self, validator: DependenciesValidat dep_id = SessionId(uuid.uuid4()) workload = SessionWorkload( session_id=session_id, - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -99,9 +99,9 @@ def test_fails_when_dependencies_not_satisfied(self, validator: DependenciesVali dep_id = SessionId(uuid.uuid4()) workload = SessionWorkload( session_id=session_id, - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -143,9 +143,9 @@ def test_fails_when_multiple_dependencies_not_satisfied( dep_id2 = SessionId(uuid.uuid4()) workload = SessionWorkload( session_id=session_id, - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_group_resource_limit.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_group_resource_limit.py index a494f3519e6..a5d9d95cac1 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_group_resource_limit.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_group_resource_limit.py @@ -35,9 +35,9 @@ def test_passes_when_under_limit( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("2"), mem=Decimal("2")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=group_id, domain_name="default", scaling_group="default", @@ -73,9 +73,9 @@ def test_fails_when_exceeds_limit( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("5"), mem=Decimal("5")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=group_id, domain_name="default", scaling_group="default", @@ -112,9 +112,9 @@ def test_passes_when_no_limit( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("100"), mem=Decimal("100")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=group_id, domain_name="default", scaling_group="default", @@ -153,9 +153,9 @@ def test_passes_when_no_current_occupancy( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("5"), mem=Decimal("5")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=group_id, domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_keypair_resource_limit.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_keypair_resource_limit.py index 1edd0b6540e..c2bb588536f 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_keypair_resource_limit.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_keypair_resource_limit.py @@ -31,9 +31,9 @@ def validator(self) -> KeypairResourceLimitValidator: def test_passes_when_under_limit(self, validator: KeypairResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("2"), mem=Decimal("2")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -83,9 +83,9 @@ def test_passes_when_under_limit(self, validator: KeypairResourceLimitValidator) def test_fails_when_exceeds_limit(self, validator: KeypairResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("5"), mem=Decimal("5")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -136,9 +136,9 @@ def test_fails_when_exceeds_limit(self, validator: KeypairResourceLimitValidator def test_passes_when_no_policy(self, validator: KeypairResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("100"), mem=Decimal("100")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -181,9 +181,9 @@ def test_passes_when_no_current_occupancy( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("5"), mem=Decimal("5")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_pending_session_resource_limit.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_pending_session_resource_limit.py index 6699e73697b..700c850c89f 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_pending_session_resource_limit.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_pending_session_resource_limit.py @@ -31,9 +31,9 @@ def validator(self) -> PendingSessionResourceLimitValidator: def test_passes_when_under_limit(self, validator: PendingSessionResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("2"), mem=Decimal("2")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -82,9 +82,9 @@ def test_passes_when_under_limit(self, validator: PendingSessionResourceLimitVal def test_passes_when_no_limit(self, validator: PendingSessionResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("100"), mem=Decimal("100")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -131,9 +131,9 @@ def test_passes_when_no_limit(self, validator: PendingSessionResourceLimitValida def test_passes_when_no_policy(self, validator: PendingSessionResourceLimitValidator) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", @@ -173,9 +173,9 @@ def test_handles_multiple_pending_sessions( ) -> None: workload = SessionWorkload( session_id=SessionId(uuid.uuid4()), - access_key=AccessKey("user1"), + main_access_key=AccessKey("user1"), requested_slots=ResourceSlot(cpu=Decimal("1"), mem=Decimal("1")), - user_uuid=uuid.uuid4(), + owner_id=uuid.uuid4(), group_id=uuid.uuid4(), domain_name="default", scaling_group="default", diff --git a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_user_resource_limit.py b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_user_resource_limit.py index 01de7ba1b13..03aff338cfd 100644 --- a/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_user_resource_limit.py +++ b/tests/unit/manager/sokovan/scheduler/provisioner/validators/test_user_resource_limit.py @@ -37,7 +37,7 @@ def test_passes_when_under_limit( resource_occupancy=ResourceOccupancySnapshot( by_keypair={}, by_user={ - workload.user_uuid: [ + workload.owner_id: [ SlotQuantity("cpu", Decimal("3")), SlotQuantity("mem", Decimal("3")), ] @@ -49,7 +49,7 @@ def test_passes_when_under_limit( resource_policy=ResourcePolicySnapshot( keypair_policies={}, user_policies={ - workload.user_uuid: UserResourcePolicy( + workload.owner_id: UserResourcePolicy( name="default", total_resource_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), ) @@ -77,7 +77,7 @@ def test_fails_when_exceeds_limit( resource_occupancy=ResourceOccupancySnapshot( by_keypair={}, by_user={ - workload.user_uuid: [ + workload.owner_id: [ SlotQuantity("cpu", Decimal("8")), SlotQuantity("mem", Decimal("8")), ] @@ -89,7 +89,7 @@ def test_fails_when_exceeds_limit( resource_policy=ResourcePolicySnapshot( keypair_policies={}, user_policies={ - workload.user_uuid: UserResourcePolicy( + workload.owner_id: UserResourcePolicy( name="default", total_resource_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), ) @@ -151,7 +151,7 @@ def test_passes_when_no_current_occupancy( resource_policy=ResourcePolicySnapshot( keypair_policies={}, user_policies={ - workload.user_uuid: UserResourcePolicy( + workload.owner_id: UserResourcePolicy( name="default", total_resource_slots=ResourceSlot(cpu=Decimal("10"), mem=Decimal("10")), ) diff --git a/tests/unit/manager/sokovan/scheduler/terminator/conftest.py b/tests/unit/manager/sokovan/scheduler/terminator/conftest.py index f959b16b790..1d8c11fd5f0 100644 --- a/tests/unit/manager/sokovan/scheduler/terminator/conftest.py +++ b/tests/unit/manager/sokovan/scheduler/terminator/conftest.py @@ -137,7 +137,7 @@ def _create_terminating_session_data( return TerminatingSessionData( session_id=session_id or SessionId(uuid4()), - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id=str(uuid4()), status=SessionStatus.TERMINATING, status_info=status_info, @@ -213,8 +213,8 @@ def _create_kernel_info( session_type=SessionTypes.INTERACTIVE, ), user_permission=UserPermission( - user_uuid=uuid4(), - access_key="test-access-key", + owner_id=uuid4(), + main_access_key="test-access-key", domain_name="default", group_id=uuid4(), uid=None, diff --git a/tests/unit/manager/sokovan/scheduler/test_scheduler.py b/tests/unit/manager/sokovan/scheduler/test_scheduler.py index f35e5bf9797..ddc293080b7 100644 --- a/tests/unit/manager/sokovan/scheduler/test_scheduler.py +++ b/tests/unit/manager/sokovan/scheduler/test_scheduler.py @@ -65,9 +65,9 @@ def create_session_workload( return SessionWorkload( session_id=session_id, - access_key=access_key, + main_access_key=access_key, requested_slots=requested_slots, - user_uuid=user_uuid, + owner_id=user_uuid, group_id=group_id, domain_name=domain_name, scaling_group=scaling_group, diff --git a/tests/unit/manager/sokovan/scheduler/test_terminate_sessions.py b/tests/unit/manager/sokovan/scheduler/test_terminate_sessions.py index d267c2d0a9b..21cb03ba5af 100644 --- a/tests/unit/manager/sokovan/scheduler/test_terminate_sessions.py +++ b/tests/unit/manager/sokovan/scheduler/test_terminate_sessions.py @@ -126,7 +126,7 @@ async def test_terminate_sessions_single_success( terminating_session = TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id="test-creation", status=SessionStatus.TERMINATING, status_info="USER_REQUESTED", @@ -181,7 +181,7 @@ async def test_terminate_sessions_multiple_kernels( terminating_session = TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id="test-creation", status=SessionStatus.TERMINATING, status_info="FORCED_TERMINATION", @@ -233,7 +233,7 @@ async def test_terminate_sessions_partial_failure( terminating_session = TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id="test-creation", status=SessionStatus.TERMINATING, status_info="TEST_PARTIAL", @@ -310,7 +310,7 @@ async def test_terminate_sessions_concurrent_execution( sessions.append( TerminatingSessionData( session_id=session_id, - access_key=AccessKey(f"key-{i}"), + main_access_key=AccessKey(f"key-{i}"), creation_id=f"creation-{i}", status=SessionStatus.TERMINATING, status_info="BATCH_TERMINATION", @@ -364,7 +364,7 @@ async def test_terminate_sessions_empty_kernel_list( terminating_session = TerminatingSessionData( session_id=session_id, - access_key=AccessKey("test-key"), + main_access_key=AccessKey("test-key"), creation_id="test-creation", status=SessionStatus.TERMINATING, status_info="NO_KERNELS", diff --git a/tests/unit/manager/test_reconcile_agent_resources.py b/tests/unit/manager/test_reconcile_agent_resources.py index 14df353e04d..2fcf6e2a8db 100644 --- a/tests/unit/manager/test_reconcile_agent_resources.py +++ b/tests/unit/manager/test_reconcile_agent_resources.py @@ -104,6 +104,7 @@ async def registry( hook_plugin_ctx=MagicMock(), network_plugin_ctx=MagicMock(), scheduling_controller=MagicMock(), + user_repository=MagicMock(), manager_public_key=PublicKey(b"GqK]ZYY#h*9jAQbGxSwkeZX3Y*%b+DiY$7ju6sh{"), manager_secret_key=SecretKey(b"37KX6]ac^&hcnSaVo=-%eVO9M]ENe8v=BOWF(Sw$"), ) @@ -498,6 +499,7 @@ async def registry( hook_plugin_ctx=MagicMock(), network_plugin_ctx=MagicMock(), scheduling_controller=MagicMock(), + user_repository=MagicMock(), manager_public_key=PublicKey(b"GqK]ZYY#h*9jAQbGxSwkeZX3Y*%b+DiY$7ju6sh{"), manager_secret_key=SecretKey(b"37KX6]ac^&hcnSaVo=-%eVO9M]ENe8v=BOWF(Sw$"), )