From 71915e01ff5af071a4f5fd42fc107dac47cf9ba6 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Wed, 4 Mar 2026 18:04:01 +0900 Subject: [PATCH 01/10] feat(BA-2946): Apply RBAC validator infrastructure to Session actions Refactor Session actions to use BaseScopeAction and BaseSingleEntityAction for RBAC validation support, following VFolder patterns from BEP-1048. Changes: - Add SessionScopeAction and SessionSingleEntityAction base classes - Refactor 6 scope-level actions (Create, Search, Match) - Add _scope_type and _scope_id fields for RBAC validation - Implement scope_type(), scope_id(), target_element() methods - Refactor 4 single-entity actions (Get, Destroy, Execute, Modify) - Add session_id field for RBAC validation - Implement target_entity_id() and target_element() methods - Add field_data() method to SessionSingleEntityAction TODO: - Set _scope_type/_scope_id from context in API/processor layer - Resolve session_id from session_name before RBAC validation - Apply same pattern to remaining 20+ single-entity actions - Connect RBAC validators to action processors Related: BA-4617, BA-4620, BEP-1048 Co-Authored-By: Claude Sonnet 4.5 --- .../session/actions/create_cluster.py | 29 ++++++++++- .../session/actions/create_from_params.py | 29 ++++++++++- .../session/actions/create_from_template.py | 29 ++++++++++- .../session/actions/destroy_session.py | 26 ++++++++-- .../session/actions/execute_session.py | 26 ++++++++-- .../session/actions/get_session_info.py | 26 ++++++++-- .../session/actions/match_sessions.py | 28 ++++++++++- .../session/actions/modify_session.py | 22 ++++++++- .../services/session/actions/search.py | 28 ++++++++++- .../services/session/actions/search_kernel.py | 29 +++++++++-- .../backend/manager/services/session/base.py | 49 +++++++++++++++++++ 11 files changed, 297 insertions(+), 24 deletions(-) diff --git a/src/ai/backend/manager/services/session/actions/create_cluster.py b/src/ai/backend/manager/services/session/actions/create_cluster.py index 7a2570e738f..81afa824887 100644 --- a/src/ai/backend/manager/services/session/actions/create_cluster.py +++ b/src/ai/backend/manager/services/session/actions/create_cluster.py @@ -3,15 +3,23 @@ from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class CreateClusterAction(SessionAction): +class CreateClusterAction(SessionScopeAction): + """Create a new cluster session within a scope (domain/project). + + The scope is determined by group_id (PROJECT scope) or domain_name (DOMAIN scope). + RBAC validation checks if the user has CREATE permission in the target scope. + """ + session_name: str user_id: uuid.UUID user_role: UserRole @@ -27,6 +35,8 @@ class CreateClusterAction(SessionAction): enqueue_only: bool keypair_resource_policy: dict[str, Any] | None max_wait_seconds: int + _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context + _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -37,6 +47,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return self._scope_type + + @override + def scope_id(self) -> str: + return self._scope_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType(self._scope_type.value), + element_id=self._scope_id, + ) + @dataclass class CreateClusterActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_params.py b/src/ai/backend/manager/services/session/actions/create_from_params.py index 39043296d5c..f121ad2dec1 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_params.py +++ b/src/ai/backend/manager/services/session/actions/create_from_params.py @@ -6,11 +6,13 @@ import yarl +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Idea: Refactor this type using pydantic and utilize as API model @@ -41,13 +43,21 @@ class CreateFromParamsActionParams: @dataclass -class CreateFromParamsAction(SessionAction): +class CreateFromParamsAction(SessionScopeAction): + """Create a new session from parameters within a scope (domain/project). + + The scope is determined by group_id (PROJECT scope) or domain_name (DOMAIN scope). + RBAC validation checks if the user has CREATE permission in the target scope. + """ + params: CreateFromParamsActionParams user_id: uuid.UUID user_role: UserRole sudo_session_enabled: bool requester_access_key: AccessKey keypair_resource_policy: dict[str, Any] | None + _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context + _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -58,6 +68,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return self._scope_type + + @override + def scope_id(self) -> str: + return self._scope_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType(self._scope_type.value), + element_id=self._scope_id, + ) + @dataclass class CreateFromParamsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_template.py b/src/ai/backend/manager/services/session/actions/create_from_template.py index 9cdd47f704d..3222f8e4cb4 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_template.py +++ b/src/ai/backend/manager/services/session/actions/create_from_template.py @@ -6,12 +6,14 @@ import yarl +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.api.utils import Undefined +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Idea: Refactor this type using pydantic and utilize as API model @@ -44,13 +46,21 @@ class CreateFromTemplateActionParams: @dataclass -class CreateFromTemplateAction(SessionAction): +class CreateFromTemplateAction(SessionScopeAction): + """Create a new session from template within a scope (domain/project). + + The scope is determined by group_id (PROJECT scope) or domain_name (DOMAIN scope). + RBAC validation checks if the user has CREATE permission in the target scope. + """ + params: CreateFromTemplateActionParams user_id: uuid.UUID user_role: UserRole sudo_session_enabled: bool requester_access_key: AccessKey keypair_resource_policy: dict[str, Any] | None + _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context + _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -61,6 +71,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return self._scope_type + + @override + def scope_id(self) -> str: + return self._scope_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType(self._scope_type.value), + element_id=self._scope_id, + ) + @dataclass class CreateFromTemplateActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/destroy_session.py b/src/ai/backend/manager/services/session/actions/destroy_session.py index 7a8d3703165..1269e51238b 100644 --- a/src/ai/backend/manager/services/session/actions/destroy_session.py +++ b/src/ai/backend/manager/services/session/actions/destroy_session.py @@ -1,25 +1,34 @@ from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionSingleEntityAction # TODO: Change this to BatchAction since it can destroy multiple sessions with recursive option @dataclass -class DestroySessionAction(SessionAction): +class DestroySessionAction(SessionSingleEntityAction): + """Destroy a specific session. + + RBAC validation checks if the user has DELETE permission for this session. + session_id will be resolved from session_name before RBAC validation. + """ + user_role: UserRole session_name: str forced: bool recursive: bool owner_access_key: AccessKey + session_id: str = "" # TODO: Resolve from session_name before RBAC validation @override def entity_id(self) -> str | None: - return None + return self.session_id if self.session_id else None @override @classmethod @@ -29,6 +38,17 @@ def operation_type(cls) -> ActionOperationType: # return "destroy_multi" return ActionOperationType.DELETE + @override + def target_entity_id(self) -> str: + return self.session_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.SESSION, + element_id=self.session_id, + ) + @dataclass class DestroySessionActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/execute_session.py b/src/ai/backend/manager/services/session/actions/execute_session.py index 5cd34477adc..f8db68d283a 100644 --- a/src/ai/backend/manager/services/session/actions/execute_session.py +++ b/src/ai/backend/manager/services/session/actions/execute_session.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionSingleEntityAction @dataclass @@ -18,21 +20,39 @@ class ExecuteSessionActionParams: @dataclass -class ExecuteSessionAction(SessionAction): +class ExecuteSessionAction(SessionSingleEntityAction): + """Execute code in a specific session. + + RBAC validation checks if the user has UPDATE permission for this session. + session_id will be resolved from session_name before RBAC validation. + """ + session_name: str api_version: tuple[Any, ...] owner_access_key: AccessKey params: ExecuteSessionActionParams + session_id: str = "" # TODO: Resolve from session_name before RBAC validation @override def entity_id(self) -> str | None: - return None + return self.session_id if self.session_id else None @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.UPDATE + @override + def target_entity_id(self) -> str: + return self.session_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.SESSION, + element_id=self.session_id, + ) + @dataclass class ExecuteSessionActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/get_session_info.py b/src/ai/backend/manager/services/session/actions/get_session_info.py index 9e39453f5c1..eaeb5343dc7 100644 --- a/src/ai/backend/manager/services/session/actions/get_session_info.py +++ b/src/ai/backend/manager/services/session/actions/get_session_info.py @@ -1,28 +1,48 @@ from dataclasses import dataclass from typing import override +from ai.backend.common.data.permission.types import RBACElementType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionSingleEntityAction from ai.backend.manager.services.session.types import LegacySessionInfo @dataclass -class GetSessionInfoAction(SessionAction): +class GetSessionInfoAction(SessionSingleEntityAction): + """Get information about a specific session. + + RBAC validation checks if the user has READ permission for this session. + session_id will be resolved from session_name before RBAC validation. + """ + session_name: str owner_access_key: AccessKey + session_id: str = "" # TODO: Resolve from session_name before RBAC validation @override def entity_id(self) -> str | None: - return None + return self.session_id if self.session_id else None @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.GET + @override + def target_entity_id(self) -> str: + return self.session_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.SESSION, + element_id=self.session_id, + ) + @dataclass class GetSessionInfoActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/match_sessions.py b/src/ai/backend/manager/services/session/actions/match_sessions.py index 93bb8b25e2b..4493d4d27a9 100644 --- a/src/ai/backend/manager/services/session/actions/match_sessions.py +++ b/src/ai/backend/manager/services/session/actions/match_sessions.py @@ -1,17 +1,26 @@ from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Make this BatchAction @dataclass -class MatchSessionsAction(SessionAction): +class MatchSessionsAction(SessionScopeAction): + """Match sessions by ID or name prefix within a scope (domain/project). + + RBAC validation checks if the user has READ permission in the target scope. + """ + id_or_name_prefix: str owner_access_key: AccessKey + _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context + _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -22,6 +31,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return self._scope_type + + @override + def scope_id(self) -> str: + return self._scope_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType(self._scope_type.value), + element_id=self._scope_id, + ) + @dataclass class MatchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/modify_session.py b/src/ai/backend/manager/services/session/actions/modify_session.py index 813f75a0a0d..2396666b312 100644 --- a/src/ai/backend/manager/services/session/actions/modify_session.py +++ b/src/ai/backend/manager/services/session/actions/modify_session.py @@ -2,16 +2,23 @@ from dataclasses import dataclass from typing import override +from ai.backend.common.data.permission.types import RBACElementType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.models.session import SessionRow from ai.backend.manager.repositories.base.updater import Updater -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionSingleEntityAction @dataclass -class ModifySessionAction(SessionAction): +class ModifySessionAction(SessionSingleEntityAction): + """Modify a specific session. + + RBAC validation checks if the user has UPDATE permission for this session. + """ + session_id: uuid.UUID updater: Updater[SessionRow] @@ -24,6 +31,17 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.UPDATE + @override + def target_entity_id(self) -> str: + return str(self.session_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.SESSION, + element_id=str(self.session_id), + ) + @dataclass class ModifySessionActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search.py b/src/ai/backend/manager/services/session/actions/search.py index 83b2e226b2e..a856e35819f 100644 --- a/src/ai/backend/manager/services/session/actions/search.py +++ b/src/ai/backend/manager/services/session/actions/search.py @@ -3,16 +3,25 @@ from dataclasses import dataclass from typing import override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.repositories.base import BatchQuerier -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class SearchSessionsAction(SessionAction): +class SearchSessionsAction(SessionScopeAction): + """Search sessions within a scope (domain/project). + + RBAC validation checks if the user has READ permission in the target scope. + """ + querier: BatchQuerier + _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context + _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -23,6 +32,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return self._scope_type + + @override + def scope_id(self) -> str: + return self._scope_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType(self._scope_type.value), + element_id=self._scope_id, + ) + @dataclass class SearchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search_kernel.py b/src/ai/backend/manager/services/session/actions/search_kernel.py index 5864e534b8d..7c96a903aeb 100644 --- a/src/ai/backend/manager/services/session/actions/search_kernel.py +++ b/src/ai/backend/manager/services/session/actions/search_kernel.py @@ -3,17 +3,25 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import EntityType +from ai.backend.common.data.permission.types import EntityType, RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.kernel.types import KernelInfo +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.repositories.base import BatchQuerier -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class SearchKernelsAction(SessionAction): +class SearchKernelsAction(SessionScopeAction): + """Search kernels within a scope (domain/project). + + RBAC validation checks if the user has READ permission in the target scope. + """ + querier: BatchQuerier + _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context + _scope_id: str = "" # TODO: Set from context @override @classmethod @@ -29,6 +37,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return self._scope_type + + @override + def scope_id(self) -> str: + return self._scope_id + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType(self._scope_type.value), + element_id=self._scope_id, + ) + @dataclass class SearchKernelsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/base.py b/src/ai/backend/manager/services/session/base.py index bcc90407892..7d0f9731490 100644 --- a/src/ai/backend/manager/services/session/base.py +++ b/src/ai/backend/manager/services/session/base.py @@ -3,6 +3,12 @@ from ai.backend.common.data.permission.types import EntityType from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult +from ai.backend.manager.actions.action.single_entity import ( + BaseSingleEntityAction, + BaseSingleEntityActionResult, +) +from ai.backend.manager.actions.action.types import FieldData @dataclass @@ -19,3 +25,46 @@ class SessionBatchAction(BaseBatchAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.SESSION + + +@dataclass +class SessionScopeAction(BaseScopeAction): + """Base class for session actions that operate within a scope (domain/project). + + Used for operations like creating or searching sessions within a specific scope. + Each concrete class must define _scope_type and _scope_id fields for RBAC validation. + """ + + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.SESSION + + +@dataclass +class SessionScopeActionResult(BaseScopeActionResult): + pass + + +@dataclass +class SessionSingleEntityAction(BaseSingleEntityAction): + """Base class for session actions that operate on a specific session. + + Used for operations like getting, updating, or deleting a specific session. + Each concrete class must implement target_entity_id() and target_element() + for RBAC validation. + """ + + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.SESSION + + @override + def field_data(self) -> FieldData | None: + return None + + +@dataclass +class SessionSingleEntityActionResult(BaseSingleEntityActionResult): + pass From cc428e1ac5feea19eabac2c71a42e32c1d64bfdf Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Wed, 4 Mar 2026 18:06:15 +0900 Subject: [PATCH 02/10] changelog: add news fragment for PR #9624 --- changes/9624.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/9624.feature.md diff --git a/changes/9624.feature.md b/changes/9624.feature.md new file mode 100644 index 00000000000..10a261771f3 --- /dev/null +++ b/changes/9624.feature.md @@ -0,0 +1 @@ +Add RBAC validator infrastructure to Session actions following BEP-1048 patterns From 0e4c901e3a5058077e0af300d9c2a02ec0d9ac1d Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Wed, 4 Mar 2026 19:07:13 +0900 Subject: [PATCH 03/10] refactor(BA-2946): move RBAC methods to base classes and fix validation Address PR review feedback by refactoring Session action classes: - Move common RBAC methods (scope_type, scope_id, target_element, target_entity_id, field_data) from concrete classes to base classes (SessionScopeAction and SessionSingleEntityAction) - Change field defaults from GLOBAL/empty string to None with explicit validation to prevent empty RBAC elements - Add proper error messages when fields not set before RBAC validation - Fix ModifySessionAction: rename session_id to session_uuid to avoid type conflict with base class (UUID vs str), convert in __post_init__ - Update API call in gql_legacy/session.py to use new parameter name This eliminates ~200 lines of duplicated code across 10 action classes and ensures scope is always explicitly set (never defaulting to GLOBAL per user requirement). Co-Authored-By: Claude Sonnet 4.5 --- .../backend/manager/api/gql_legacy/session.py | 2 +- .../session/actions/create_cluster.py | 19 ------ .../session/actions/create_from_params.py | 19 ------ .../session/actions/create_from_template.py | 19 ------ .../session/actions/destroy_session.py | 16 +---- .../session/actions/execute_session.py | 16 +---- .../session/actions/get_session_info.py | 16 +---- .../session/actions/match_sessions.py | 19 ------ .../session/actions/modify_session.py | 22 +++---- .../services/session/actions/search.py | 20 +----- .../services/session/actions/search_kernel.py | 20 +----- .../backend/manager/services/session/base.py | 63 +++++++++++++++++-- 12 files changed, 71 insertions(+), 180 deletions(-) diff --git a/src/ai/backend/manager/api/gql_legacy/session.py b/src/ai/backend/manager/api/gql_legacy/session.py index 6ab06ca2dfb..44be2251a7e 100644 --- a/src/ai/backend/manager/api/gql_legacy/session.py +++ b/src/ai/backend/manager/api/gql_legacy/session.py @@ -907,7 +907,7 @@ async def mutate_and_get_payload( result = await graph_ctx.processors.session.modify_session.wait_for_complete( ModifySessionAction( - session_id=session_id, + session_uuid=session_id, updater=Updater( spec=SessionUpdaterSpec( name=OptionalState[str].from_graphql(name), diff --git a/src/ai/backend/manager/services/session/actions/create_cluster.py b/src/ai/backend/manager/services/session/actions/create_cluster.py index 81afa824887..ea56e1e4710 100644 --- a/src/ai/backend/manager/services/session/actions/create_cluster.py +++ b/src/ai/backend/manager/services/session/actions/create_cluster.py @@ -3,11 +3,9 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.session.base import SessionScopeAction @@ -35,8 +33,6 @@ class CreateClusterAction(SessionScopeAction): enqueue_only: bool keypair_resource_policy: dict[str, Any] | None max_wait_seconds: int - _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context - _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -47,21 +43,6 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE - @override - def scope_type(self) -> ScopeType: - return self._scope_type - - @override - def scope_id(self) -> str: - return self._scope_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType(self._scope_type.value), - element_id=self._scope_id, - ) - @dataclass class CreateClusterActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_params.py b/src/ai/backend/manager/services/session/actions/create_from_params.py index f121ad2dec1..8c4ee666f93 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_params.py +++ b/src/ai/backend/manager/services/session/actions/create_from_params.py @@ -6,11 +6,9 @@ import yarl -from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.session.base import SessionScopeAction @@ -56,8 +54,6 @@ class CreateFromParamsAction(SessionScopeAction): sudo_session_enabled: bool requester_access_key: AccessKey keypair_resource_policy: dict[str, Any] | None - _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context - _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -68,21 +64,6 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE - @override - def scope_type(self) -> ScopeType: - return self._scope_type - - @override - def scope_id(self) -> str: - return self._scope_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType(self._scope_type.value), - element_id=self._scope_id, - ) - @dataclass class CreateFromParamsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_template.py b/src/ai/backend/manager/services/session/actions/create_from_template.py index 3222f8e4cb4..1c34236e520 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_template.py +++ b/src/ai/backend/manager/services/session/actions/create_from_template.py @@ -6,12 +6,10 @@ import yarl -from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.api.utils import Undefined -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.session.base import SessionScopeAction @@ -59,8 +57,6 @@ class CreateFromTemplateAction(SessionScopeAction): sudo_session_enabled: bool requester_access_key: AccessKey keypair_resource_policy: dict[str, Any] | None - _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context - _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -71,21 +67,6 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE - @override - def scope_type(self) -> ScopeType: - return self._scope_type - - @override - def scope_id(self) -> str: - return self._scope_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType(self._scope_type.value), - element_id=self._scope_id, - ) - @dataclass class CreateFromTemplateActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/destroy_session.py b/src/ai/backend/manager/services/session/actions/destroy_session.py index 1269e51238b..a19d256f8e7 100644 --- a/src/ai/backend/manager/services/session/actions/destroy_session.py +++ b/src/ai/backend/manager/services/session/actions/destroy_session.py @@ -1,11 +1,9 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.data.permission.types import RBACElementType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.session.base import SessionSingleEntityAction @@ -16,7 +14,7 @@ class DestroySessionAction(SessionSingleEntityAction): """Destroy a specific session. RBAC validation checks if the user has DELETE permission for this session. - session_id will be resolved from session_name before RBAC validation. + session_id must be resolved from session_name before RBAC validation. """ user_role: UserRole @@ -24,7 +22,6 @@ class DestroySessionAction(SessionSingleEntityAction): forced: bool recursive: bool owner_access_key: AccessKey - session_id: str = "" # TODO: Resolve from session_name before RBAC validation @override def entity_id(self) -> str | None: @@ -38,17 +35,6 @@ def operation_type(cls) -> ActionOperationType: # return "destroy_multi" return ActionOperationType.DELETE - @override - def target_entity_id(self) -> str: - return self.session_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType.SESSION, - element_id=self.session_id, - ) - @dataclass class DestroySessionActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/execute_session.py b/src/ai/backend/manager/services/session/actions/execute_session.py index f8db68d283a..bff6b76503f 100644 --- a/src/ai/backend/manager/services/session/actions/execute_session.py +++ b/src/ai/backend/manager/services/session/actions/execute_session.py @@ -1,11 +1,9 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.data.permission.types import RBACElementType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.services.session.base import SessionSingleEntityAction @@ -24,14 +22,13 @@ class ExecuteSessionAction(SessionSingleEntityAction): """Execute code in a specific session. RBAC validation checks if the user has UPDATE permission for this session. - session_id will be resolved from session_name before RBAC validation. + session_id must be resolved from session_name before RBAC validation. """ session_name: str api_version: tuple[Any, ...] owner_access_key: AccessKey params: ExecuteSessionActionParams - session_id: str = "" # TODO: Resolve from session_name before RBAC validation @override def entity_id(self) -> str | None: @@ -42,17 +39,6 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.UPDATE - @override - def target_entity_id(self) -> str: - return self.session_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType.SESSION, - element_id=self.session_id, - ) - @dataclass class ExecuteSessionActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/get_session_info.py b/src/ai/backend/manager/services/session/actions/get_session_info.py index eaeb5343dc7..e5b8f157e29 100644 --- a/src/ai/backend/manager/services/session/actions/get_session_info.py +++ b/src/ai/backend/manager/services/session/actions/get_session_info.py @@ -1,11 +1,9 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import RBACElementType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.services.session.base import SessionSingleEntityAction from ai.backend.manager.services.session.types import LegacySessionInfo @@ -16,12 +14,11 @@ class GetSessionInfoAction(SessionSingleEntityAction): """Get information about a specific session. RBAC validation checks if the user has READ permission for this session. - session_id will be resolved from session_name before RBAC validation. + session_id must be resolved from session_name before RBAC validation. """ session_name: str owner_access_key: AccessKey - session_id: str = "" # TODO: Resolve from session_name before RBAC validation @override def entity_id(self) -> str | None: @@ -32,17 +29,6 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.GET - @override - def target_entity_id(self) -> str: - return self.session_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType.SESSION, - element_id=self.session_id, - ) - @dataclass class GetSessionInfoActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/match_sessions.py b/src/ai/backend/manager/services/session/actions/match_sessions.py index 4493d4d27a9..8184864ffba 100644 --- a/src/ai/backend/manager/services/session/actions/match_sessions.py +++ b/src/ai/backend/manager/services/session/actions/match_sessions.py @@ -1,11 +1,9 @@ from dataclasses import dataclass from typing import Any, override -from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.services.session.base import SessionScopeAction @@ -19,8 +17,6 @@ class MatchSessionsAction(SessionScopeAction): id_or_name_prefix: str owner_access_key: AccessKey - _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context - _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -31,21 +27,6 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH - @override - def scope_type(self) -> ScopeType: - return self._scope_type - - @override - def scope_id(self) -> str: - return self._scope_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType(self._scope_type.value), - element_id=self._scope_id, - ) - @dataclass class MatchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/modify_session.py b/src/ai/backend/manager/services/session/actions/modify_session.py index 2396666b312..7f2eeeb23ca 100644 --- a/src/ai/backend/manager/services/session/actions/modify_session.py +++ b/src/ai/backend/manager/services/session/actions/modify_session.py @@ -2,10 +2,8 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import RBACElementType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.models.session import SessionRow from ai.backend.manager.repositories.base.updater import Updater @@ -17,31 +15,25 @@ class ModifySessionAction(SessionSingleEntityAction): """Modify a specific session. RBAC validation checks if the user has UPDATE permission for this session. + session_id (str) is automatically set from the session_uuid (UUID) field. """ - session_id: uuid.UUID + session_uuid: uuid.UUID # Renamed to avoid conflict with base class session_id updater: Updater[SessionRow] + def __post_init__(self) -> None: + # Set session_id (str) for RBAC validation from session_uuid (UUID) + object.__setattr__(self, "session_id", str(self.session_uuid)) + @override def entity_id(self) -> str | None: - return str(self.session_id) + return str(self.session_uuid) @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.UPDATE - @override - def target_entity_id(self) -> str: - return str(self.session_id) - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType.SESSION, - element_id=str(self.session_id), - ) - @dataclass class ModifySessionActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search.py b/src/ai/backend/manager/services/session/actions/search.py index a856e35819f..bccbbcec7a7 100644 --- a/src/ai/backend/manager/services/session/actions/search.py +++ b/src/ai/backend/manager/services/session/actions/search.py @@ -3,10 +3,8 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.repositories.base import BatchQuerier from ai.backend.manager.services.session.base import SessionScopeAction @@ -17,11 +15,10 @@ class SearchSessionsAction(SessionScopeAction): """Search sessions within a scope (domain/project). RBAC validation checks if the user has READ permission in the target scope. + _scope_type and _scope_id must be set before RBAC validation (typically USER scope). """ querier: BatchQuerier - _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context - _scope_id: str = "" # TODO: Set from context @override def entity_id(self) -> str | None: @@ -32,21 +29,6 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH - @override - def scope_type(self) -> ScopeType: - return self._scope_type - - @override - def scope_id(self) -> str: - return self._scope_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType(self._scope_type.value), - element_id=self._scope_id, - ) - @dataclass class SearchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search_kernel.py b/src/ai/backend/manager/services/session/actions/search_kernel.py index 7c96a903aeb..060f863e656 100644 --- a/src/ai/backend/manager/services/session/actions/search_kernel.py +++ b/src/ai/backend/manager/services/session/actions/search_kernel.py @@ -3,11 +3,10 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import EntityType, RBACElementType, ScopeType +from ai.backend.common.data.permission.types import EntityType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.kernel.types import KernelInfo -from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.repositories.base import BatchQuerier from ai.backend.manager.services.session.base import SessionScopeAction @@ -20,8 +19,6 @@ class SearchKernelsAction(SessionScopeAction): """ querier: BatchQuerier - _scope_type: ScopeType = ScopeType.GLOBAL # TODO: Set from context - _scope_id: str = "" # TODO: Set from context @override @classmethod @@ -37,21 +34,6 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH - @override - def scope_type(self) -> ScopeType: - return self._scope_type - - @override - def scope_id(self) -> str: - return self._scope_id - - @override - def target_element(self) -> RBACElementRef: - return RBACElementRef( - element_type=RBACElementType(self._scope_type.value), - element_id=self._scope_id, - ) - @dataclass class SearchKernelsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/base.py b/src/ai/backend/manager/services/session/base.py index 7d0f9731490..a49b7fb4b74 100644 --- a/src/ai/backend/manager/services/session/base.py +++ b/src/ai/backend/manager/services/session/base.py @@ -1,7 +1,7 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import override -from ai.backend.common.data.permission.types import EntityType +from ai.backend.common.data.permission.types import EntityType, RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseAction, BaseBatchAction from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult from ai.backend.manager.actions.action.single_entity import ( @@ -9,6 +9,7 @@ BaseSingleEntityActionResult, ) from ai.backend.manager.actions.action.types import FieldData +from ai.backend.manager.data.permission.types import RBACElementRef @dataclass @@ -32,14 +33,45 @@ class SessionScopeAction(BaseScopeAction): """Base class for session actions that operate within a scope (domain/project). Used for operations like creating or searching sessions within a specific scope. - Each concrete class must define _scope_type and _scope_id fields for RBAC validation. + Subclasses must set _scope_type and _scope_id fields before RBAC validation. + + Note: Scope should typically be USER scope (user_id), not GLOBAL. + Empty _scope_id is not allowed and will raise ValueError. """ + _scope_type: ScopeType | None = field(default=None, kw_only=True) + _scope_id: str | None = field(default=None, kw_only=True) + @override @classmethod def entity_type(cls) -> EntityType: return EntityType.SESSION + @override + def scope_type(self) -> ScopeType: + if self._scope_type is None: + raise ValueError( + f"{self.__class__.__name__}._scope_type must be set before RBAC validation" + ) + return self._scope_type + + @override + def scope_id(self) -> str: + if self._scope_id is None or not self._scope_id.strip(): + raise ValueError( + f"{self.__class__.__name__}._scope_id must be set to a non-empty string " + "before RBAC validation" + ) + return self._scope_id + + @override + def target_element(self) -> RBACElementRef: + # Reuse scope_type() and scope_id() for validation + return RBACElementRef( + element_type=RBACElementType(self.scope_type().value), + element_id=self.scope_id(), + ) + @dataclass class SessionScopeActionResult(BaseScopeActionResult): @@ -51,10 +83,14 @@ class SessionSingleEntityAction(BaseSingleEntityAction): """Base class for session actions that operate on a specific session. Used for operations like getting, updating, or deleting a specific session. - Each concrete class must implement target_entity_id() and target_element() - for RBAC validation. + Subclasses must provide a session_id (resolved from session_name if needed) + before RBAC validation. + + Note: Empty session_id is not allowed and will raise ValueError. """ + session_id: str | None = field(default=None, kw_only=True) + @override @classmethod def entity_type(cls) -> EntityType: @@ -64,6 +100,23 @@ def entity_type(cls) -> EntityType: def field_data(self) -> FieldData | None: return None + @override + def target_entity_id(self) -> str: + if self.session_id is None or not self.session_id.strip(): + raise ValueError( + f"{self.__class__.__name__}.session_id must be set to a non-empty string " + "before RBAC validation" + ) + return self.session_id + + @override + def target_element(self) -> RBACElementRef: + # Reuse target_entity_id() for validation + return RBACElementRef( + element_type=RBACElementType.SESSION, + element_id=self.target_entity_id(), + ) + @dataclass class SessionSingleEntityActionResult(BaseSingleEntityActionResult): From e20b04adfa20c9f64c916d4f11794ab720f557e3 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Wed, 4 Mar 2026 19:45:40 +0900 Subject: [PATCH 04/10] refactor(BA-2946): derive scope from business logic, remove _scope_type/_scope_id fields MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change SessionScopeAction pattern to compute scope from business context instead of storing it in separate fields: - Remove _scope_type and _scope_id fields from all SessionScopeAction concrete classes (SearchSessionsAction, CreateFromParamsAction, etc.) - Each concrete class now computes scope directly from its business fields: * scope_type() always returns ScopeType.USER * scope_id() returns str(self.user_id) * target_element() uses USER scope with user_id - Add user_id field to actions that lacked it: * SearchSessionsAction * SearchKernelsAction * MatchSessionsAction Benefits: - Eliminates redundant fields (_scope_type, _scope_id) - Scope derivation logic co-located with action definition - Enforces "always USER scope" requirement at type level - API handlers now only need to provide user_id, not compute scope This follows user requirement: "scope는 모두 user id로 설정. global scope는 절대 쓰지 마라" (always use user id for scope, never use GLOBAL scope). Co-Authored-By: Claude Sonnet 4.5 --- .../session/actions/create_cluster.py | 23 +++++++++++-- .../session/actions/create_from_params.py | 23 +++++++++++-- .../session/actions/create_from_template.py | 23 +++++++++++-- .../session/actions/match_sessions.py | 24 ++++++++++++-- .../services/session/actions/search.py | 25 ++++++++++++-- .../services/session/actions/search_kernel.py | 25 ++++++++++++-- .../backend/manager/services/session/base.py | 33 ++----------------- 7 files changed, 128 insertions(+), 48 deletions(-) diff --git a/src/ai/backend/manager/services/session/actions/create_cluster.py b/src/ai/backend/manager/services/session/actions/create_cluster.py index ea56e1e4710..10e843b1df7 100644 --- a/src/ai/backend/manager/services/session/actions/create_cluster.py +++ b/src/ai/backend/manager/services/session/actions/create_cluster.py @@ -3,19 +3,21 @@ from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.session.base import SessionScopeAction @dataclass class CreateClusterAction(SessionScopeAction): - """Create a new cluster session within a scope (domain/project). + """Create a new cluster session. - The scope is determined by group_id (PROJECT scope) or domain_name (DOMAIN scope). - RBAC validation checks if the user has CREATE permission in the target scope. + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. """ session_name: str @@ -43,6 +45,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateClusterActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_params.py b/src/ai/backend/manager/services/session/actions/create_from_params.py index 8c4ee666f93..68a293cf4c9 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_params.py +++ b/src/ai/backend/manager/services/session/actions/create_from_params.py @@ -6,9 +6,11 @@ import yarl +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.session.base import SessionScopeAction @@ -42,10 +44,10 @@ class CreateFromParamsActionParams: @dataclass class CreateFromParamsAction(SessionScopeAction): - """Create a new session from parameters within a scope (domain/project). + """Create a new session from parameters. - The scope is determined by group_id (PROJECT scope) or domain_name (DOMAIN scope). - RBAC validation checks if the user has CREATE permission in the target scope. + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. """ params: CreateFromParamsActionParams @@ -64,6 +66,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateFromParamsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_template.py b/src/ai/backend/manager/services/session/actions/create_from_template.py index 1c34236e520..2b721d36d3a 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_template.py +++ b/src/ai/backend/manager/services/session/actions/create_from_template.py @@ -6,10 +6,12 @@ import yarl +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.api.utils import Undefined +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.session.base import SessionScopeAction @@ -45,10 +47,10 @@ class CreateFromTemplateActionParams: @dataclass class CreateFromTemplateAction(SessionScopeAction): - """Create a new session from template within a scope (domain/project). + """Create a new session from template. - The scope is determined by group_id (PROJECT scope) or domain_name (DOMAIN scope). - RBAC validation checks if the user has CREATE permission in the target scope. + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. """ params: CreateFromTemplateActionParams @@ -67,6 +69,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateFromTemplateActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/match_sessions.py b/src/ai/backend/manager/services/session/actions/match_sessions.py index 8184864ffba..ca91ae8bab5 100644 --- a/src/ai/backend/manager/services/session/actions/match_sessions.py +++ b/src/ai/backend/manager/services/session/actions/match_sessions.py @@ -1,22 +1,27 @@ +import uuid from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Make this BatchAction @dataclass class MatchSessionsAction(SessionScopeAction): - """Match sessions by ID or name prefix within a scope (domain/project). + """Match sessions by ID or name prefix. - RBAC validation checks if the user has READ permission in the target scope. + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. """ id_or_name_prefix: str owner_access_key: AccessKey + user_id: uuid.UUID @override def entity_id(self) -> str | None: @@ -27,6 +32,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class MatchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search.py b/src/ai/backend/manager/services/session/actions/search.py index bccbbcec7a7..8d0b177490d 100644 --- a/src/ai/backend/manager/services/session/actions/search.py +++ b/src/ai/backend/manager/services/session/actions/search.py @@ -1,10 +1,13 @@ from __future__ import annotations +import uuid from dataclasses import dataclass from typing import override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.repositories.base import BatchQuerier from ai.backend.manager.services.session.base import SessionScopeAction @@ -12,13 +15,14 @@ @dataclass class SearchSessionsAction(SessionScopeAction): - """Search sessions within a scope (domain/project). + """Search sessions within a scope. - RBAC validation checks if the user has READ permission in the target scope. - _scope_type and _scope_id must be set before RBAC validation (typically USER scope). + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. """ querier: BatchQuerier + user_id: uuid.UUID @override def entity_id(self) -> str | None: @@ -29,6 +33,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class SearchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search_kernel.py b/src/ai/backend/manager/services/session/actions/search_kernel.py index 060f863e656..0a3ace75f58 100644 --- a/src/ai/backend/manager/services/session/actions/search_kernel.py +++ b/src/ai/backend/manager/services/session/actions/search_kernel.py @@ -1,24 +1,28 @@ from __future__ import annotations +import uuid from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import EntityType +from ai.backend.common.data.permission.types import EntityType, RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.kernel.types import KernelInfo +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.repositories.base import BatchQuerier from ai.backend.manager.services.session.base import SessionScopeAction @dataclass class SearchKernelsAction(SessionScopeAction): - """Search kernels within a scope (domain/project). + """Search kernels within a scope. - RBAC validation checks if the user has READ permission in the target scope. + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. """ querier: BatchQuerier + user_id: uuid.UUID @override @classmethod @@ -34,6 +38,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class SearchKernelsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/base.py b/src/ai/backend/manager/services/session/base.py index a49b7fb4b74..c824f8702b0 100644 --- a/src/ai/backend/manager/services/session/base.py +++ b/src/ai/backend/manager/services/session/base.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import override -from ai.backend.common.data.permission.types import EntityType, RBACElementType, ScopeType +from ai.backend.common.data.permission.types import EntityType, RBACElementType from ai.backend.manager.actions.action import BaseAction, BaseBatchAction from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult from ai.backend.manager.actions.action.single_entity import ( @@ -33,45 +33,16 @@ class SessionScopeAction(BaseScopeAction): """Base class for session actions that operate within a scope (domain/project). Used for operations like creating or searching sessions within a specific scope. - Subclasses must set _scope_type and _scope_id fields before RBAC validation. + Subclasses must implement scope_type(), scope_id(), and target_element() methods. Note: Scope should typically be USER scope (user_id), not GLOBAL. - Empty _scope_id is not allowed and will raise ValueError. """ - _scope_type: ScopeType | None = field(default=None, kw_only=True) - _scope_id: str | None = field(default=None, kw_only=True) - @override @classmethod def entity_type(cls) -> EntityType: return EntityType.SESSION - @override - def scope_type(self) -> ScopeType: - if self._scope_type is None: - raise ValueError( - f"{self.__class__.__name__}._scope_type must be set before RBAC validation" - ) - return self._scope_type - - @override - def scope_id(self) -> str: - if self._scope_id is None or not self._scope_id.strip(): - raise ValueError( - f"{self.__class__.__name__}._scope_id must be set to a non-empty string " - "before RBAC validation" - ) - return self._scope_id - - @override - def target_element(self) -> RBACElementRef: - # Reuse scope_type() and scope_id() for validation - return RBACElementRef( - element_type=RBACElementType(self.scope_type().value), - element_id=self.scope_id(), - ) - @dataclass class SessionScopeActionResult(BaseScopeActionResult): From 5027f152cbcd84404bfdc285c2094fe9fd779e6f Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Wed, 4 Mar 2026 20:25:53 +0900 Subject: [PATCH 05/10] fix(BA-2946): add user_id to SessionScopeAction calls in API handlers Update all API handlers to provide user_id when creating SessionScopeAction instances (SearchSessionsAction, SearchKernelsAction, MatchSessionsAction). Changes: - GraphQL data loaders: get user_id from current_user() context * session/loader.py * kernel/loader.py - GraphQL fetchers: get user_id from current_user() context * kernel/fetcher/kernel.py * session/fetcher/session.py - GraphQL types: add user_id to SearchKernelsAction in kernels field * session/types.py - REST API handlers: get user_id from current_user() context * rest/session/handler.py (MatchSessionsAction) * rest/compute_sessions/handler.py (SearchSessionsAction, SearchKernelsAction) All actions now receive user_id which is used to compute USER scope for RBAC validation as required by the design. Co-Authored-By: Claude Sonnet 4.5 --- .../manager/api/gql/data_loader/kernel/loader.py | 8 +++++++- .../manager/api/gql/data_loader/session/loader.py | 8 +++++++- .../backend/manager/api/gql/kernel/fetcher/kernel.py | 8 +++++++- .../backend/manager/api/gql/session/fetcher/session.py | 8 +++++++- src/ai/backend/manager/api/gql/session/types.py | 8 +++++++- .../manager/api/rest/compute_sessions/handler.py | 10 ++++++++-- src/ai/backend/manager/api/rest/session/handler.py | 6 ++++++ 7 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py b/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py index f2adf405c54..5ea816895fb 100644 --- a/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py @@ -2,8 +2,10 @@ from collections.abc import Sequence +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import KernelId from ai.backend.manager.data.kernel.types import KernelInfo +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import BatchQuerier, NoPagination from ai.backend.manager.repositories.scheduler.options import KernelConditions from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction @@ -26,13 +28,17 @@ async def load_kernels_by_ids( if not kernel_ids: return [] + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = BatchQuerier( pagination=NoPagination(), conditions=[KernelConditions.by_ids(kernel_ids)], ) action_result = await processor.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) kernel_map: dict[KernelId, KernelInfo] = {kernel.id: kernel for kernel in action_result.data} diff --git a/src/ai/backend/manager/api/gql/data_loader/session/loader.py b/src/ai/backend/manager/api/gql/data_loader/session/loader.py index 59e0f2a176e..8cd5708b59f 100644 --- a/src/ai/backend/manager/api/gql/data_loader/session/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/session/loader.py @@ -2,8 +2,10 @@ from collections.abc import Sequence +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import SessionId from ai.backend.manager.data.session.types import SessionData +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import BatchQuerier, NoPagination from ai.backend.manager.repositories.scheduler.options import SessionConditions from ai.backend.manager.services.session.actions.search import SearchSessionsAction @@ -17,13 +19,17 @@ async def load_sessions_by_ids( if not session_ids: return [] + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = BatchQuerier( pagination=NoPagination(), conditions=[SessionConditions.by_ids(session_ids)], ) action_result = await processor.search_sessions.wait_for_complete( - SearchSessionsAction(querier=querier) + SearchSessionsAction(querier=querier, user_id=user.user_id) ) session_map: dict[SessionId, SessionData] = { diff --git a/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py b/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py index 50bda150070..30a803e60aa 100644 --- a/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py +++ b/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py @@ -5,6 +5,7 @@ import strawberry from strawberry import Info +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import KernelId from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec from ai.backend.manager.api.gql.base import encode_cursor @@ -16,6 +17,7 @@ KernelV2OrderByGQL, ) from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.kernel import KernelRow from ai.backend.manager.repositories.base import QueryCondition from ai.backend.manager.repositories.scheduler.options import KernelConditions @@ -45,6 +47,10 @@ async def fetch_kernels( offset: int | None = None, base_conditions: list[QueryCondition] | None = None, ) -> KernelV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = info.context.gql_adapter.build_querier( PaginationOptions( first=first, @@ -61,7 +67,7 @@ async def fetch_kernels( ) action_result = await info.context.processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) nodes = [KernelV2GQL.from_kernel_info(kernel_info) for kernel_info in action_result.data] edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes] diff --git a/src/ai/backend/manager/api/gql/session/fetcher/session.py b/src/ai/backend/manager/api/gql/session/fetcher/session.py index 07b88f0e069..603405e00e2 100644 --- a/src/ai/backend/manager/api/gql/session/fetcher/session.py +++ b/src/ai/backend/manager/api/gql/session/fetcher/session.py @@ -5,6 +5,7 @@ import strawberry from strawberry import Info +from ai.backend.common.contexts.user import current_user from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec from ai.backend.manager.api.gql.base import encode_cursor from ai.backend.manager.api.gql.session.types import ( @@ -15,6 +16,7 @@ SessionV2OrderByGQL, ) from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.session import SessionRow from ai.backend.manager.repositories.base import QueryCondition from ai.backend.manager.repositories.scheduler.options import SessionConditions, SessionOrders @@ -44,6 +46,10 @@ async def fetch_sessions( offset: int | None = None, base_conditions: list[QueryCondition] | None = None, ) -> SessionV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = info.context.gql_adapter.build_querier( PaginationOptions( first=first, @@ -60,7 +66,7 @@ async def fetch_sessions( ) action_result = await info.context.processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=querier) + SearchSessionsAction(querier=querier, user_id=user.user_id) ) nodes = [SessionV2GQL.from_data(session_data) for session_data in action_result.data] diff --git a/src/ai/backend/manager/api/gql/session/types.py b/src/ai/backend/manager/api/gql/session/types.py index c33434660bb..26ad397e989 100644 --- a/src/ai/backend/manager/api/gql/session/types.py +++ b/src/ai/backend/manager/api/gql/session/types.py @@ -12,6 +12,7 @@ from strawberry import ID, Info from strawberry.relay import Connection, Edge, Node, NodeID +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import SessionId from ai.backend.manager.api.gql.base import OrderDirection, StringFilter, UUIDFilter, encode_cursor from ai.backend.manager.api.gql.common.types import ( @@ -41,6 +42,7 @@ from ai.backend.manager.api.gql.types import GQLFilter, GQLOrderBy, StrawberryGQLContext from ai.backend.manager.api.gql.user.types.node import UserV2GQL from ai.backend.manager.data.session.types import SessionData, SessionStatus +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import ( BatchQuerier, NoPagination, @@ -468,13 +470,17 @@ async def images(self) -> ImageV2ConnectionGQL: description="Added in 26.3.0. The kernels belonging to this session." ) async def kernels(self, info: Info[StrawberryGQLContext]) -> KernelV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + session_id = SessionId(UUID(str(self.id))) querier = BatchQuerier( pagination=NoPagination(), conditions=[KernelConditions.by_session_ids([session_id])], ) action_result = await info.context.processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) nodes = [KernelV2GQL.from_kernel_info(kernel) for kernel in action_result.data] edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes] diff --git a/src/ai/backend/manager/api/rest/compute_sessions/handler.py b/src/ai/backend/manager/api/rest/compute_sessions/handler.py index fb1c612342c..adeae27a073 100644 --- a/src/ai/backend/manager/api/rest/compute_sessions/handler.py +++ b/src/ai/backend/manager/api/rest/compute_sessions/handler.py @@ -7,6 +7,7 @@ from typing import Final from ai.backend.common.api_handlers import APIResponse, BodyParam +from ai.backend.common.contexts.user import current_user from ai.backend.common.dto.manager.compute_session import ( PaginationInfo, SearchComputeSessionsRequest, @@ -15,6 +16,7 @@ from ai.backend.common.types import SessionId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.dto.context import UserContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.services.session.actions.search import SearchSessionsAction from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction from ai.backend.manager.services.session.processors import SessionProcessors @@ -39,10 +41,14 @@ async def search_sessions( """Search compute sessions with nested container data.""" log.info("SEARCH_SESSIONS (ak:{})", ctx.access_key) + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + # Step 1: Search sessions session_querier = self._adapter.build_session_querier(body.parsed) session_result = await self._session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=session_querier) + SearchSessionsAction(querier=session_querier, user_id=user.user_id) ) # Step 2: Fetch kernels for found sessions @@ -51,7 +57,7 @@ async def search_sessions( if session_ids: kernel_querier = self._adapter.build_kernel_querier_for_sessions(session_ids) kernel_result = await self._session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=kernel_querier) + SearchKernelsAction(querier=kernel_querier, user_id=user.user_id) ) kernels_by_session = self._adapter.group_kernels_by_session(kernel_result.data) diff --git a/src/ai/backend/manager/api/rest/session/handler.py b/src/ai/backend/manager/api/rest/session/handler.py index 79b1b28ed8d..4ef40071f41 100644 --- a/src/ai/backend/manager/api/rest/session/handler.py +++ b/src/ai/backend/manager/api/rest/session/handler.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from ai.backend.common.api_handlers import APIResponse, BaseResponseModel, BodyParam, QueryParam +from ai.backend.common.contexts.user import current_user from ai.backend.common.dto.manager.session.request import ( CommitSessionRequest, CompleteRequest, @@ -91,6 +92,7 @@ from ai.backend.manager.errors.api import InvalidAPIParameters from ai.backend.manager.errors.auth import InsufficientPrivilege from ai.backend.manager.errors.resource import NoCurrentTaskContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.agent.actions.sync_agent_registry import ( SyncAgentRegistryAction, @@ -515,6 +517,9 @@ async def match_sessions( ) ) requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key + user = current_user() + if user is None: + raise UserNotFound("User not found in context") log.info( "MATCH_SESSIONS(ak:{0}/{1}, prefix:{2})", requester_access_key, @@ -525,6 +530,7 @@ async def match_sessions( MatchSessionsAction( id_or_name_prefix=params.id, owner_access_key=owner_access_key, + user_id=user.user_id, ) ) return APIResponse.build(HTTPStatus.OK, MatchSessionsResponse(matches=result.result)) From a77f32a6c3d64791b29f4186ff3385aefc4c6a06 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Wed, 4 Mar 2026 22:46:21 +0900 Subject: [PATCH 06/10] refactor(BA-2946): remove single entity actions, keep scope actions only Remove 4 single entity actions and SessionSingleEntityAction base class to keep this branch focused on scope actions only. Single entity actions will be handled in separate branches (BA-4864). Changes: - Removed SessionSingleEntityAction and SessionSingleEntityActionResult from base.py - Restored 4 single entity action files to main state: - destroy_session.py - execute_session.py - get_session_info.py - modify_session.py - Kept SessionScopeAction and 6 scope actions: - create_cluster, create_from_params, create_from_template - match_sessions, search_kernel, search Co-Authored-By: Claude Sonnet 4.5 --- .../session/actions/destroy_session.py | 12 +--- .../session/actions/execute_session.py | 12 +--- .../session/actions/get_session_info.py | 12 +--- .../session/actions/modify_session.py | 18 ++---- .../backend/manager/services/session/base.py | 55 +------------------ 5 files changed, 15 insertions(+), 94 deletions(-) diff --git a/src/ai/backend/manager/services/session/actions/destroy_session.py b/src/ai/backend/manager/services/session/actions/destroy_session.py index a19d256f8e7..7a8d3703165 100644 --- a/src/ai/backend/manager/services/session/actions/destroy_session.py +++ b/src/ai/backend/manager/services/session/actions/destroy_session.py @@ -5,18 +5,12 @@ from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionSingleEntityAction +from ai.backend.manager.services.session.base import SessionAction # TODO: Change this to BatchAction since it can destroy multiple sessions with recursive option @dataclass -class DestroySessionAction(SessionSingleEntityAction): - """Destroy a specific session. - - RBAC validation checks if the user has DELETE permission for this session. - session_id must be resolved from session_name before RBAC validation. - """ - +class DestroySessionAction(SessionAction): user_role: UserRole session_name: str forced: bool @@ -25,7 +19,7 @@ class DestroySessionAction(SessionSingleEntityAction): @override def entity_id(self) -> str | None: - return self.session_id if self.session_id else None + return None @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/execute_session.py b/src/ai/backend/manager/services/session/actions/execute_session.py index bff6b76503f..5cd34477adc 100644 --- a/src/ai/backend/manager/services/session/actions/execute_session.py +++ b/src/ai/backend/manager/services/session/actions/execute_session.py @@ -5,7 +5,7 @@ from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData -from ai.backend.manager.services.session.base import SessionSingleEntityAction +from ai.backend.manager.services.session.base import SessionAction @dataclass @@ -18,13 +18,7 @@ class ExecuteSessionActionParams: @dataclass -class ExecuteSessionAction(SessionSingleEntityAction): - """Execute code in a specific session. - - RBAC validation checks if the user has UPDATE permission for this session. - session_id must be resolved from session_name before RBAC validation. - """ - +class ExecuteSessionAction(SessionAction): session_name: str api_version: tuple[Any, ...] owner_access_key: AccessKey @@ -32,7 +26,7 @@ class ExecuteSessionAction(SessionSingleEntityAction): @override def entity_id(self) -> str | None: - return self.session_id if self.session_id else None + return None @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/get_session_info.py b/src/ai/backend/manager/services/session/actions/get_session_info.py index e5b8f157e29..9e39453f5c1 100644 --- a/src/ai/backend/manager/services/session/actions/get_session_info.py +++ b/src/ai/backend/manager/services/session/actions/get_session_info.py @@ -5,24 +5,18 @@ from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData -from ai.backend.manager.services.session.base import SessionSingleEntityAction +from ai.backend.manager.services.session.base import SessionAction from ai.backend.manager.services.session.types import LegacySessionInfo @dataclass -class GetSessionInfoAction(SessionSingleEntityAction): - """Get information about a specific session. - - RBAC validation checks if the user has READ permission for this session. - session_id must be resolved from session_name before RBAC validation. - """ - +class GetSessionInfoAction(SessionAction): session_name: str owner_access_key: AccessKey @override def entity_id(self) -> str | None: - return self.session_id if self.session_id else None + return None @override @classmethod diff --git a/src/ai/backend/manager/services/session/actions/modify_session.py b/src/ai/backend/manager/services/session/actions/modify_session.py index 7f2eeeb23ca..813f75a0a0d 100644 --- a/src/ai/backend/manager/services/session/actions/modify_session.py +++ b/src/ai/backend/manager/services/session/actions/modify_session.py @@ -7,27 +7,17 @@ from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.models.session import SessionRow from ai.backend.manager.repositories.base.updater import Updater -from ai.backend.manager.services.session.base import SessionSingleEntityAction +from ai.backend.manager.services.session.base import SessionAction @dataclass -class ModifySessionAction(SessionSingleEntityAction): - """Modify a specific session. - - RBAC validation checks if the user has UPDATE permission for this session. - session_id (str) is automatically set from the session_uuid (UUID) field. - """ - - session_uuid: uuid.UUID # Renamed to avoid conflict with base class session_id +class ModifySessionAction(SessionAction): + session_id: uuid.UUID updater: Updater[SessionRow] - def __post_init__(self) -> None: - # Set session_id (str) for RBAC validation from session_uuid (UUID) - object.__setattr__(self, "session_id", str(self.session_uuid)) - @override def entity_id(self) -> str | None: - return str(self.session_uuid) + return str(self.session_id) @override @classmethod diff --git a/src/ai/backend/manager/services/session/base.py b/src/ai/backend/manager/services/session/base.py index c824f8702b0..0693e4c77cd 100644 --- a/src/ai/backend/manager/services/session/base.py +++ b/src/ai/backend/manager/services/session/base.py @@ -1,15 +1,9 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import EntityType, RBACElementType +from ai.backend.common.data.permission.types import EntityType from ai.backend.manager.actions.action import BaseAction, BaseBatchAction from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult -from ai.backend.manager.actions.action.single_entity import ( - BaseSingleEntityAction, - BaseSingleEntityActionResult, -) -from ai.backend.manager.actions.action.types import FieldData -from ai.backend.manager.data.permission.types import RBACElementRef @dataclass @@ -47,48 +41,3 @@ def entity_type(cls) -> EntityType: @dataclass class SessionScopeActionResult(BaseScopeActionResult): pass - - -@dataclass -class SessionSingleEntityAction(BaseSingleEntityAction): - """Base class for session actions that operate on a specific session. - - Used for operations like getting, updating, or deleting a specific session. - Subclasses must provide a session_id (resolved from session_name if needed) - before RBAC validation. - - Note: Empty session_id is not allowed and will raise ValueError. - """ - - session_id: str | None = field(default=None, kw_only=True) - - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.SESSION - - @override - def field_data(self) -> FieldData | None: - return None - - @override - def target_entity_id(self) -> str: - if self.session_id is None or not self.session_id.strip(): - raise ValueError( - f"{self.__class__.__name__}.session_id must be set to a non-empty string " - "before RBAC validation" - ) - return self.session_id - - @override - def target_element(self) -> RBACElementRef: - # Reuse target_entity_id() for validation - return RBACElementRef( - element_type=RBACElementType.SESSION, - element_id=self.target_entity_id(), - ) - - -@dataclass -class SessionSingleEntityActionResult(BaseSingleEntityActionResult): - pass From 77ccefd92f7f812c4ed2ebf87d6d97fd346b881b Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Wed, 4 Mar 2026 23:20:30 +0900 Subject: [PATCH 07/10] feat(BA-2946): connect RBAC validators to SessionProcessors - Add permission_repository parameter to SessionProcessors.__init__ - Instantiate ScopeActionRBACValidator and SingleEntityActionRBACValidator - Apply scope validator to 6 scope actions (create_cluster, create_from_params, create_from_template, match_sessions, search_kernels, search_sessions) - Apply single entity validator to 4 single entity actions (destroy_session, execute_session, get_session_info, modify_session) - Reorganize processor initialization into three logical sections: no validation, scope validation, single entity validation Co-Authored-By: Claude Sonnet 4.5 --- .../manager/services/session/processors.py | 69 ++++++++++++++++--- 1 file changed, 58 insertions(+), 11 deletions(-) diff --git a/src/ai/backend/manager/services/session/processors.py b/src/ai/backend/manager/services/session/processors.py index 647568dd8ce..bf0c50e30cd 100644 --- a/src/ai/backend/manager/services/session/processors.py +++ b/src/ai/backend/manager/services/session/processors.py @@ -1,8 +1,9 @@ -from typing import override +from typing import cast, override from ai.backend.manager.actions.monitors.monitor import ActionMonitor from ai.backend.manager.actions.processor import ActionProcessor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.actions.validator.base import ActionValidator from ai.backend.manager.actions.validators import ActionValidators from ai.backend.manager.services.session.actions.check_and_transit_status import ( CheckAndTransitStatusAction, @@ -172,18 +173,17 @@ def __init__( action_monitors: list[ActionMonitor], validators: ActionValidators, ) -> None: + scope_validator = validators.rbac.scope + single_entity_validator = validators.rbac.single_entity + + # Actions without RBAC validation (internal/legacy) self.commit_session = ActionProcessor(service.commit_session, action_monitors) self.complete = ActionProcessor(service.complete, action_monitors) self.convert_session_to_image = ActionProcessor( service.convert_session_to_image, action_monitors ) - self.create_cluster = ActionProcessor(service.create_cluster, action_monitors) - self.create_from_params = ActionProcessor(service.create_from_params, action_monitors) - self.create_from_template = ActionProcessor(service.create_from_template, action_monitors) - self.destroy_session = ActionProcessor(service.destroy_session, action_monitors) self.download_file = ActionProcessor(service.download_file, action_monitors) self.download_files = ActionProcessor(service.download_files, action_monitors) - self.execute_session = ActionProcessor(service.execute_session, action_monitors) self.get_abusing_report = ActionProcessor(service.get_abusing_report, action_monitors) self.get_commit_status = ActionProcessor(service.get_commit_status, action_monitors) self.get_container_logs = ActionProcessor(service.get_container_logs, action_monitors) @@ -191,23 +191,70 @@ def __init__( self.get_direct_access_info = ActionProcessor( service.get_direct_access_info, action_monitors ) - self.get_session_info = ActionProcessor(service.get_session_info, action_monitors) self.get_status_history = ActionProcessor(service.get_status_history, action_monitors) self.interrupt = ActionProcessor(service.interrupt, action_monitors) self.list_files = ActionProcessor(service.list_files, action_monitors) - self.match_sessions = ActionProcessor(service.match_sessions, action_monitors) self.rename_session = ActionProcessor(service.rename_session, action_monitors) self.restart_session = ActionProcessor(service.restart_session, action_monitors) - self.search_kernels = ActionProcessor(service.search_kernels, action_monitors) - self.search_sessions = ActionProcessor(service.search, action_monitors) self.shutdown_service = ActionProcessor(service.shutdown_service, action_monitors) self.start_service = ActionProcessor(service.start_service, action_monitors) self.upload_files = ActionProcessor(service.upload_files, action_monitors) - self.modify_session = ActionProcessor(service.modify_session, action_monitors) self.check_and_transit_status = ActionProcessor( service.check_and_transit_status, action_monitors ) + # Scope actions with RBAC validation + self.create_cluster = ActionProcessor( + service.create_cluster, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.create_from_params = ActionProcessor( + service.create_from_params, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.create_from_template = ActionProcessor( + service.create_from_template, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.match_sessions = ActionProcessor( + service.match_sessions, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.search_kernels = ActionProcessor( + service.search_kernels, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.search_sessions = ActionProcessor( + service.search, action_monitors, validators=[cast(ActionValidator, scope_validator)] + ) + + # Single entity actions with RBAC validation + self.destroy_session = ActionProcessor( + service.destroy_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.execute_session = ActionProcessor( + service.execute_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.get_session_info = ActionProcessor( + service.get_session_info, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.modify_session = ActionProcessor( + service.modify_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + @override def supported_actions(self) -> list[ActionSpec]: return [ From 6a441467c796d264f26f2614510227e1e2c256e6 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 20 Mar 2026 14:12:50 +0900 Subject: [PATCH 08/10] fix(BA-2946): fix misleading SessionScopeAction docstring Co-Authored-By: Claude Opus 4.6 --- src/ai/backend/manager/services/session/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ai/backend/manager/services/session/base.py b/src/ai/backend/manager/services/session/base.py index 0693e4c77cd..93c9c488e74 100644 --- a/src/ai/backend/manager/services/session/base.py +++ b/src/ai/backend/manager/services/session/base.py @@ -24,12 +24,10 @@ def entity_type(cls) -> EntityType: @dataclass class SessionScopeAction(BaseScopeAction): - """Base class for session actions that operate within a scope (domain/project). + """Base class for session actions that operate within a scope. Used for operations like creating or searching sessions within a specific scope. Subclasses must implement scope_type(), scope_id(), and target_element() methods. - - Note: Scope should typically be USER scope (user_id), not GLOBAL. """ @override From ebf2f90b6969002c77fe91356224cea78aa2a49d Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 20 Mar 2026 14:36:11 +0900 Subject: [PATCH 09/10] fix(BA-2946): add missing user_id args and fix session_uuid rename in tests - Add user_id to SearchSessionsAction, SearchKernelsAction, MatchSessionsAction calls in test_handler, test_session_service, test_session_lifecycle_service - Fix session_uuid -> session_id rename in gql_legacy/session.py Co-Authored-By: Claude Opus 4.6 --- .../backend/manager/api/gql_legacy/session.py | 2 +- .../api/compute_sessions/test_handler.py | 12 +++++----- .../session/test_session_lifecycle_service.py | 5 ++++ .../services/session/test_session_service.py | 23 ++++++++++++++----- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/ai/backend/manager/api/gql_legacy/session.py b/src/ai/backend/manager/api/gql_legacy/session.py index 44be2251a7e..6ab06ca2dfb 100644 --- a/src/ai/backend/manager/api/gql_legacy/session.py +++ b/src/ai/backend/manager/api/gql_legacy/session.py @@ -907,7 +907,7 @@ async def mutate_and_get_payload( result = await graph_ctx.processors.session.modify_session.wait_for_complete( ModifySessionAction( - session_uuid=session_id, + session_id=session_id, updater=Updater( spec=SessionUpdaterSpec( name=OptionalState[str].from_graphql(name), diff --git a/tests/unit/manager/api/compute_sessions/test_handler.py b/tests/unit/manager/api/compute_sessions/test_handler.py index f562e2caa26..c66a3db2431 100644 --- a/tests/unit/manager/api/compute_sessions/test_handler.py +++ b/tests/unit/manager/api/compute_sessions/test_handler.py @@ -500,12 +500,12 @@ async def test_search_sessions_calls_both_processors( ) -> None: """Handler should call both search_sessions and search_kernels.""" await mock_processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=MagicMock()) + SearchSessionsAction(querier=MagicMock(), user_id=uuid4()) ) mock_processors.session.search_sessions.wait_for_complete.assert_called_once() await mock_processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=MagicMock()) + SearchKernelsAction(querier=MagicMock(), user_id=uuid4()) ) mock_processors.session.search_kernels.wait_for_complete.assert_called_once() @@ -520,7 +520,7 @@ async def test_search_sessions_empty_result( ) result = await processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=MagicMock()) + SearchSessionsAction(querier=MagicMock(), user_id=uuid4()) ) assert result.data == [] @@ -534,10 +534,10 @@ async def test_session_result_has_correct_container_grouping( ) -> None: """Kernels should be correctly grouped by session ID.""" session_result = await mock_processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=MagicMock()) + SearchSessionsAction(querier=MagicMock(), user_id=uuid4()) ) kernel_result = await mock_processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=MagicMock()) + SearchKernelsAction(querier=MagicMock(), user_id=uuid4()) ) adapter = ComputeSessionsAdapter() @@ -559,7 +559,7 @@ async def test_pagination_info_is_correct( ) -> None: """Pagination info should reflect the session search result.""" session_result = await mock_processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=MagicMock()) + SearchSessionsAction(querier=MagicMock(), user_id=uuid4()) ) assert session_result.total_count == 2 diff --git a/tests/unit/manager/services/session/test_session_lifecycle_service.py b/tests/unit/manager/services/session/test_session_lifecycle_service.py index a7b2b6434e0..0dc2f898f65 100644 --- a/tests/unit/manager/services/session/test_session_lifecycle_service.py +++ b/tests/unit/manager/services/session/test_session_lifecycle_service.py @@ -1237,6 +1237,7 @@ async def test_prefix_matching_returns_sessions( action = MatchSessionsAction( id_or_name_prefix="test", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1249,12 +1250,14 @@ async def test_no_match_returns_empty( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.match_sessions = AsyncMock(return_value=[]) action = MatchSessionsAction( id_or_name_prefix="nonexistent", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1265,12 +1268,14 @@ async def test_owner_access_key_filtering( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: mock_session_repository.match_sessions = AsyncMock(return_value=[]) action = MatchSessionsAction( id_or_name_prefix="test", owner_access_key=sample_access_key, + user_id=sample_user_id, ) await session_service.match_sessions(action) diff --git a/tests/unit/manager/services/session/test_session_service.py b/tests/unit/manager/services/session/test_session_service.py index a447b764d28..78680a21f19 100644 --- a/tests/unit/manager/services/session/test_session_service.py +++ b/tests/unit/manager/services/session/test_session_service.py @@ -289,6 +289,7 @@ async def test_success( mock_session_repository: MagicMock, sample_session_data: SessionData, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test successfully matching sessions""" mock_session_repository.match_sessions = AsyncMock(return_value=[sample_session_data]) @@ -296,6 +297,7 @@ async def test_success( action = MatchSessionsAction( id_or_name_prefix="test", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -310,6 +312,7 @@ async def test_no_matches( session_service: SessionService, mock_session_repository: MagicMock, sample_access_key: AccessKey, + sample_user_id: UUID, ) -> None: """Test matching sessions when none found""" mock_session_repository.match_sessions = AsyncMock(return_value=[]) @@ -317,6 +320,7 @@ async def test_no_matches( action = MatchSessionsAction( id_or_name_prefix="nonexistent", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -382,6 +386,7 @@ async def test_multiple_matches( action = MatchSessionsAction( id_or_name_prefix="test", owner_access_key=sample_access_key, + user_id=sample_user_id, ) result = await session_service.match_sessions(action) @@ -1600,6 +1605,7 @@ async def test_search_sessions( session_service: SessionService, mock_session_repository: MagicMock, sample_session_data: SessionData, + sample_user_id: UUID, ) -> None: """Test searching sessions with querier""" mock_session_repository.search = AsyncMock( @@ -1616,7 +1622,7 @@ async def test_search_sessions( conditions=[], orders=[], ) - action = SearchSessionsAction(querier=querier) + action = SearchSessionsAction(querier=querier, user_id=sample_user_id) result = await session_service.search(action) assert result.data == [sample_session_data] @@ -1629,6 +1635,7 @@ async def test_search_sessions_empty_result( self, session_service: SessionService, mock_session_repository: MagicMock, + sample_user_id: UUID, ) -> None: """Test searching sessions when no results are found""" mock_session_repository.search = AsyncMock( @@ -1645,7 +1652,7 @@ async def test_search_sessions_empty_result( conditions=[], orders=[], ) - action = SearchSessionsAction(querier=querier) + action = SearchSessionsAction(querier=querier, user_id=sample_user_id) result = await session_service.search(action) assert result.data == [] @@ -1656,6 +1663,7 @@ async def test_search_sessions_with_pagination( session_service: SessionService, mock_session_repository: MagicMock, sample_session_data: SessionData, + sample_user_id: UUID, ) -> None: """Test searching sessions with pagination""" mock_session_repository.search = AsyncMock( @@ -1672,7 +1680,7 @@ async def test_search_sessions_with_pagination( conditions=[], orders=[], ) - action = SearchSessionsAction(querier=querier) + action = SearchSessionsAction(querier=querier, user_id=sample_user_id) result = await session_service.search(action) assert result.total_count == 25 @@ -1783,6 +1791,7 @@ async def test_search_kernels( session_service: SessionService, mock_session_repository: MagicMock, sample_kernel_info: KernelInfo, + sample_user_id: UUID, ) -> None: """Test searching kernels with querier""" mock_session_repository.search_kernels = AsyncMock( @@ -1799,7 +1808,7 @@ async def test_search_kernels( conditions=[], orders=[], ) - action = SearchKernelsAction(querier=querier) + action = SearchKernelsAction(querier=querier, user_id=sample_user_id) result = await session_service.search_kernels(action) assert result.data == [sample_kernel_info] @@ -1812,6 +1821,7 @@ async def test_search_kernels_empty_result( self, session_service: SessionService, mock_session_repository: MagicMock, + sample_user_id: UUID, ) -> None: """Test searching kernels when no results are found""" mock_session_repository.search_kernels = AsyncMock( @@ -1828,7 +1838,7 @@ async def test_search_kernels_empty_result( conditions=[], orders=[], ) - action = SearchKernelsAction(querier=querier) + action = SearchKernelsAction(querier=querier, user_id=sample_user_id) result = await session_service.search_kernels(action) assert result.data == [] @@ -1839,6 +1849,7 @@ async def test_search_kernels_with_pagination( session_service: SessionService, mock_session_repository: MagicMock, sample_kernel_info: KernelInfo, + sample_user_id: UUID, ) -> None: """Test searching kernels with pagination""" mock_session_repository.search_kernels = AsyncMock( @@ -1855,7 +1866,7 @@ async def test_search_kernels_with_pagination( conditions=[], orders=[], ) - action = SearchKernelsAction(querier=querier) + action = SearchKernelsAction(querier=querier, user_id=sample_user_id) result = await session_service.search_kernels(action) assert result.total_count == 25 From ddab65a1496449c6cea3eb373129eebfced92d88 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 20 Mar 2026 14:50:10 +0900 Subject: [PATCH 10/10] fix(BA-2946): set up user context in dataloader tests for RBAC Co-Authored-By: Claude Opus 4.6 --- .../manager/api/session/test_dataloader.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/unit/manager/api/session/test_dataloader.py b/tests/unit/manager/api/session/test_dataloader.py index 456b7f2add0..2a9f4389dbf 100644 --- a/tests/unit/manager/api/session/test_dataloader.py +++ b/tests/unit/manager/api/session/test_dataloader.py @@ -5,11 +5,24 @@ import uuid from unittest.mock import AsyncMock, MagicMock +from ai.backend.common.contexts.user import with_user +from ai.backend.common.data.user.types import UserData, UserRole from ai.backend.common.types import KernelId from ai.backend.manager.api.gql.data_loader.kernel.loader import load_kernels_by_ids from ai.backend.manager.data.kernel.types import KernelInfo +def _make_user_data() -> UserData: + return UserData( + user_id=uuid.uuid4(), + is_authorized=True, + is_admin=False, + is_superadmin=False, + role=UserRole.USER, + domain_name="default", + ) + + class TestLoadKernelsByIds: """Tests for load_kernels_by_ids function.""" @@ -47,7 +60,8 @@ async def test_returns_kernels_in_request_order(self) -> None: ) # When - result = await load_kernels_by_ids(mock_processor, [id1, id2, id3]) + with with_user(_make_user_data()): + result = await load_kernels_by_ids(mock_processor, [id1, id2, id3]) # Then assert result == [kernel1, kernel2, kernel3] @@ -60,7 +74,8 @@ async def test_returns_none_for_missing_ids(self) -> None: mock_processor = self.create_mock_processor([existing_kernel]) # When - result = await load_kernels_by_ids(mock_processor, [existing_id, missing_id]) + with with_user(_make_user_data()): + result = await load_kernels_by_ids(mock_processor, [existing_id, missing_id]) # Then assert result == [existing_kernel, None]