Skip to content

Commit ae734ed

Browse files
fregataaclaude
authored andcommitted
feat(BA-5681): auto-sync user-scope entries on role assign/unassign (#10990)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 71e1b53 commit ae734ed

5 files changed

Lines changed: 608 additions & 6 deletions

File tree

changes/10990.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Auto-sync user-scope membership entries in `association_scopes_entities` when roles are assigned or revoked

src/ai/backend/manager/repositories/group/db_source/db_source.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,9 @@ async def bind_user_to_project(self, user_id: UUID, project_id: UUID) -> None:
858858
"""Add a user to a project (business association + RBAC scope binding).
859859
860860
Skips if the user is already a member of the project.
861+
862+
TODO: Remove once association_groups_users is fully migrated to
863+
association_scopes_entities.
861864
"""
862865
async with self._db.begin_session() as session:
863866
already_bound = await session.scalar(
@@ -878,7 +881,11 @@ async def bind_user_to_project(self, user_id: UUID, project_id: UUID) -> None:
878881
await execute_rbac_scope_binder(session, RBACScopeBinder(pairs=[pair]))
879882

880883
async def unbind_user_from_project(self, user_id: UUID, project_id: UUID) -> None:
881-
"""Remove a user from a project (business association + RBAC scope binding)."""
884+
"""Remove a user from a project (business association + RBAC scope binding).
885+
886+
TODO: Remove once association_groups_users is fully migrated to
887+
association_scopes_entities.
888+
"""
882889
async with self._db.begin_session() as session:
883890
await session.execute(
884891
sa.delete(AssocGroupUserRow).where(

src/ai/backend/manager/repositories/group/repository.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,19 @@ async def bind_user_to_project(self, user_id: UUID, project_id: UUID) -> None:
144144
"""Add a user to a project (business association + RBAC scope binding).
145145
146146
Skips if the user is already a member of the project.
147+
148+
TODO: Remove once association_groups_users is fully migrated to
149+
association_scopes_entities.
147150
"""
148151
await self._db_source.bind_user_to_project(user_id, project_id)
149152

150153
@group_repository_resilience.apply()
151154
async def unbind_user_from_project(self, user_id: UUID, project_id: UUID) -> None:
152-
"""Remove a user from a project (business association + RBAC scope binding)."""
155+
"""Remove a user from a project (business association + RBAC scope binding).
156+
157+
TODO: Remove once association_groups_users is fully migrated to
158+
association_scopes_entities.
159+
"""
153160
await self._db_source.unbind_user_from_project(user_id, project_id)
154161

155162
@group_repository_resilience.apply()

src/ai/backend/manager/repositories/permission_controller/db_source/db_source.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import logging
22
import uuid
3-
from collections.abc import Iterable, Sequence
3+
from collections.abc import Collection, Iterable, Sequence
44
from dataclasses import dataclass, field
55
from typing import Any, cast
66

77
import sqlalchemy as sa
8+
from sqlalchemy.dialects.postgresql import insert as pg_insert
89
from sqlalchemy.ext.asyncio import AsyncSession as SASession
910
from sqlalchemy.orm import contains_eager, selectinload
1011

@@ -106,6 +107,83 @@ class PermissionDBSource:
106107
def __init__(self, db: ExtendedAsyncSAEngine) -> None:
107108
self._db = db
108109

110+
@staticmethod
111+
async def _sync_user_scopes_on_assign(
112+
db_session: SASession,
113+
user_ids: Collection[uuid.UUID],
114+
) -> None:
115+
"""Ensure user-scope membership entries exist for all assigned roles.
116+
117+
For each user, finds every scope bound to any of their assigned roles
118+
and inserts the corresponding user-scope entries. Executed as a single
119+
``INSERT … SELECT`` so the role lookup and insert share the same snapshot.
120+
"""
121+
if not user_ids:
122+
return
123+
ase = AssociationScopesEntitiesRow
124+
source = (
125+
sa.select(
126+
ase.scope_type,
127+
ase.scope_id,
128+
sa.literal(EntityType.USER.value).label("entity_type"),
129+
sa.cast(UserRoleRow.user_id, sa.String).label("entity_id"),
130+
sa.literal(RelationType.AUTO.value).label("relation_type"),
131+
)
132+
.join(
133+
UserRoleRow,
134+
sa.cast(UserRoleRow.role_id, sa.String) == ase.entity_id,
135+
)
136+
.where(
137+
ase.entity_type == EntityType.ROLE,
138+
UserRoleRow.user_id.in_(user_ids),
139+
)
140+
)
141+
await db_session.execute(
142+
pg_insert(ase)
143+
.from_select(
144+
["scope_type", "scope_id", "entity_type", "entity_id", "relation_type"],
145+
source,
146+
)
147+
.on_conflict_do_nothing()
148+
)
149+
150+
@staticmethod
151+
async def _sync_user_scopes_on_revoke(
152+
db_session: SASession,
153+
user_ids: Collection[uuid.UUID],
154+
) -> None:
155+
"""Remove user-scope entries no longer covered by any assigned role.
156+
157+
Deletes user-scope rows for *user_ids* when no assigned role binds
158+
the user to that scope. Executed as a single ``DELETE`` statement
159+
so the coverage check and deletion share the same snapshot.
160+
"""
161+
if not user_ids:
162+
return
163+
ase = AssociationScopesEntitiesRow
164+
str_user_ids = [str(uid) for uid in user_ids]
165+
ase_remaining = sa.orm.aliased(ase, flat=True)
166+
await db_session.execute(
167+
sa.delete(ase).where(
168+
ase.entity_type == EntityType.USER,
169+
ase.entity_id.in_(str_user_ids),
170+
~sa.exists(
171+
sa.select(sa.literal(1))
172+
.select_from(ase_remaining)
173+
.join(
174+
UserRoleRow,
175+
sa.cast(UserRoleRow.role_id, sa.String) == ase_remaining.entity_id,
176+
)
177+
.where(
178+
ase_remaining.entity_type == EntityType.ROLE,
179+
ase_remaining.scope_type == ase.scope_type,
180+
ase_remaining.scope_id == ase.scope_id,
181+
sa.cast(UserRoleRow.user_id, sa.String) == ase.entity_id,
182+
)
183+
),
184+
)
185+
)
186+
109187
# ------------------------------------------------------------------ role CRUD
110188

111189
async def create_role(self, input_data: CreateRoleInput) -> RoleRow:
@@ -276,6 +354,7 @@ async def assign_role(self, data: UserRoleAssignmentInput) -> UserRoleRow:
276354
)
277355
)
278356
result = await execute_creator(db_session, creator)
357+
await self._sync_user_scopes_on_assign(db_session, [data.user_id])
279358
return result.row
280359

281360
async def revoke_role(self, data: UserRoleRevocationInput) -> RoleRevocationResult:
@@ -299,8 +378,13 @@ async def revoke_role(self, data: UserRoleRevocationInput) -> RoleRevocationResu
299378
await db_session.delete(user_role_row)
300379
await db_session.flush()
301380

302-
# Single query: find projects this role belongs to and count
303-
# remaining user-role mappings per project
381+
await self._sync_user_scopes_on_revoke(db_session, [data.user_id])
382+
383+
# Used by PermissionControllerService.revoke_role() to decide whether
384+
# to call GroupDBSource.unbind_user_from_project().
385+
# TODO: remove this query when unbind_user_from_project() is retired
386+
# (i.e. association_groups_users is fully migrated to
387+
# association_scopes_entities).
304388
ase = AssociationScopesEntitiesRow
305389
project_subq = (
306390
sa.select(ase.scope_id).where(
@@ -1041,7 +1125,10 @@ async def bulk_assign_role(
10411125
self, bulk_creator: BulkCreator[UserRoleRow]
10421126
) -> BulkCreatorResultWithFailures[UserRoleRow]:
10431127
async with self._db.begin_session() as db_session:
1044-
return await execute_bulk_creator_partial(db_session, bulk_creator)
1128+
result = await execute_bulk_creator_partial(db_session, bulk_creator)
1129+
all_user_ids = [row.user_id for row in result.successes]
1130+
await self._sync_user_scopes_on_assign(db_session, all_user_ids)
1131+
return result
10451132

10461133
async def bulk_revoke_role(
10471134
self, data: BulkUserRoleRevocationInput
@@ -1081,5 +1168,7 @@ async def bulk_revoke_role(
10811168
str(e),
10821169
)
10831170
failures.append(BulkRoleRevocationFailure(user_id=user_id, message=str(e)))
1171+
revoked_user_ids = [s.user_id for s in successes]
1172+
await self._sync_user_scopes_on_revoke(db_session, revoked_user_ids)
10841173

10851174
return BulkRoleRevocationResultData(successes=successes, failures=failures)

0 commit comments

Comments
 (0)