diff --git a/changes/9624.feature.md b/changes/9624.feature.md new file mode 100644 index 00000000000..10a261771f3 --- /dev/null +++ b/changes/9624.feature.md @@ -0,0 +1 @@ +Add RBAC validator infrastructure to Session actions following BEP-1048 patterns diff --git a/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py b/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py index f2adf405c54..5ea816895fb 100644 --- a/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py @@ -2,8 +2,10 @@ from collections.abc import Sequence +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import KernelId from ai.backend.manager.data.kernel.types import KernelInfo +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import BatchQuerier, NoPagination from ai.backend.manager.repositories.scheduler.options import KernelConditions from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction @@ -26,13 +28,17 @@ async def load_kernels_by_ids( if not kernel_ids: return [] + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = BatchQuerier( pagination=NoPagination(), conditions=[KernelConditions.by_ids(kernel_ids)], ) action_result = await processor.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) kernel_map: dict[KernelId, KernelInfo] = {kernel.id: kernel for kernel in action_result.data} diff --git a/src/ai/backend/manager/api/gql/data_loader/session/loader.py b/src/ai/backend/manager/api/gql/data_loader/session/loader.py index 59e0f2a176e..8cd5708b59f 100644 --- a/src/ai/backend/manager/api/gql/data_loader/session/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/session/loader.py @@ -2,8 +2,10 @@ from collections.abc import Sequence +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import SessionId from ai.backend.manager.data.session.types import SessionData +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import BatchQuerier, NoPagination from ai.backend.manager.repositories.scheduler.options import SessionConditions from ai.backend.manager.services.session.actions.search import SearchSessionsAction @@ -17,13 +19,17 @@ async def load_sessions_by_ids( if not session_ids: return [] + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = BatchQuerier( pagination=NoPagination(), conditions=[SessionConditions.by_ids(session_ids)], ) action_result = await processor.search_sessions.wait_for_complete( - SearchSessionsAction(querier=querier) + SearchSessionsAction(querier=querier, user_id=user.user_id) ) session_map: dict[SessionId, SessionData] = { diff --git a/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py b/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py index 50bda150070..30a803e60aa 100644 --- a/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py +++ b/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py @@ -5,6 +5,7 @@ import strawberry from strawberry import Info +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import KernelId from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec from ai.backend.manager.api.gql.base import encode_cursor @@ -16,6 +17,7 @@ KernelV2OrderByGQL, ) from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.kernel import KernelRow from ai.backend.manager.repositories.base import QueryCondition from ai.backend.manager.repositories.scheduler.options import KernelConditions @@ -45,6 +47,10 @@ async def fetch_kernels( offset: int | None = None, base_conditions: list[QueryCondition] | None = None, ) -> KernelV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = info.context.gql_adapter.build_querier( PaginationOptions( first=first, @@ -61,7 +67,7 @@ async def fetch_kernels( ) action_result = await info.context.processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) nodes = [KernelV2GQL.from_kernel_info(kernel_info) for kernel_info in action_result.data] edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes] diff --git a/src/ai/backend/manager/api/gql/session/fetcher/session.py b/src/ai/backend/manager/api/gql/session/fetcher/session.py index 07b88f0e069..603405e00e2 100644 --- a/src/ai/backend/manager/api/gql/session/fetcher/session.py +++ b/src/ai/backend/manager/api/gql/session/fetcher/session.py @@ -5,6 +5,7 @@ import strawberry from strawberry import Info +from ai.backend.common.contexts.user import current_user from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec from ai.backend.manager.api.gql.base import encode_cursor from ai.backend.manager.api.gql.session.types import ( @@ -15,6 +16,7 @@ SessionV2OrderByGQL, ) from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.session import SessionRow from ai.backend.manager.repositories.base import QueryCondition from ai.backend.manager.repositories.scheduler.options import SessionConditions, SessionOrders @@ -44,6 +46,10 @@ async def fetch_sessions( offset: int | None = None, base_conditions: list[QueryCondition] | None = None, ) -> SessionV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = info.context.gql_adapter.build_querier( PaginationOptions( first=first, @@ -60,7 +66,7 @@ async def fetch_sessions( ) action_result = await info.context.processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=querier) + SearchSessionsAction(querier=querier, user_id=user.user_id) ) nodes = [SessionV2GQL.from_data(session_data) for session_data in action_result.data] diff --git a/src/ai/backend/manager/api/gql/session/types.py b/src/ai/backend/manager/api/gql/session/types.py index c33434660bb..26ad397e989 100644 --- a/src/ai/backend/manager/api/gql/session/types.py +++ b/src/ai/backend/manager/api/gql/session/types.py @@ -12,6 +12,7 @@ from strawberry import ID, Info from strawberry.relay import Connection, Edge, Node, NodeID +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import SessionId from ai.backend.manager.api.gql.base import OrderDirection, StringFilter, UUIDFilter, encode_cursor from ai.backend.manager.api.gql.common.types import ( @@ -41,6 +42,7 @@ from ai.backend.manager.api.gql.types import GQLFilter, GQLOrderBy, StrawberryGQLContext from ai.backend.manager.api.gql.user.types.node import UserV2GQL from ai.backend.manager.data.session.types import SessionData, SessionStatus +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import ( BatchQuerier, NoPagination, @@ -468,13 +470,17 @@ async def images(self) -> ImageV2ConnectionGQL: description="Added in 26.3.0. The kernels belonging to this session." ) async def kernels(self, info: Info[StrawberryGQLContext]) -> KernelV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + session_id = SessionId(UUID(str(self.id))) querier = BatchQuerier( pagination=NoPagination(), conditions=[KernelConditions.by_session_ids([session_id])], ) action_result = await info.context.processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) nodes = [KernelV2GQL.from_kernel_info(kernel) for kernel in action_result.data] edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes] diff --git a/src/ai/backend/manager/api/rest/compute_sessions/handler.py b/src/ai/backend/manager/api/rest/compute_sessions/handler.py index fb1c612342c..adeae27a073 100644 --- a/src/ai/backend/manager/api/rest/compute_sessions/handler.py +++ b/src/ai/backend/manager/api/rest/compute_sessions/handler.py @@ -7,6 +7,7 @@ from typing import Final from ai.backend.common.api_handlers import APIResponse, BodyParam +from ai.backend.common.contexts.user import current_user from ai.backend.common.dto.manager.compute_session import ( PaginationInfo, SearchComputeSessionsRequest, @@ -15,6 +16,7 @@ from ai.backend.common.types import SessionId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.dto.context import UserContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.services.session.actions.search import SearchSessionsAction from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction from ai.backend.manager.services.session.processors import SessionProcessors @@ -39,10 +41,14 @@ async def search_sessions( """Search compute sessions with nested container data.""" log.info("SEARCH_SESSIONS (ak:{})", ctx.access_key) + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + # Step 1: Search sessions session_querier = self._adapter.build_session_querier(body.parsed) session_result = await self._session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=session_querier) + SearchSessionsAction(querier=session_querier, user_id=user.user_id) ) # Step 2: Fetch kernels for found sessions @@ -51,7 +57,7 @@ async def search_sessions( if session_ids: kernel_querier = self._adapter.build_kernel_querier_for_sessions(session_ids) kernel_result = await self._session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=kernel_querier) + SearchKernelsAction(querier=kernel_querier, user_id=user.user_id) ) kernels_by_session = self._adapter.group_kernels_by_session(kernel_result.data) diff --git a/src/ai/backend/manager/api/rest/session/handler.py b/src/ai/backend/manager/api/rest/session/handler.py index 79b1b28ed8d..4ef40071f41 100644 --- a/src/ai/backend/manager/api/rest/session/handler.py +++ b/src/ai/backend/manager/api/rest/session/handler.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from ai.backend.common.api_handlers import APIResponse, BaseResponseModel, BodyParam, QueryParam +from ai.backend.common.contexts.user import current_user from ai.backend.common.dto.manager.session.request import ( CommitSessionRequest, CompleteRequest, @@ -91,6 +92,7 @@ from ai.backend.manager.errors.api import InvalidAPIParameters from ai.backend.manager.errors.auth import InsufficientPrivilege from ai.backend.manager.errors.resource import NoCurrentTaskContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.agent.actions.sync_agent_registry import ( SyncAgentRegistryAction, @@ -515,6 +517,9 @@ async def match_sessions( ) ) requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + user = current_user() + if user is None: + raise UserNotFound("User not found in context") log.info( "MATCH_SESSIONS(ak:{0}/{1}, prefix:{2})", requester_access_key, @@ -525,6 +530,7 @@ async def match_sessions( MatchSessionsAction( id_or_name_prefix=params.id, owner_access_key=owner_access_key, + user_id=user.user_id, ) ) return APIResponse.build(HTTPStatus.OK, MatchSessionsResponse(matches=result.result)) diff --git a/src/ai/backend/manager/services/session/actions/create_cluster.py b/src/ai/backend/manager/services/session/actions/create_cluster.py index 7a2570e738f..10e843b1df7 100644 --- a/src/ai/backend/manager/services/session/actions/create_cluster.py +++ b/src/ai/backend/manager/services/session/actions/create_cluster.py @@ -3,15 +3,23 @@ from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class CreateClusterAction(SessionAction): +class CreateClusterAction(SessionScopeAction): + """Create a new cluster session. + + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. + """ + session_name: str user_id: uuid.UUID user_role: UserRole @@ -37,6 +45,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateClusterActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_params.py b/src/ai/backend/manager/services/session/actions/create_from_params.py index 39043296d5c..68a293cf4c9 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_params.py +++ b/src/ai/backend/manager/services/session/actions/create_from_params.py @@ -6,11 +6,13 @@ import yarl +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Idea: Refactor this type using pydantic and utilize as API model @@ -41,7 +43,13 @@ class CreateFromParamsActionParams: @dataclass -class CreateFromParamsAction(SessionAction): +class CreateFromParamsAction(SessionScopeAction): + """Create a new session from parameters. + + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. + """ + params: CreateFromParamsActionParams user_id: uuid.UUID user_role: UserRole @@ -58,6 +66,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateFromParamsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_template.py b/src/ai/backend/manager/services/session/actions/create_from_template.py index 9cdd47f704d..2b721d36d3a 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_template.py +++ b/src/ai/backend/manager/services/session/actions/create_from_template.py @@ -6,12 +6,14 @@ import yarl +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.api.utils import Undefined +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Idea: Refactor this type using pydantic and utilize as API model @@ -44,7 +46,13 @@ class CreateFromTemplateActionParams: @dataclass -class CreateFromTemplateAction(SessionAction): +class CreateFromTemplateAction(SessionScopeAction): + """Create a new session from template. + + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. + """ + params: CreateFromTemplateActionParams user_id: uuid.UUID user_role: UserRole @@ -61,6 +69,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateFromTemplateActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/match_sessions.py b/src/ai/backend/manager/services/session/actions/match_sessions.py index 93bb8b25e2b..ca91ae8bab5 100644 --- a/src/ai/backend/manager/services/session/actions/match_sessions.py +++ b/src/ai/backend/manager/services/session/actions/match_sessions.py @@ -1,17 +1,27 @@ +import uuid from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Make this BatchAction @dataclass -class MatchSessionsAction(SessionAction): +class MatchSessionsAction(SessionScopeAction): + """Match sessions by ID or name prefix. + + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. + """ + id_or_name_prefix: str owner_access_key: AccessKey + user_id: uuid.UUID @override def entity_id(self) -> str | None: @@ -22,6 +32,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class MatchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search.py b/src/ai/backend/manager/services/session/actions/search.py index 83b2e226b2e..8d0b177490d 100644 --- a/src/ai/backend/manager/services/session/actions/search.py +++ b/src/ai/backend/manager/services/session/actions/search.py @@ -1,18 +1,28 @@ from __future__ import annotations +import uuid from dataclasses import dataclass from typing import override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.repositories.base import BatchQuerier -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class SearchSessionsAction(SessionAction): +class SearchSessionsAction(SessionScopeAction): + """Search sessions within a scope. + + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. + """ + querier: BatchQuerier + user_id: uuid.UUID @override def entity_id(self) -> str | None: @@ -23,6 +33,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class SearchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search_kernel.py b/src/ai/backend/manager/services/session/actions/search_kernel.py index 5864e534b8d..0a3ace75f58 100644 --- a/src/ai/backend/manager/services/session/actions/search_kernel.py +++ b/src/ai/backend/manager/services/session/actions/search_kernel.py @@ -1,19 +1,28 @@ from __future__ import annotations +import uuid from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import EntityType +from ai.backend.common.data.permission.types import EntityType, RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.kernel.types import KernelInfo +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.repositories.base import BatchQuerier -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class SearchKernelsAction(SessionAction): +class SearchKernelsAction(SessionScopeAction): + """Search kernels within a scope. + + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. + """ + querier: BatchQuerier + user_id: uuid.UUID @override @classmethod @@ -29,6 +38,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class SearchKernelsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/base.py b/src/ai/backend/manager/services/session/base.py index bcc90407892..93c9c488e74 100644 --- a/src/ai/backend/manager/services/session/base.py +++ b/src/ai/backend/manager/services/session/base.py @@ -3,6 +3,7 @@ from ai.backend.common.data.permission.types import EntityType from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult @dataclass @@ -19,3 +20,22 @@ class SessionBatchAction(BaseBatchAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.SESSION + + +@dataclass +class SessionScopeAction(BaseScopeAction): + """Base class for session actions that operate within a scope. + + Used for operations like creating or searching sessions within a specific scope. + Subclasses must implement scope_type(), scope_id(), and target_element() methods. + """ + + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.SESSION + + +@dataclass +class SessionScopeActionResult(BaseScopeActionResult): + pass diff --git a/src/ai/backend/manager/services/session/processors.py b/src/ai/backend/manager/services/session/processors.py index 647568dd8ce..bf0c50e30cd 100644 --- a/src/ai/backend/manager/services/session/processors.py +++ b/src/ai/backend/manager/services/session/processors.py @@ -1,8 +1,9 @@ -from typing import override +from typing import cast, override from ai.backend.manager.actions.monitors.monitor import ActionMonitor from ai.backend.manager.actions.processor import ActionProcessor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.actions.validator.base import ActionValidator from ai.backend.manager.actions.validators import ActionValidators from ai.backend.manager.services.session.actions.check_and_transit_status import ( CheckAndTransitStatusAction, @@ -172,18 +173,17 @@ def __init__( action_monitors: list[ActionMonitor], validators: ActionValidators, ) -> None: + scope_validator = validators.rbac.scope + single_entity_validator = validators.rbac.single_entity + + # Actions without RBAC validation (internal/legacy) self.commit_session = ActionProcessor(service.commit_session, action_monitors) self.complete = ActionProcessor(service.complete, action_monitors) self.convert_session_to_image = ActionProcessor( service.convert_session_to_image, action_monitors ) - self.create_cluster = ActionProcessor(service.create_cluster, action_monitors) - self.create_from_params = ActionProcessor(service.create_from_params, action_monitors) - self.create_from_template = ActionProcessor(service.create_from_template, action_monitors) - self.destroy_session = ActionProcessor(service.destroy_session, action_monitors) self.download_file = ActionProcessor(service.download_file, action_monitors) self.download_files = ActionProcessor(service.download_files, action_monitors) - self.execute_session = ActionProcessor(service.execute_session, action_monitors) self.get_abusing_report = ActionProcessor(service.get_abusing_report, action_monitors) self.get_commit_status = ActionProcessor(service.get_commit_status, action_monitors) self.get_container_logs = ActionProcessor(service.get_container_logs, action_monitors) @@ -191,23 +191,70 @@ def __init__( self.get_direct_access_info = ActionProcessor( service.get_direct_access_info, action_monitors ) - self.get_session_info = ActionProcessor(service.get_session_info, action_monitors) self.get_status_history = ActionProcessor(service.get_status_history, action_monitors) self.interrupt = ActionProcessor(service.interrupt, action_monitors) self.list_files = ActionProcessor(service.list_files, action_monitors) - self.match_sessions = ActionProcessor(service.match_sessions, action_monitors) self.rename_session = ActionProcessor(service.rename_session, action_monitors) self.restart_session = ActionProcessor(service.restart_session, action_monitors) - self.search_kernels = ActionProcessor(service.search_kernels, action_monitors) - self.search_sessions = ActionProcessor(service.search, action_monitors) self.shutdown_service = ActionProcessor(service.shutdown_service, action_monitors) self.start_service = ActionProcessor(service.start_service, action_monitors) self.upload_files = ActionProcessor(service.upload_files, action_monitors) - self.modify_session = ActionProcessor(service.modify_session, action_monitors) self.check_and_transit_status = ActionProcessor( service.check_and_transit_status, action_monitors ) + # Scope actions with RBAC validation + self.create_cluster = ActionProcessor( + service.create_cluster, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.create_from_params = ActionProcessor( + service.create_from_params, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.create_from_template = ActionProcessor( + service.create_from_template, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.match_sessions = ActionProcessor( + service.match_sessions, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.search_kernels = ActionProcessor( + service.search_kernels, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.search_sessions = ActionProcessor( + service.search, action_monitors, validators=[cast(ActionValidator, scope_validator)] + ) + + # Single entity actions with RBAC validation + self.destroy_session = ActionProcessor( + service.destroy_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.execute_session = ActionProcessor( + service.execute_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.get_session_info = ActionProcessor( + service.get_session_info, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.modify_session = ActionProcessor( + service.modify_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + @override def supported_actions(self) -> list[ActionSpec]: return [ diff --git a/tests/unit/manager/api/compute_sessions/test_handler.py b/tests/unit/manager/api/compute_sessions/test_handler.py index f562e2caa26..c66a3db2431 100644 --- a/tests/unit/manager/api/compute_sessions/test_handler.py +++ b/tests/unit/manager/api/compute_sessions/test_handler.py @@ -500,12 +500,12 @@ async def test_search_sessions_calls_both_processors( ) -> None: """Handler should call both search_sessions and search_kernels.""" await mock_processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=MagicMock()) + SearchSessionsAction(querier=MagicMock(), user_id=uuid4()) ) mock_processors.session.search_sessions.wait_for_complete.assert_called_once() await mock_processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=MagicMock()) + SearchKernelsAction(querier=MagicMock(), user_id=uuid4()) ) mock_processors.session.search_kernels.wait_for_complete.assert_called_once() @@ -520,7 +520,7 @@ async def test_search_sessions_empty_result( ) result = await processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=MagicMock()) + SearchSessionsAction(querier=MagicMock(), user_id=uuid4()) ) assert result.data == [] @@ -534,10 +534,10 @@ async def test_session_result_has_correct_container_grouping( ) -> None: """Kernels should be correctly grouped by session ID.""" session_result = await mock_processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=MagicMock()) + SearchSessionsAction(querier=MagicMock(), user_id=uuid4()) ) kernel_result = await mock_processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=MagicMock()) + SearchKernelsAction(querier=MagicMock(), user_id=uuid4()) ) adapter = ComputeSessionsAdapter() @@ -559,7 +559,7 @@ async def test_pagination_info_is_correct( ) -> None: """Pagination info should reflect the session search result.""" session_result = await mock_processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=MagicMock()) + SearchSessionsAction(querier=MagicMock(), user_id=uuid4()) ) assert session_result.total_count == 2 diff --git a/tests/unit/manager/api/session/test_dataloader.py b/tests/unit/manager/api/session/test_dataloader.py index 456b7f2add0..2a9f4389dbf 100644 --- a/tests/unit/manager/api/session/test_dataloader.py +++ b/tests/unit/manager/api/session/test_dataloader.py @@ -5,11 +5,24 @@ import uuid from unittest.mock import AsyncMock, MagicMock +from ai.backend.common.contexts.user import with_user +from ai.backend.common.data.user.types import UserData, UserRole from ai.backend.common.types import KernelId from ai.backend.manager.api.gql.data_loader.kernel.loader import load_kernels_by_ids from ai.backend.manager.data.kernel.types import KernelInfo +def _make_user_data() -> UserData: + return UserData( + user_id=uuid.uuid4(), + is_authorized=True, + is_admin=False, + is_superadmin=False, + role=UserRole.USER, + domain_name="default", + ) + + class TestLoadKernelsByIds: """Tests for load_kernels_by_ids function.""" @@ -47,7 +60,8 @@ async def test_returns_kernels_in_request_order(self) -> None: ) # When - result = await load_kernels_by_ids(mock_processor, [id1, id2, id3]) + with with_user(_make_user_data()): + result = await load_kernels_by_ids(mock_processor, [id1, id2, id3]) # Then assert result == [kernel1, kernel2, kernel3] @@ -60,7 +74,8 @@ async def test_returns_none_for_missing_ids(self) -> None: mock_processor = self.create_mock_processor([existing_kernel]) # When - result = await load_kernels_by_ids(mock_processor, [existing_id, missing_id]) + with with_user(_make_user_data()): + result = await load_kernels_by_ids(mock_processor, [existing_id, missing_id]) # Then assert result == [existing_kernel, None] diff --git a/tests/unit/manager/services/session/test_session_lifecycle_service.py b/tests/unit/manager/services/session/test_session_lifecycle_service.py index a7b2b6434e0..0dc2f898f65 100644 --- a/tests/unit/manager/services/session/test_session_lifecycle_service.py +++ b/tests/unit/manager/services/session/test_session_lifecycle_service.py @@ -1237,6 +1237,7 @@ async def test_prefix_matching_returns_sessions( action = MatchSessionsAction( id_or_name_prefix="test", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1249,12 +1250,14 @@ async def test_no_match_returns_empty( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.match_sessions = AsyncMock(return_value=[]) action = MatchSessionsAction( id_or_name_prefix="nonexistent", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1265,12 +1268,14 @@ async def test_owner_access_key_filtering( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.match_sessions = AsyncMock(return_value=[]) action = MatchSessionsAction( id_or_name_prefix="test", owner_access_key=sample_access_key, + user_id=sample_user_id, ) await session_service.match_sessions(action) diff --git a/tests/unit/manager/services/session/test_session_service.py b/tests/unit/manager/services/session/test_session_service.py index a447b764d28..78680a21f19 100644 --- a/tests/unit/manager/services/session/test_session_service.py +++ b/tests/unit/manager/services/session/test_session_service.py @@ -289,6 +289,7 @@ async def test_success( mock_session_repository: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully matching sessions""" mock_session_repository.match_sessions = AsyncMock(return_value=[sample_session_data]) @@ -296,6 +297,7 @@ async def test_success( action = MatchSessionsAction( id_or_name_prefix="test", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -310,6 +312,7 @@ async def test_no_matches( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test matching sessions when none found""" mock_session_repository.match_sessions = AsyncMock(return_value=[]) @@ -317,6 +320,7 @@ async def test_no_matches( action = MatchSessionsAction( id_or_name_prefix="nonexistent", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -382,6 +386,7 @@ async def test_multiple_matches( action = MatchSessionsAction( id_or_name_prefix="test", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1600,6 +1605,7 @@ async def test_search_sessions( session_service: SessionService, mock_session_repository: MagicMock, sample_session_data: SessionData, + sample_user_id: UUID, ) -> None: """Test searching sessions with querier""" mock_session_repository.search = AsyncMock( @@ -1616,7 +1622,7 @@ async def test_search_sessions( conditions=[], orders=[], ) - action = SearchSessionsAction(querier=querier) + action = SearchSessionsAction(querier=querier, user_id=sample_user_id) result = await session_service.search(action) assert result.data == [sample_session_data] @@ -1629,6 +1635,7 @@ async def test_search_sessions_empty_result( self, session_service: SessionService, mock_session_repository: MagicMock, + sample_user_id: UUID, ) -> None: """Test searching sessions when no results are found""" mock_session_repository.search = AsyncMock( @@ -1645,7 +1652,7 @@ async def test_search_sessions_empty_result( conditions=[], orders=[], ) - action = SearchSessionsAction(querier=querier) + action = SearchSessionsAction(querier=querier, user_id=sample_user_id) result = await session_service.search(action) assert result.data == [] @@ -1656,6 +1663,7 @@ async def test_search_sessions_with_pagination( session_service: SessionService, mock_session_repository: MagicMock, sample_session_data: SessionData, + sample_user_id: UUID, ) -> None: """Test searching sessions with pagination""" mock_session_repository.search = AsyncMock( @@ -1672,7 +1680,7 @@ async def test_search_sessions_with_pagination( conditions=[], orders=[], ) - action = SearchSessionsAction(querier=querier) + action = SearchSessionsAction(querier=querier, user_id=sample_user_id) result = await session_service.search(action) assert result.total_count == 25 @@ -1783,6 +1791,7 @@ async def test_search_kernels( session_service: SessionService, mock_session_repository: MagicMock, sample_kernel_info: KernelInfo, + sample_user_id: UUID, ) -> None: """Test searching kernels with querier""" mock_session_repository.search_kernels = AsyncMock( @@ -1799,7 +1808,7 @@ async def test_search_kernels( conditions=[], orders=[], ) - action = SearchKernelsAction(querier=querier) + action = SearchKernelsAction(querier=querier, user_id=sample_user_id) result = await session_service.search_kernels(action) assert result.data == [sample_kernel_info] @@ -1812,6 +1821,7 @@ async def test_search_kernels_empty_result( self, session_service: SessionService, mock_session_repository: MagicMock, + sample_user_id: UUID, ) -> None: """Test searching kernels when no results are found""" mock_session_repository.search_kernels = AsyncMock( @@ -1828,7 +1838,7 @@ async def test_search_kernels_empty_result( conditions=[], orders=[], ) - action = SearchKernelsAction(querier=querier) + action = SearchKernelsAction(querier=querier, user_id=sample_user_id) result = await session_service.search_kernels(action) assert result.data == [] @@ -1839,6 +1849,7 @@ async def test_search_kernels_with_pagination( session_service: SessionService, mock_session_repository: MagicMock, sample_kernel_info: KernelInfo, + sample_user_id: UUID, ) -> None: """Test searching kernels with pagination""" mock_session_repository.search_kernels = AsyncMock( @@ -1855,7 +1866,7 @@ async def test_search_kernels_with_pagination( conditions=[], orders=[], ) - action = SearchKernelsAction(querier=querier) + action = SearchKernelsAction(querier=querier, user_id=sample_user_id) result = await session_service.search_kernels(action) assert result.total_count == 25