Skip to content

Commit 1f436d5

Browse files
fregataaclaude
andauthored
feat(BA-4865): Apply RBAC validator infrastructure to Session actions (#9624)
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 2c62d8a commit 1f436d5

20 files changed

Lines changed: 329 additions & 45 deletions

File tree

changes/9624.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add RBAC validator infrastructure to Session actions following BEP-1048 patterns

src/ai/backend/manager/api/gql/data_loader/kernel/loader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from collections.abc import Sequence
44

5+
from ai.backend.common.contexts.user import current_user
56
from ai.backend.common.types import KernelId
67
from ai.backend.manager.data.kernel.types import KernelInfo
8+
from ai.backend.manager.errors.user import UserNotFound
79
from ai.backend.manager.repositories.base import BatchQuerier, NoPagination
810
from ai.backend.manager.repositories.scheduler.options import KernelConditions
911
from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction
@@ -26,13 +28,17 @@ async def load_kernels_by_ids(
2628
if not kernel_ids:
2729
return []
2830

31+
user = current_user()
32+
if user is None:
33+
raise UserNotFound("User not found in context")
34+
2935
querier = BatchQuerier(
3036
pagination=NoPagination(),
3137
conditions=[KernelConditions.by_ids(kernel_ids)],
3238
)
3339

3440
action_result = await processor.search_kernels.wait_for_complete(
35-
SearchKernelsAction(querier=querier)
41+
SearchKernelsAction(querier=querier, user_id=user.user_id)
3642
)
3743

3844
kernel_map: dict[KernelId, KernelInfo] = {kernel.id: kernel for kernel in action_result.data}

src/ai/backend/manager/api/gql/data_loader/session/loader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from collections.abc import Sequence
44

5+
from ai.backend.common.contexts.user import current_user
56
from ai.backend.common.types import SessionId
67
from ai.backend.manager.data.session.types import SessionData
8+
from ai.backend.manager.errors.user import UserNotFound
79
from ai.backend.manager.repositories.base import BatchQuerier, NoPagination
810
from ai.backend.manager.repositories.scheduler.options import SessionConditions
911
from ai.backend.manager.services.session.actions.search import SearchSessionsAction
@@ -17,13 +19,17 @@ async def load_sessions_by_ids(
1719
if not session_ids:
1820
return []
1921

22+
user = current_user()
23+
if user is None:
24+
raise UserNotFound("User not found in context")
25+
2026
querier = BatchQuerier(
2127
pagination=NoPagination(),
2228
conditions=[SessionConditions.by_ids(session_ids)],
2329
)
2430

2531
action_result = await processor.search_sessions.wait_for_complete(
26-
SearchSessionsAction(querier=querier)
32+
SearchSessionsAction(querier=querier, user_id=user.user_id)
2733
)
2834

2935
session_map: dict[SessionId, SessionData] = {

src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import strawberry
66
from strawberry import Info
77

8+
from ai.backend.common.contexts.user import current_user
89
from ai.backend.common.types import KernelId
910
from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec
1011
from ai.backend.manager.api.gql.base import encode_cursor
@@ -16,6 +17,7 @@
1617
KernelV2OrderByGQL,
1718
)
1819
from ai.backend.manager.api.gql.types import StrawberryGQLContext
20+
from ai.backend.manager.errors.user import UserNotFound
1921
from ai.backend.manager.models.kernel import KernelRow
2022
from ai.backend.manager.repositories.base import QueryCondition
2123
from ai.backend.manager.repositories.scheduler.options import KernelConditions
@@ -45,6 +47,10 @@ async def fetch_kernels(
4547
offset: int | None = None,
4648
base_conditions: list[QueryCondition] | None = None,
4749
) -> KernelV2ConnectionGQL:
50+
user = current_user()
51+
if user is None:
52+
raise UserNotFound("User not found in context")
53+
4854
querier = info.context.gql_adapter.build_querier(
4955
PaginationOptions(
5056
first=first,
@@ -61,7 +67,7 @@ async def fetch_kernels(
6167
)
6268

6369
action_result = await info.context.processors.session.search_kernels.wait_for_complete(
64-
SearchKernelsAction(querier=querier)
70+
SearchKernelsAction(querier=querier, user_id=user.user_id)
6571
)
6672
nodes = [KernelV2GQL.from_kernel_info(kernel_info) for kernel_info in action_result.data]
6773
edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes]

src/ai/backend/manager/api/gql/session/fetcher/session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import strawberry
66
from strawberry import Info
77

8+
from ai.backend.common.contexts.user import current_user
89
from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec
910
from ai.backend.manager.api.gql.base import encode_cursor
1011
from ai.backend.manager.api.gql.session.types import (
@@ -15,6 +16,7 @@
1516
SessionV2OrderByGQL,
1617
)
1718
from ai.backend.manager.api.gql.types import StrawberryGQLContext
19+
from ai.backend.manager.errors.user import UserNotFound
1820
from ai.backend.manager.models.session import SessionRow
1921
from ai.backend.manager.repositories.base import QueryCondition
2022
from ai.backend.manager.repositories.scheduler.options import SessionConditions, SessionOrders
@@ -44,6 +46,10 @@ async def fetch_sessions(
4446
offset: int | None = None,
4547
base_conditions: list[QueryCondition] | None = None,
4648
) -> SessionV2ConnectionGQL:
49+
user = current_user()
50+
if user is None:
51+
raise UserNotFound("User not found in context")
52+
4753
querier = info.context.gql_adapter.build_querier(
4854
PaginationOptions(
4955
first=first,
@@ -60,7 +66,7 @@ async def fetch_sessions(
6066
)
6167

6268
action_result = await info.context.processors.session.search_sessions.wait_for_complete(
63-
SearchSessionsAction(querier=querier)
69+
SearchSessionsAction(querier=querier, user_id=user.user_id)
6470
)
6571

6672
nodes = [SessionV2GQL.from_data(session_data) for session_data in action_result.data]

src/ai/backend/manager/api/gql/session/types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from strawberry import ID, Info
1313
from strawberry.relay import Connection, Edge, Node, NodeID
1414

15+
from ai.backend.common.contexts.user import current_user
1516
from ai.backend.common.types import SessionId
1617
from ai.backend.manager.api.gql.base import OrderDirection, StringFilter, UUIDFilter, encode_cursor
1718
from ai.backend.manager.api.gql.common.types import (
@@ -41,6 +42,7 @@
4142
from ai.backend.manager.api.gql.types import GQLFilter, GQLOrderBy, StrawberryGQLContext
4243
from ai.backend.manager.api.gql.user.types.node import UserV2GQL
4344
from ai.backend.manager.data.session.types import SessionData, SessionStatus
45+
from ai.backend.manager.errors.user import UserNotFound
4446
from ai.backend.manager.repositories.base import (
4547
BatchQuerier,
4648
NoPagination,
@@ -468,13 +470,17 @@ async def images(self) -> ImageV2ConnectionGQL:
468470
description="Added in 26.3.0. The kernels belonging to this session."
469471
)
470472
async def kernels(self, info: Info[StrawberryGQLContext]) -> KernelV2ConnectionGQL:
473+
user = current_user()
474+
if user is None:
475+
raise UserNotFound("User not found in context")
476+
471477
session_id = SessionId(UUID(str(self.id)))
472478
querier = BatchQuerier(
473479
pagination=NoPagination(),
474480
conditions=[KernelConditions.by_session_ids([session_id])],
475481
)
476482
action_result = await info.context.processors.session.search_kernels.wait_for_complete(
477-
SearchKernelsAction(querier=querier)
483+
SearchKernelsAction(querier=querier, user_id=user.user_id)
478484
)
479485
nodes = [KernelV2GQL.from_kernel_info(kernel) for kernel in action_result.data]
480486
edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes]

src/ai/backend/manager/api/rest/compute_sessions/handler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Final
88

99
from ai.backend.common.api_handlers import APIResponse, BodyParam
10+
from ai.backend.common.contexts.user import current_user
1011
from ai.backend.common.dto.manager.compute_session import (
1112
PaginationInfo,
1213
SearchComputeSessionsRequest,
@@ -15,6 +16,7 @@
1516
from ai.backend.common.types import SessionId
1617
from ai.backend.logging import BraceStyleAdapter
1718
from ai.backend.manager.dto.context import UserContext
19+
from ai.backend.manager.errors.user import UserNotFound
1820
from ai.backend.manager.services.session.actions.search import SearchSessionsAction
1921
from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction
2022
from ai.backend.manager.services.session.processors import SessionProcessors
@@ -39,10 +41,14 @@ async def search_sessions(
3941
"""Search compute sessions with nested container data."""
4042
log.info("SEARCH_SESSIONS (ak:{})", ctx.access_key)
4143

44+
user = current_user()
45+
if user is None:
46+
raise UserNotFound("User not found in context")
47+
4248
# Step 1: Search sessions
4349
session_querier = self._adapter.build_session_querier(body.parsed)
4450
session_result = await self._session.search_sessions.wait_for_complete(
45-
SearchSessionsAction(querier=session_querier)
51+
SearchSessionsAction(querier=session_querier, user_id=user.user_id)
4652
)
4753

4854
# Step 2: Fetch kernels for found sessions
@@ -51,7 +57,7 @@ async def search_sessions(
5157
if session_ids:
5258
kernel_querier = self._adapter.build_kernel_querier_for_sessions(session_ids)
5359
kernel_result = await self._session.search_kernels.wait_for_complete(
54-
SearchKernelsAction(querier=kernel_querier)
60+
SearchKernelsAction(querier=kernel_querier, user_id=user.user_id)
5561
)
5662
kernels_by_session = self._adapter.group_kernels_by_session(kernel_result.data)
5763

src/ai/backend/manager/api/rest/session/handler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pydantic import BaseModel
2121

2222
from ai.backend.common.api_handlers import APIResponse, BaseResponseModel, BodyParam, QueryParam
23+
from ai.backend.common.contexts.user import current_user
2324
from ai.backend.common.dto.manager.session.request import (
2425
CommitSessionRequest,
2526
CompleteRequest,
@@ -91,6 +92,7 @@
9192
from ai.backend.manager.errors.api import InvalidAPIParameters
9293
from ai.backend.manager.errors.auth import InsufficientPrivilege
9394
from ai.backend.manager.errors.resource import NoCurrentTaskContext
95+
from ai.backend.manager.errors.user import UserNotFound
9496
from ai.backend.manager.models.user import UserRole
9597
from ai.backend.manager.services.agent.actions.sync_agent_registry import (
9698
SyncAgentRegistryAction,
@@ -515,6 +517,9 @@ async def match_sessions(
515517
)
516518
)
517519
requester_access_key, owner_access_key = scope.requester_access_key, scope.owner_access_key
520+
user = current_user()
521+
if user is None:
522+
raise UserNotFound("User not found in context")
518523
log.info(
519524
"MATCH_SESSIONS(ak:{0}/{1}, prefix:{2})",
520525
requester_access_key,
@@ -525,6 +530,7 @@ async def match_sessions(
525530
MatchSessionsAction(
526531
id_or_name_prefix=params.id,
527532
owner_access_key=owner_access_key,
533+
user_id=user.user_id,
528534
)
529535
)
530536
return APIResponse.build(HTTPStatus.OK, MatchSessionsResponse(matches=result.result))

src/ai/backend/manager/services/session/actions/create_cluster.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,23 @@
33
from dataclasses import dataclass
44
from typing import Any, override
55

6+
from ai.backend.common.data.permission.types import RBACElementType, ScopeType
67
from ai.backend.common.types import AccessKey, SessionTypes
78
from ai.backend.manager.actions.action import BaseActionResult
89
from ai.backend.manager.actions.types import ActionOperationType
10+
from ai.backend.manager.data.permission.types import RBACElementRef
911
from ai.backend.manager.models.user import UserRole
10-
from ai.backend.manager.services.session.base import SessionAction
12+
from ai.backend.manager.services.session.base import SessionScopeAction
1113

1214

1315
@dataclass
14-
class CreateClusterAction(SessionAction):
16+
class CreateClusterAction(SessionScopeAction):
17+
"""Create a new cluster session.
18+
19+
RBAC validation checks if the user has CREATE permission in USER scope.
20+
Scope is always USER scope with user_id.
21+
"""
22+
1523
session_name: str
1624
user_id: uuid.UUID
1725
user_role: UserRole
@@ -37,6 +45,21 @@ def entity_id(self) -> str | None:
3745
def operation_type(cls) -> ActionOperationType:
3846
return ActionOperationType.CREATE
3947

48+
@override
49+
def scope_type(self) -> ScopeType:
50+
return ScopeType.USER
51+
52+
@override
53+
def scope_id(self) -> str:
54+
return str(self.user_id)
55+
56+
@override
57+
def target_element(self) -> RBACElementRef:
58+
return RBACElementRef(
59+
element_type=RBACElementType.USER,
60+
element_id=str(self.user_id),
61+
)
62+
4063

4164
@dataclass
4265
class CreateClusterActionResult(BaseActionResult):

src/ai/backend/manager/services/session/actions/create_from_params.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77
import yarl
88

9+
from ai.backend.common.data.permission.types import RBACElementType, ScopeType
910
from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes
1011
from ai.backend.manager.actions.action import BaseActionResult
1112
from ai.backend.manager.actions.types import ActionOperationType
13+
from ai.backend.manager.data.permission.types import RBACElementRef
1214
from ai.backend.manager.models.user import UserRole
13-
from ai.backend.manager.services.session.base import SessionAction
15+
from ai.backend.manager.services.session.base import SessionScopeAction
1416

1517

1618
# TODO: Idea: Refactor this type using pydantic and utilize as API model
@@ -41,7 +43,13 @@ class CreateFromParamsActionParams:
4143

4244

4345
@dataclass
44-
class CreateFromParamsAction(SessionAction):
46+
class CreateFromParamsAction(SessionScopeAction):
47+
"""Create a new session from parameters.
48+
49+
RBAC validation checks if the user has CREATE permission in USER scope.
50+
Scope is always USER scope with user_id.
51+
"""
52+
4553
params: CreateFromParamsActionParams
4654
user_id: uuid.UUID
4755
user_role: UserRole
@@ -58,6 +66,21 @@ def entity_id(self) -> str | None:
5866
def operation_type(cls) -> ActionOperationType:
5967
return ActionOperationType.CREATE
6068

69+
@override
70+
def scope_type(self) -> ScopeType:
71+
return ScopeType.USER
72+
73+
@override
74+
def scope_id(self) -> str:
75+
return str(self.user_id)
76+
77+
@override
78+
def target_element(self) -> RBACElementRef:
79+
return RBACElementRef(
80+
element_type=RBACElementType.USER,
81+
element_id=str(self.user_id),
82+
)
83+
6184

6285
@dataclass
6386
class CreateFromParamsActionResult(BaseActionResult):

0 commit comments

Comments
 (0)