22import uuid
33from collections .abc import Collection , Iterable , Sequence
44from dataclasses import dataclass , field
5- from typing import Any , cast
5+ from typing import Any
66
77import sqlalchemy as sa
88from sqlalchemy .dialects .postgresql import insert as pg_insert
2929from ai .backend .manager .data .permission .role import (
3030 AssignedUserData ,
3131 AssignedUserListResult ,
32+ BulkPermissionCheckInput ,
3233 BulkRoleRevocationFailure ,
3334 BulkRoleRevocationResultData ,
3435 BulkUserRoleRevocationInput ,
@@ -101,6 +102,15 @@ class CreateRoleInput:
101102 scope_refs : Sequence [RBACElementRef ] = field (default_factory = list )
102103
103104
105+ @dataclass (frozen = True )
106+ class _ScopeChainQueryParams :
107+ user_id : uuid .UUID
108+ target_element_type : RBACElementType
109+ entity_ids : list [str ]
110+ operation : OperationType
111+ permission_entity_type : EntityType | None = None
112+
113+
104114class PermissionDBSource :
105115 _db : ExtendedAsyncSAEngine
106116
@@ -515,17 +525,6 @@ async def get_user_roles(self, user_id: uuid.UUID) -> list[RoleRow]:
515525 result = await db_session .scalars (stmt )
516526 return list (result .all ())
517527
518- async def get_entity_mapped_scopes (
519- self , target_object_id : ObjectId
520- ) -> list [AssociationScopesEntitiesRow ]:
521- async with self ._db .begin_readonly_session_read_committed () as db_session :
522- stmt = sa .select (AssociationScopesEntitiesRow ).where (
523- AssociationScopesEntitiesRow .entity_id == target_object_id .entity_id ,
524- AssociationScopesEntitiesRow .entity_type == target_object_id .entity_type .value ,
525- )
526- result = await db_session .scalars (stmt )
527- return list (result .all ())
528-
529528 async def check_scope_permission_exist (
530529 self ,
531530 user_id : uuid .UUID ,
@@ -556,42 +555,6 @@ async def check_scope_permission_exist(
556555 result = await db_session .scalar (role_query )
557556 return result or False
558557
559- def _make_query_statement_for_object_permission (
560- self ,
561- user_id : uuid .UUID ,
562- object_ids : Iterable [ObjectId ],
563- ) -> sa .sql .Select [Any ]:
564- object_id_for_cond = [obj_id .entity_id for obj_id in object_ids ]
565- return (
566- sa .select (RoleRow )
567- .select_from (
568- sa .join (RoleRow , UserRoleRow , RoleRow .id == UserRoleRow .role_id )
569- .join (PermissionRow , RoleRow .id == PermissionRow .role_id )
570- .join (
571- AssociationScopesEntitiesRow ,
572- sa .and_ (
573- PermissionRow .scope_id == AssociationScopesEntitiesRow .scope_id ,
574- PermissionRow .scope_type == AssociationScopesEntitiesRow .scope_type ,
575- ),
576- )
577- .join (ObjectPermissionRow , RoleRow .id == ObjectPermissionRow .role_id )
578- )
579- .where (
580- sa .and_ (
581- RoleRow .status == RoleStatus .ACTIVE ,
582- UserRoleRow .user_id == user_id ,
583- sa .or_ (
584- PermissionRow .scope_type == ScopeType .GLOBAL ,
585- AssociationScopesEntitiesRow .entity_id .in_ (object_id_for_cond ),
586- ObjectPermissionRow .entity_id .in_ (object_id_for_cond ),
587- ),
588- )
589- )
590- .options (
591- contains_eager (RoleRow .object_permission_rows ),
592- )
593- )
594-
595558 def _make_query_statement_for_object_permissions (
596559 self ,
597560 user_id : uuid .UUID ,
@@ -639,20 +602,6 @@ def _make_query_statement_for_object_permissions(
639602 )
640603 )
641604
642- async def check_object_permission_exist (
643- self ,
644- user_id : uuid .UUID ,
645- object_id : ObjectId ,
646- operation : OperationType ,
647- ) -> bool :
648- role_query = self ._make_query_statement_for_object_permissions (
649- user_id , [object_id ], operation
650- )
651- async with self ._db .begin_readonly_session_read_committed () as db_session :
652- result = await db_session .scalars (role_query )
653- role_rows = cast (list [RoleRow ], result .unique ().all ())
654- return len (role_rows ) > 0
655-
656605 async def check_batch_object_permission_exist (
657606 self ,
658607 user_id : uuid .UUID ,
@@ -943,36 +892,35 @@ async def search_element_associations_in_scope(
943892
944893 @staticmethod
945894 def _build_scope_chain_cte (
946- target_element_ref : RBACElementRef ,
895+ target_entity_type : EntityType ,
896+ entity_ids : list [str ],
947897 ) -> sa .CTE :
948- """Build a recursive CTE that walks the scope chain upward via AUTO edges only .
898+ """Build a recursive CTE that walks the scope chain upward via AUTO edges.
949899
950- Starting from the target entity, traverses association_scopes_entities
951- following only AUTO edges to find all ancestor scopes. REF edges
952- terminate the chain — scopes beyond a REF edge are unreachable.
953-
954- Uses UNION (not UNION ALL) to prevent infinite recursion on cycles.
900+ Carries entity_id through the recursion so each result row can be
901+ traced back to its originating entity.
955902 """
956903 ase = AssociationScopesEntitiesRow .__table__
957- target_entity_type = target_element_ref .element_type .to_entity_type ()
958904
959- # Base case: direct AUTO scope entries for the target entity .
905+ # Base case: direct AUTO scope entries for target entities .
960906 scope_chain_base = sa .select (
907+ ase .c .entity_id ,
961908 ase .c .scope_type ,
962909 ase .c .scope_id ,
963910 ).where (
964911 sa .and_ (
965912 ase .c .entity_type == target_entity_type ,
966- ase .c .entity_id == target_element_ref . element_id ,
913+ ase .c .entity_id . in_ ( entity_ids ) ,
967914 ase .c .relation_type == RelationType .AUTO ,
968915 )
969916 )
970917 scope_chain_cte = scope_chain_base .cte ("scope_chain" , recursive = True )
971918
972- # Recursive case: walk parent scopes upward, following AUTO edges only .
919+ # Recursive case: walk parent scopes upward, carrying entity_id .
973920 parent = ase .alias ("parent" )
974921 scope_chain_recursive = (
975922 sa .select (
923+ scope_chain_cte .c .entity_id ,
976924 parent .c .scope_type ,
977925 parent .c .scope_id ,
978926 )
@@ -991,21 +939,30 @@ def _build_scope_chain_cte(
991939 )
992940 return scope_chain_cte .union (scope_chain_recursive )
993941
994- def _build_scope_chain_permission_query (
942+ async def _check_permissions_via_scope_chain (
995943 self ,
996- user_id : uuid .UUID ,
997- target_element_ref : RBACElementRef ,
998- target_entity_type : EntityType ,
999- operation : OperationType ,
1000- ) -> sa .sql .Select [Any ]:
1001- """Build a query that checks permissions via CTE scope chain traversal."""
944+ params : _ScopeChainQueryParams ,
945+ ) -> set [str ]:
946+ """Core scope chain permission check shared by single and batch methods.
947+
948+ Two-layer check:
949+ 1. Scope chain traversal — walks AUTO edges upward via recursive CTE.
950+ 2. Self-scope direct match — permission scoped to the target entity itself.
951+
952+ Returns the set of entity IDs that have the requested permission.
953+ """
954+ association_entity_type = params .target_element_type .to_entity_type ()
955+ permission_entity_type = params .permission_entity_type or association_entity_type
956+ target_scope_type = params .target_element_type .to_scope_type ()
957+
1002958 permissions = PermissionRow .__table__
1003959 user_roles = UserRoleRow .__table__
1004960 roles = RoleRow .__table__
1005961
1006- scope_chain_cte = self ._build_scope_chain_cte (target_element_ref )
1007- return (
1008- sa .select (sa .literal (1 ))
962+ # Layer 1: scope chain traversal.
963+ scope_chain_cte = self ._build_scope_chain_cte (association_entity_type , params .entity_ids )
964+ scope_chain_query = (
965+ sa .select (scope_chain_cte .c .entity_id )
1009966 .select_from (
1010967 scope_chain_cte .join (
1011968 permissions ,
@@ -1025,30 +982,17 @@ def _build_scope_chain_permission_query(
1025982 )
1026983 .where (
1027984 sa .and_ (
1028- user_roles .c .user_id == user_id ,
985+ user_roles .c .user_id == params . user_id ,
1029986 roles .c .status == RoleStatus .ACTIVE ,
1030- permissions .c .entity_type == target_entity_type ,
1031- permissions .c .operation == operation ,
987+ permissions .c .entity_type == permission_entity_type ,
988+ permissions .c .operation == params . operation ,
1032989 )
1033990 )
1034- .limit (1 )
1035991 )
1036992
1037- def _build_self_scope_permission_query (
1038- self ,
1039- user_id : uuid .UUID ,
1040- target_element_ref : RBACElementRef ,
1041- target_entity_type : EntityType ,
1042- target_scope_type : ScopeType ,
1043- operation : OperationType ,
1044- ) -> sa .sql .Select [Any ]:
1045- """Build a query that checks permissions scoped to the target entity itself."""
1046- permissions = PermissionRow .__table__
1047- user_roles = UserRoleRow .__table__
1048- roles = RoleRow .__table__
1049-
1050- return (
1051- sa .select (sa .literal (1 ))
993+ # Layer 2: self-scope direct match.
994+ self_scope_query = (
995+ sa .select (permissions .c .scope_id .label ("entity_id" ))
1052996 .select_from (
1053997 permissions .join (
1054998 roles ,
@@ -1060,66 +1004,59 @@ def _build_self_scope_permission_query(
10601004 )
10611005 .where (
10621006 sa .and_ (
1063- user_roles .c .user_id == user_id ,
1007+ user_roles .c .user_id == params . user_id ,
10641008 roles .c .status == RoleStatus .ACTIVE ,
10651009 permissions .c .scope_type == target_scope_type ,
1066- permissions .c .scope_id == target_element_ref . element_id ,
1067- permissions .c .entity_type == target_entity_type ,
1068- permissions .c .operation == operation ,
1010+ permissions .c .scope_id . in_ ( params . entity_ids ) ,
1011+ permissions .c .entity_type == permission_entity_type ,
1012+ permissions .c .operation == params . operation ,
10691013 )
10701014 )
1071- .limit (1 )
10721015 )
10731016
1017+ combined_query = sa .union (scope_chain_query , self_scope_query )
1018+
1019+ granted : set [str ] = set ()
1020+ async with self ._db .begin_readonly_session_read_committed () as db_session :
1021+ rows = await db_session .execute (combined_query )
1022+ for row in rows :
1023+ granted .add (row .entity_id )
1024+
1025+ return granted
1026+
10741027 async def check_permission_with_scope_chain (
10751028 self ,
10761029 data : ScopeChainPermissionCheckInput ,
10771030 ) -> bool :
1078- """CTE-based permission check that traverses the scope chain.
1079-
1080- Two-layer check:
1081- 1. Self-scope direct match — permission scoped to the target entity itself.
1082- 2. Scope chain traversal — walks AUTO edges upward via CTE.
1083-
1084- Args:
1085- data: Permission check input containing user_id, target_element_ref,
1086- operation, and optional permission_entity_type override.
1087- When permission_entity_type is provided, it is used as the
1088- entity_type filter for permission matching instead of deriving
1089- it from target_element_ref. This enables cross-scope entity type
1090- checks (e.g., checking MODEL_DEPLOYMENT:READ permission at
1091- PROJECT scope).
1092- """
1093- target_entity_type = (
1094- data .permission_entity_type or data .target_element_ref .element_type .to_entity_type ()
1095- )
1096- target_scope_type = data .target_element_ref .element_type .to_scope_type ()
1097-
1098- combined_query = sa .select (
1099- sa .or_ (
1100- sa .exists (
1101- self ._build_scope_chain_permission_query (
1102- data .user_id ,
1103- data .target_element_ref ,
1104- target_entity_type ,
1105- data .operation ,
1106- )
1107- ),
1108- sa .exists (
1109- self ._build_self_scope_permission_query (
1110- data .user_id ,
1111- data .target_element_ref ,
1112- target_entity_type ,
1113- target_scope_type ,
1114- data .operation ,
1115- )
1116- ),
1031+ """CTE-based permission check for a single entity."""
1032+ granted = await self ._check_permissions_via_scope_chain (
1033+ _ScopeChainQueryParams (
1034+ user_id = data .user_id ,
1035+ target_element_type = data .target_element_ref .element_type ,
1036+ entity_ids = [data .target_element_ref .element_id ],
1037+ operation = data .operation ,
1038+ permission_entity_type = data .permission_entity_type ,
11171039 )
11181040 )
1041+ return data .target_element_ref .element_id in granted
11191042
1120- async with self ._db .begin_readonly_session_read_committed () as db_session :
1121- result = await db_session .scalar (combined_query )
1122- return result or False
1043+ async def check_bulk_permission_with_scope_chain (
1044+ self ,
1045+ data : BulkPermissionCheckInput ,
1046+ ) -> dict [str , bool ]:
1047+ """Batch CTE-based permission check for multiple entities."""
1048+ if not data .target_entity_ids :
1049+ return {}
1050+
1051+ granted = await self ._check_permissions_via_scope_chain (
1052+ _ScopeChainQueryParams (
1053+ user_id = data .user_id ,
1054+ target_element_type = data .target_element_type ,
1055+ entity_ids = data .target_entity_ids ,
1056+ operation = data .operation ,
1057+ )
1058+ )
1059+ return {eid : eid in granted for eid in data .target_entity_ids }
11231060
11241061 async def bulk_assign_role (
11251062 self , bulk_creator : BulkCreator [UserRoleRow ]
0 commit comments