Skip to content

Commit e20625f

Browse files
fregataaclaude
andauthored
feat(BA-5776): implement bulk permission check query (#11189)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 75b1ecd commit e20625f

6 files changed

Lines changed: 903 additions & 150 deletions

File tree

changes/11189.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add bulk scope-chain permission check query for validating multiple entities in a single DB round-trip.

src/ai/backend/manager/data/permission/role.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from .permission import ScopedPermissionCreateInput
1515
from .status import RoleStatus
16-
from .types import EntityType, OperationType, RBACElementRef, RoleSource
16+
from .types import EntityType, OperationType, RBACElementRef, RBACElementType, RoleSource
1717

1818

1919
@dataclass(frozen=True)
@@ -94,6 +94,14 @@ class ScopeChainPermissionCheckInput:
9494
permission_entity_type: EntityType | None
9595

9696

97+
@dataclass(frozen=True)
98+
class BulkPermissionCheckInput:
99+
user_id: uuid.UUID
100+
target_element_type: RBACElementType
101+
target_entity_ids: list[str]
102+
operation: OperationType
103+
104+
97105
@dataclass(frozen=True)
98106
class UserRoleAssignmentInput:
99107
"""

src/ai/backend/manager/repositories/permission_controller/db_source/db_source.py

Lines changed: 86 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uuid
33
from collections.abc import Collection, Iterable, Sequence
44
from dataclasses import dataclass, field
5-
from typing import Any, cast
5+
from typing import Any
66

77
import sqlalchemy as sa
88
from sqlalchemy.dialects.postgresql import insert as pg_insert
@@ -29,6 +29,7 @@
2929
from ai.backend.manager.data.permission.role import (
3030
AssignedUserData,
3131
AssignedUserListResult,
32+
BulkPermissionCheckInput,
3233
BulkRoleRevocationFailure,
3334
BulkRoleRevocationResultData,
3435
BulkUserRoleRevocationInput,
@@ -101,6 +102,15 @@ class CreateRoleInput:
101102
scope_refs: Sequence[RBACElementRef] = field(default_factory=list)
102103

103104

105+
@dataclass(frozen=True)
106+
class _ScopeChainQueryParams:
107+
user_id: uuid.UUID
108+
target_element_type: RBACElementType
109+
entity_ids: list[str]
110+
operation: OperationType
111+
permission_entity_type: EntityType | None = None
112+
113+
104114
class PermissionDBSource:
105115
_db: ExtendedAsyncSAEngine
106116

@@ -515,17 +525,6 @@ async def get_user_roles(self, user_id: uuid.UUID) -> list[RoleRow]:
515525
result = await db_session.scalars(stmt)
516526
return list(result.all())
517527

518-
async def get_entity_mapped_scopes(
519-
self, target_object_id: ObjectId
520-
) -> list[AssociationScopesEntitiesRow]:
521-
async with self._db.begin_readonly_session_read_committed() as db_session:
522-
stmt = sa.select(AssociationScopesEntitiesRow).where(
523-
AssociationScopesEntitiesRow.entity_id == target_object_id.entity_id,
524-
AssociationScopesEntitiesRow.entity_type == target_object_id.entity_type.value,
525-
)
526-
result = await db_session.scalars(stmt)
527-
return list(result.all())
528-
529528
async def check_scope_permission_exist(
530529
self,
531530
user_id: uuid.UUID,
@@ -556,42 +555,6 @@ async def check_scope_permission_exist(
556555
result = await db_session.scalar(role_query)
557556
return result or False
558557

559-
def _make_query_statement_for_object_permission(
560-
self,
561-
user_id: uuid.UUID,
562-
object_ids: Iterable[ObjectId],
563-
) -> sa.sql.Select[Any]:
564-
object_id_for_cond = [obj_id.entity_id for obj_id in object_ids]
565-
return (
566-
sa.select(RoleRow)
567-
.select_from(
568-
sa.join(RoleRow, UserRoleRow, RoleRow.id == UserRoleRow.role_id)
569-
.join(PermissionRow, RoleRow.id == PermissionRow.role_id)
570-
.join(
571-
AssociationScopesEntitiesRow,
572-
sa.and_(
573-
PermissionRow.scope_id == AssociationScopesEntitiesRow.scope_id,
574-
PermissionRow.scope_type == AssociationScopesEntitiesRow.scope_type,
575-
),
576-
)
577-
.join(ObjectPermissionRow, RoleRow.id == ObjectPermissionRow.role_id)
578-
)
579-
.where(
580-
sa.and_(
581-
RoleRow.status == RoleStatus.ACTIVE,
582-
UserRoleRow.user_id == user_id,
583-
sa.or_(
584-
PermissionRow.scope_type == ScopeType.GLOBAL,
585-
AssociationScopesEntitiesRow.entity_id.in_(object_id_for_cond),
586-
ObjectPermissionRow.entity_id.in_(object_id_for_cond),
587-
),
588-
)
589-
)
590-
.options(
591-
contains_eager(RoleRow.object_permission_rows),
592-
)
593-
)
594-
595558
def _make_query_statement_for_object_permissions(
596559
self,
597560
user_id: uuid.UUID,
@@ -639,20 +602,6 @@ def _make_query_statement_for_object_permissions(
639602
)
640603
)
641604

642-
async def check_object_permission_exist(
643-
self,
644-
user_id: uuid.UUID,
645-
object_id: ObjectId,
646-
operation: OperationType,
647-
) -> bool:
648-
role_query = self._make_query_statement_for_object_permissions(
649-
user_id, [object_id], operation
650-
)
651-
async with self._db.begin_readonly_session_read_committed() as db_session:
652-
result = await db_session.scalars(role_query)
653-
role_rows = cast(list[RoleRow], result.unique().all())
654-
return len(role_rows) > 0
655-
656605
async def check_batch_object_permission_exist(
657606
self,
658607
user_id: uuid.UUID,
@@ -943,36 +892,35 @@ async def search_element_associations_in_scope(
943892

944893
@staticmethod
945894
def _build_scope_chain_cte(
946-
target_element_ref: RBACElementRef,
895+
target_entity_type: EntityType,
896+
entity_ids: list[str],
947897
) -> sa.CTE:
948-
"""Build a recursive CTE that walks the scope chain upward via AUTO edges only.
898+
"""Build a recursive CTE that walks the scope chain upward via AUTO edges.
949899
950-
Starting from the target entity, traverses association_scopes_entities
951-
following only AUTO edges to find all ancestor scopes. REF edges
952-
terminate the chain — scopes beyond a REF edge are unreachable.
953-
954-
Uses UNION (not UNION ALL) to prevent infinite recursion on cycles.
900+
Carries entity_id through the recursion so each result row can be
901+
traced back to its originating entity.
955902
"""
956903
ase = AssociationScopesEntitiesRow.__table__
957-
target_entity_type = target_element_ref.element_type.to_entity_type()
958904

959-
# Base case: direct AUTO scope entries for the target entity.
905+
# Base case: direct AUTO scope entries for target entities.
960906
scope_chain_base = sa.select(
907+
ase.c.entity_id,
961908
ase.c.scope_type,
962909
ase.c.scope_id,
963910
).where(
964911
sa.and_(
965912
ase.c.entity_type == target_entity_type,
966-
ase.c.entity_id == target_element_ref.element_id,
913+
ase.c.entity_id.in_(entity_ids),
967914
ase.c.relation_type == RelationType.AUTO,
968915
)
969916
)
970917
scope_chain_cte = scope_chain_base.cte("scope_chain", recursive=True)
971918

972-
# Recursive case: walk parent scopes upward, following AUTO edges only.
919+
# Recursive case: walk parent scopes upward, carrying entity_id.
973920
parent = ase.alias("parent")
974921
scope_chain_recursive = (
975922
sa.select(
923+
scope_chain_cte.c.entity_id,
976924
parent.c.scope_type,
977925
parent.c.scope_id,
978926
)
@@ -991,21 +939,30 @@ def _build_scope_chain_cte(
991939
)
992940
return scope_chain_cte.union(scope_chain_recursive)
993941

994-
def _build_scope_chain_permission_query(
942+
async def _check_permissions_via_scope_chain(
995943
self,
996-
user_id: uuid.UUID,
997-
target_element_ref: RBACElementRef,
998-
target_entity_type: EntityType,
999-
operation: OperationType,
1000-
) -> sa.sql.Select[Any]:
1001-
"""Build a query that checks permissions via CTE scope chain traversal."""
944+
params: _ScopeChainQueryParams,
945+
) -> set[str]:
946+
"""Core scope chain permission check shared by single and batch methods.
947+
948+
Two-layer check:
949+
1. Scope chain traversal — walks AUTO edges upward via recursive CTE.
950+
2. Self-scope direct match — permission scoped to the target entity itself.
951+
952+
Returns the set of entity IDs that have the requested permission.
953+
"""
954+
association_entity_type = params.target_element_type.to_entity_type()
955+
permission_entity_type = params.permission_entity_type or association_entity_type
956+
target_scope_type = params.target_element_type.to_scope_type()
957+
1002958
permissions = PermissionRow.__table__
1003959
user_roles = UserRoleRow.__table__
1004960
roles = RoleRow.__table__
1005961

1006-
scope_chain_cte = self._build_scope_chain_cte(target_element_ref)
1007-
return (
1008-
sa.select(sa.literal(1))
962+
# Layer 1: scope chain traversal.
963+
scope_chain_cte = self._build_scope_chain_cte(association_entity_type, params.entity_ids)
964+
scope_chain_query = (
965+
sa.select(scope_chain_cte.c.entity_id)
1009966
.select_from(
1010967
scope_chain_cte.join(
1011968
permissions,
@@ -1025,30 +982,17 @@ def _build_scope_chain_permission_query(
1025982
)
1026983
.where(
1027984
sa.and_(
1028-
user_roles.c.user_id == user_id,
985+
user_roles.c.user_id == params.user_id,
1029986
roles.c.status == RoleStatus.ACTIVE,
1030-
permissions.c.entity_type == target_entity_type,
1031-
permissions.c.operation == operation,
987+
permissions.c.entity_type == permission_entity_type,
988+
permissions.c.operation == params.operation,
1032989
)
1033990
)
1034-
.limit(1)
1035991
)
1036992

1037-
def _build_self_scope_permission_query(
1038-
self,
1039-
user_id: uuid.UUID,
1040-
target_element_ref: RBACElementRef,
1041-
target_entity_type: EntityType,
1042-
target_scope_type: ScopeType,
1043-
operation: OperationType,
1044-
) -> sa.sql.Select[Any]:
1045-
"""Build a query that checks permissions scoped to the target entity itself."""
1046-
permissions = PermissionRow.__table__
1047-
user_roles = UserRoleRow.__table__
1048-
roles = RoleRow.__table__
1049-
1050-
return (
1051-
sa.select(sa.literal(1))
993+
# Layer 2: self-scope direct match.
994+
self_scope_query = (
995+
sa.select(permissions.c.scope_id.label("entity_id"))
1052996
.select_from(
1053997
permissions.join(
1054998
roles,
@@ -1060,66 +1004,59 @@ def _build_self_scope_permission_query(
10601004
)
10611005
.where(
10621006
sa.and_(
1063-
user_roles.c.user_id == user_id,
1007+
user_roles.c.user_id == params.user_id,
10641008
roles.c.status == RoleStatus.ACTIVE,
10651009
permissions.c.scope_type == target_scope_type,
1066-
permissions.c.scope_id == target_element_ref.element_id,
1067-
permissions.c.entity_type == target_entity_type,
1068-
permissions.c.operation == operation,
1010+
permissions.c.scope_id.in_(params.entity_ids),
1011+
permissions.c.entity_type == permission_entity_type,
1012+
permissions.c.operation == params.operation,
10691013
)
10701014
)
1071-
.limit(1)
10721015
)
10731016

1017+
combined_query = sa.union(scope_chain_query, self_scope_query)
1018+
1019+
granted: set[str] = set()
1020+
async with self._db.begin_readonly_session_read_committed() as db_session:
1021+
rows = await db_session.execute(combined_query)
1022+
for row in rows:
1023+
granted.add(row.entity_id)
1024+
1025+
return granted
1026+
10741027
async def check_permission_with_scope_chain(
10751028
self,
10761029
data: ScopeChainPermissionCheckInput,
10771030
) -> bool:
1078-
"""CTE-based permission check that traverses the scope chain.
1079-
1080-
Two-layer check:
1081-
1. Self-scope direct match — permission scoped to the target entity itself.
1082-
2. Scope chain traversal — walks AUTO edges upward via CTE.
1083-
1084-
Args:
1085-
data: Permission check input containing user_id, target_element_ref,
1086-
operation, and optional permission_entity_type override.
1087-
When permission_entity_type is provided, it is used as the
1088-
entity_type filter for permission matching instead of deriving
1089-
it from target_element_ref. This enables cross-scope entity type
1090-
checks (e.g., checking MODEL_DEPLOYMENT:READ permission at
1091-
PROJECT scope).
1092-
"""
1093-
target_entity_type = (
1094-
data.permission_entity_type or data.target_element_ref.element_type.to_entity_type()
1095-
)
1096-
target_scope_type = data.target_element_ref.element_type.to_scope_type()
1097-
1098-
combined_query = sa.select(
1099-
sa.or_(
1100-
sa.exists(
1101-
self._build_scope_chain_permission_query(
1102-
data.user_id,
1103-
data.target_element_ref,
1104-
target_entity_type,
1105-
data.operation,
1106-
)
1107-
),
1108-
sa.exists(
1109-
self._build_self_scope_permission_query(
1110-
data.user_id,
1111-
data.target_element_ref,
1112-
target_entity_type,
1113-
target_scope_type,
1114-
data.operation,
1115-
)
1116-
),
1031+
"""CTE-based permission check for a single entity."""
1032+
granted = await self._check_permissions_via_scope_chain(
1033+
_ScopeChainQueryParams(
1034+
user_id=data.user_id,
1035+
target_element_type=data.target_element_ref.element_type,
1036+
entity_ids=[data.target_element_ref.element_id],
1037+
operation=data.operation,
1038+
permission_entity_type=data.permission_entity_type,
11171039
)
11181040
)
1041+
return data.target_element_ref.element_id in granted
11191042

1120-
async with self._db.begin_readonly_session_read_committed() as db_session:
1121-
result = await db_session.scalar(combined_query)
1122-
return result or False
1043+
async def check_bulk_permission_with_scope_chain(
1044+
self,
1045+
data: BulkPermissionCheckInput,
1046+
) -> dict[str, bool]:
1047+
"""Batch CTE-based permission check for multiple entities."""
1048+
if not data.target_entity_ids:
1049+
return {}
1050+
1051+
granted = await self._check_permissions_via_scope_chain(
1052+
_ScopeChainQueryParams(
1053+
user_id=data.user_id,
1054+
target_element_type=data.target_element_type,
1055+
entity_ids=data.target_entity_ids,
1056+
operation=data.operation,
1057+
)
1058+
)
1059+
return {eid: eid in granted for eid in data.target_entity_ids}
11231060

11241061
async def bulk_assign_role(
11251062
self, bulk_creator: BulkCreator[UserRoleRow]

src/ai/backend/manager/repositories/permission_controller/repository.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ai.backend.manager.data.permission.role import (
2020
AssignedUserListResult,
2121
BatchEntityPermissionCheckInput,
22+
BulkPermissionCheckInput,
2223
BulkRoleAssignmentFailure,
2324
BulkRoleAssignmentResultData,
2425
BulkRoleRevocationResultData,
@@ -336,3 +337,15 @@ async def check_permission_with_scope_chain(
336337
scope. REF edges are not traversed.
337338
"""
338339
return await self._db_source.check_permission_with_scope_chain(data)
340+
341+
@permission_controller_repository_resilience.apply()
342+
async def check_bulk_permission_with_scope_chain(
343+
self,
344+
data: BulkPermissionCheckInput,
345+
) -> dict[str, bool]:
346+
"""Batch permission check that traverses the scope chain via AUTO edges.
347+
348+
Same semantics as check_permission_with_scope_chain but for multiple
349+
entities of the same RBACElementType in a single query.
350+
"""
351+
return await self._db_source.check_bulk_permission_with_scope_chain(data)

0 commit comments

Comments
 (0)