Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions ansible_base/jwt_consumer/common/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
import uuid
from datetime import datetime
from typing import Optional
Expand Down Expand Up @@ -286,18 +287,27 @@ def process_rbac_permissions(self):

# Claims hash mismatch - fetch from gateway
logger.info(f"Claims hash mismatch for user {user_ansible_id}. JWT: {jwt_claims_hash}, Local: {local_claims_hash}. Fetching from gateway.")
reconcile_start = time.monotonic()
try:
fetch_start = time.monotonic()
gateway_claims = self._fetch_jwt_claims_from_gateway(user_ansible_id)
# Extract claims structure from gateway response
fetch_elapsed = time.monotonic() - fetch_start

objects = gateway_claims.get('objects', {})
object_roles = gateway_claims.get('object_roles', {})
global_roles = gateway_claims.get('global_roles', [])

# Process the RBAC permissions with the gateway claims
save_start = time.monotonic()
save_user_claims(self.user, objects, object_roles, global_roles)
save_elapsed = time.monotonic() - save_start

# Update cache with the new hash
self.cache.cache_claims_hash(user_ansible_id, jwt_claims_hash)

total_elapsed = time.monotonic() - reconcile_start
logger.info(
f"Claims reconciliation for {user_ansible_id}: "
f"fetch={fetch_elapsed:.3f}s, save={save_elapsed:.3f}s, total={total_elapsed:.3f}s"
)
except GatewayLockedException:
if self.token.get('user_data', {}).get("is_superuser", False) is False:
self.log_and_raise(
Expand Down
88 changes: 83 additions & 5 deletions ansible_base/rbac/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from uuid import UUID

from django.conf import settings
from django.db import transaction
from django.db.utils import IntegrityError

from ansible_base.rbac.models import ObjectRole, RoleDefinition, RoleEvaluation, RoleEvaluationUUID
from ansible_base.rbac.permission_registry import permission_registry
Expand Down Expand Up @@ -121,6 +123,56 @@ def get_parent_teams_of_teams(org_team_mapping: dict) -> dict[int, list[int]]:
return team_team_parents


def _is_stale_objectrole_fk(exc):
"""Return True if *exc* is specifically an FK violation from a stale ObjectRole reference.

PostgreSQL (psycopg2/psycopg3): checks SQLSTATE 23503 (foreign_key_violation)
and that the referenced table is ``dab_rbac_objectrole``. Other
IntegrityError sub-types (unique-constraint, check-constraint, etc.) or FK
violations referencing a different table return False so they propagate
normally.

Other backends (SQLite): returns False. The TOCTOU race requires truly
concurrent committed transactions, which SQLite's global write lock
prevents, so FK errors there indicate a real bug.
"""
cause = exc.__cause__
if cause is None:
return False
# psycopg3 exposes .sqlstate, psycopg2 exposes .pgcode
sqlstate = getattr(cause, 'sqlstate', None) or getattr(cause, 'pgcode', None)
if sqlstate != '23503':
return False
return 'dab_rbac_objectrole' in str(cause)


def _safe_m2m_add(team, to_add):
"""Add ObjectRole IDs to team.member_roles, handling concurrent deletions.

A TOCTOU race exists: ObjectRole IDs collected during the read phase of
compute_team_member_roles() may be deleted by a concurrent transaction
(e.g. org deletion cascading through rbac_post_delete_remove_object_roles)
before this write. The savepoint allows us to catch the FK violation
without aborting the outer transaction, then retry with only IDs that
still exist.
"""
try:
with transaction.atomic():
team.member_roles.add(*to_add)
except IntegrityError as exc:
if not _is_stale_objectrole_fk(exc):
raise
to_add = set(ObjectRole.objects.filter(id__in=to_add).values_list('id', flat=True))
if to_add:
try:
with transaction.atomic():
team.member_roles.add(*to_add)
except IntegrityError as exc:
if not _is_stale_objectrole_fk(exc):
raise
logger.warning('Persistent IntegrityError adding member_roles for team %s, will be corrected on next recompute', team.id)


def compute_team_member_roles():
"""
Fills in the ObjectRole.provides_teams relationship for all teams.
Expand Down Expand Up @@ -154,11 +206,39 @@ def compute_team_member_roles():
to_add = expected_ids - existing_ids
to_remove = existing_ids - expected_ids
if to_add:
team.member_roles.add(*to_add)
_safe_m2m_add(team, to_add)
if to_remove:
team.member_roles.remove(*to_remove)


def _safe_bulk_create_evaluations(model, evaluations, ignore_conflicts):
"""bulk_create RoleEvaluation rows, handling concurrent ObjectRole deletions.

ignore_conflicts (ON CONFLICT DO NOTHING) only suppresses unique-constraint
violations. A concurrent ObjectRole deletion causes an FK violation on
the role_id column which is a different IntegrityError. We use a savepoint
so the outer transaction stays healthy, then retry without the stale rows.
"""
if not evaluations:
return
try:
with transaction.atomic():
model.objects.bulk_create(evaluations, ignore_conflicts=ignore_conflicts)
except IntegrityError as exc:
if not _is_stale_objectrole_fk(exc):
raise
existing_role_ids = set(ObjectRole.objects.filter(id__in={e.role_id for e in evaluations}).values_list('id', flat=True))
evaluations = [e for e in evaluations if e.role_id in existing_role_ids]
if evaluations:
try:
with transaction.atomic():
model.objects.bulk_create(evaluations, ignore_conflicts=ignore_conflicts)
except IntegrityError as exc:
if not _is_stale_objectrole_fk(exc):
raise
logger.warning('Persistent IntegrityError in bulk_create for %s, will be corrected on next recompute', model.__name__)


def compute_object_role_permissions(object_roles=None, types_prefetch=None):
"""
Assumes the ObjectRole.provides_teams relationship is correct.
Expand Down Expand Up @@ -194,10 +274,8 @@ def compute_object_role_permissions(object_roles=None, types_prefetch=None):
to_add_uuid.append(evaluation)
else:
raise RuntimeError(f'Could not find a place in cache for {evaluation}')
if to_add_int:
RoleEvaluation.objects.bulk_create(to_add_int, ignore_conflicts=settings.ANSIBLE_BASE_EVALUATIONS_IGNORE_CONFLICTS)
if to_add_uuid:
RoleEvaluationUUID.objects.bulk_create(to_add_uuid, ignore_conflicts=settings.ANSIBLE_BASE_EVALUATIONS_IGNORE_CONFLICTS)
_safe_bulk_create_evaluations(RoleEvaluation, to_add_int, settings.ANSIBLE_BASE_EVALUATIONS_IGNORE_CONFLICTS)
_safe_bulk_create_evaluations(RoleEvaluationUUID, to_add_uuid, settings.ANSIBLE_BASE_EVALUATIONS_IGNORE_CONFLICTS)

if to_delete:
logger.info(f'Deleting {len(to_delete)} object-permission records')
Expand Down
122 changes: 72 additions & 50 deletions ansible_base/rbac/claims.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import hashlib
import json
import logging
import time
from collections import defaultdict
from typing import Optional, Tuple, Union

from django.apps import apps
from django.conf import settings
from django.db import transaction
from django.db.models import F, Model, OuterRef, QuerySet
from django.db.utils import IntegrityError

Expand Down Expand Up @@ -310,61 +312,81 @@ def get_or_create_resource(objects: dict, content_type: str, data: dict) -> Tupl


def save_user_claims(user: Model, objects: dict, object_roles: dict, global_roles: list) -> None:
"""Apply RBAC permissions from gateway claims data.

Wrapped in:
- no_reverse_sync(): Suppresses post_save signals from stub Org/Team
creation that would otherwise make HTTP calls back to Gateway.
- transaction.atomic(): Ensures partial failures roll back cleanly
instead of leaving the user in an inconsistent permission state.
- defer_role_evaluation(): Batches all RoleEvaluation cache updates
into a single pass instead of N individual compute cycles.
"""
Apply RBAC permissions from claims data
"""
role_diff = RoleUserAssignment.objects.filter(user=user, role_definition__name__in=settings.ANSIBLE_BASE_JWT_MANAGED_ROLES)

for system_role_name in global_roles:
logger.debug(f"Processing system role {system_role_name} for {user.username}")
rd = get_role_definition(system_role_name)
if rd:
if rd.name in settings.ANSIBLE_BASE_JWT_MANAGED_ROLES:
assignment = rd.give_global_permission(user)
role_diff = role_diff.exclude(pk=assignment.pk)
logger.info(f"Granted user {user.username} global role {system_role_name}")
from ansible_base.rbac.triggers import defer_role_evaluation
from ansible_base.resource_registry.signals.handlers import no_reverse_sync

start = time.monotonic()
grants = 0
removals = 0

with no_reverse_sync(), transaction.atomic(), defer_role_evaluation():
role_diff = RoleUserAssignment.objects.filter(user=user, role_definition__name__in=settings.ANSIBLE_BASE_JWT_MANAGED_ROLES)

for system_role_name in global_roles:
logger.debug(f"Processing system role {system_role_name} for {user.username}")
rd = get_role_definition(system_role_name)
if rd:
if rd.name in settings.ANSIBLE_BASE_JWT_MANAGED_ROLES:
assignment = rd.give_global_permission(user)
role_diff = role_diff.exclude(pk=assignment.pk)
grants += 1
logger.info(f"Granted user {user.username} global role {system_role_name}")
else:
logger.error(f"Unable to grant {user.username} system level role {system_role_name} because it is not a JWT managed role")
else:
logger.error(f"Unable to grant {user.username} system level role {system_role_name} because it is not a JWT managed role")
else:
logger.error(f"Unable to grant {user.username} system level role {system_role_name} because it does not exist")
continue

for object_role_name in object_roles.keys():
rd = get_role_definition(object_role_name)
if rd is None:
logger.error(f"Unable to grant {user.username} object role {object_role_name} because it does not exist")
continue
elif rd.name not in settings.ANSIBLE_BASE_JWT_MANAGED_ROLES:
logger.error(f"Unable to grant {user.username} object role {object_role_name} because it is not a JWT managed role")
continue

object_type = object_roles[object_role_name]['content_type']
object_indexes = object_roles[object_role_name]['objects']
logger.error(f"Unable to grant {user.username} system level role {system_role_name} because it does not exist")
continue

for index in object_indexes:
object_data = objects[object_type][index]
try:
resource, obj = get_or_create_resource(objects, object_type, object_data)
except IntegrityError as e:
logger.warning(
f"Got integrity error ({e}) on {object_data}. Skipping {object_type} assignment. "
"Please make sure the sync task is running to prevent this warning in the future."
)
for object_role_name in object_roles.keys():
rd = get_role_definition(object_role_name)
if rd is None:
logger.error(f"Unable to grant {user.username} object role {object_role_name} because it does not exist")
continue
elif rd.name not in settings.ANSIBLE_BASE_JWT_MANAGED_ROLES:
logger.error(f"Unable to grant {user.username} object role {object_role_name} because it is not a JWT managed role")
continue

object_type = object_roles[object_role_name]['content_type']
object_indexes = object_roles[object_role_name]['objects']

for index in object_indexes:
object_data = objects[object_type][index]
try:
resource, obj = get_or_create_resource(objects, object_type, object_data)
except IntegrityError as e:
logger.warning(
f"Got integrity error ({e}) on {object_data}. Skipping {object_type} assignment. "
"Please make sure the sync task is running to prevent this warning in the future."
)
continue

if resource is not None:
assignment = rd.give_permission(user, obj)
role_diff = role_diff.exclude(pk=assignment.pk)
grants += 1
logger.info(f"Granted user {user.username} role {object_role_name} to object {obj.name} with ansible_id {object_data['ansible_id']}")

for role_assignment in role_diff:
rd = role_assignment.role_definition
content_object = role_assignment.content_object
if content_object:
rd.remove_permission(user, content_object)
else:
rd.remove_global_permission(user)
removals += 1

if resource is not None:
assignment = rd.give_permission(user, obj)
role_diff = role_diff.exclude(pk=assignment.pk)
logger.info(f"Granted user {user.username} role {object_role_name} to object {obj.name} with ansible_id {object_data['ansible_id']}")

# Remove all permissions not authorized by the JWT
for role_assignment in role_diff:
rd = role_assignment.role_definition
content_object = role_assignment.content_object
if content_object:
rd.remove_permission(user, content_object)
else:
rd.remove_global_permission(user)
elapsed = time.monotonic() - start
logger.info(f"save_user_claims for {user.username}: {grants} grants, {removals} removals in {elapsed:.3f}s")


# ---- for claims hashing ----
Expand Down
57 changes: 57 additions & 0 deletions ansible_base/rbac/triggers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import contextmanager
from typing import Union
from uuid import UUID

Expand All @@ -25,6 +26,57 @@
dab_post_migrate = Signal()


class _DeferredEvaluationState:
"""Tracks whether RoleEvaluation cache updates should be deferred.

Used by defer_role_evaluation() to collect object_roles that need
recomputation and process them in a single batch on context exit.
"""

def __init__(self):
self.enabled = False
self.object_roles = set()
self.needs_team_recompute = False


_deferred_evaluation = _DeferredEvaluationState()


@contextmanager
def defer_role_evaluation():
"""Defer RoleEvaluation cache updates until the context exits.

During bulk operations like save_user_claims(), each give_permission() /
remove_permission() call triggers compute_object_role_permissions()
individually. This context manager collects all affected object_roles
and processes them in a single batch when the context exits, reducing
N individual compute+write cycles to one.

Supports nesting by saving and restoring previous state.
"""
previous_enabled = _deferred_evaluation.enabled
previous_roles = _deferred_evaluation.object_roles
previous_teams = _deferred_evaluation.needs_team_recompute

_deferred_evaluation.enabled = True
_deferred_evaluation.object_roles = set()
_deferred_evaluation.needs_team_recompute = False
try:
yield
finally:
deferred_roles = _deferred_evaluation.object_roles
deferred_teams = _deferred_evaluation.needs_team_recompute

_deferred_evaluation.enabled = previous_enabled
_deferred_evaluation.object_roles = previous_roles
_deferred_evaluation.needs_team_recompute = previous_teams

if deferred_teams:
compute_team_member_roles()
if deferred_roles:
compute_object_role_permissions(object_roles=deferred_roles)

Comment on lines +29 to +78
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Make deferred evaluation state request-scoped and flush only on the outermost successful exit.

_deferred_evaluation is process-global, so one thread can flip enabled while another thread is in update_after_assignment(), causing cross-request batching and skipped recomputes. The current nesting logic also computes on inner exit instead of merging the inner batch back into the outer context, so the advertised nested behavior is broken and exception paths still do work that the surrounding transaction may roll back. Use thread-local/context-local state here and only run compute_* when the outermost context exits without an error.

Suggested fix
+import threading
 from contextlib import contextmanager
 from typing import Union
 from uuid import UUID
@@
-class _DeferredEvaluationState:
+class _DeferredEvaluationState(threading.local):
     """Tracks whether RoleEvaluation cache updates should be deferred.
@@
 `@contextmanager`
 def defer_role_evaluation():
@@
     previous_enabled = _deferred_evaluation.enabled
     previous_roles = _deferred_evaluation.object_roles
     previous_teams = _deferred_evaluation.needs_team_recompute
@@
+    completed = False
     try:
         yield
+        completed = True
     finally:
         deferred_roles = _deferred_evaluation.object_roles
         deferred_teams = _deferred_evaluation.needs_team_recompute
 
         _deferred_evaluation.enabled = previous_enabled
         _deferred_evaluation.object_roles = previous_roles
         _deferred_evaluation.needs_team_recompute = previous_teams
 
-        if deferred_teams:
-            compute_team_member_roles()
-        if deferred_roles:
-            compute_object_role_permissions(object_roles=deferred_roles)
+        if previous_enabled:
+            previous_roles.update(deferred_roles)
+            _deferred_evaluation.needs_team_recompute = previous_teams or deferred_teams
+        elif completed:
+            if deferred_teams:
+                compute_team_member_roles()
+            if deferred_roles:
+                compute_object_role_permissions(object_roles=deferred_roles)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ansible_base/rbac/triggers.py` around lines 29 - 78, The current
_deferred_evaluation is process-global and flushes work on every context exit,
causing cross-request races and incorrect nested behavior; change
defer_role_evaluation() to use request-scoped (thread-local or contextvar) state
with a nesting counter (e.g., add a depth attribute on the state), have inner
contexts merge their object_roles and needs_team_recompute into the parent state
instead of running computes, and only when the outermost context exits
successfully (no exception) run compute_team_member_roles() and
compute_object_role_permissions(object_roles=...) and then clear the state;
update references to _deferred_evaluation in defer_role_evaluation() (and any
callers like update_after_assignment) to use the new thread-local/context-local
state and ensure computes are skipped if an exception occurred in the context.


def team_ancestor_roles(team):
"""
Return a queryset of all roles that directly or indirectly grant any form of permission to a team.
Expand Down Expand Up @@ -84,6 +136,11 @@ def needed_updates_on_assignment(role_definition, actor, object_role, created=Fa

def update_after_assignment(update_teams, to_update):
"Call this with the output of needed_updates_on_assignment"
if _deferred_evaluation.enabled:
_deferred_evaluation.object_roles.update(to_update)
_deferred_evaluation.needs_team_recompute |= update_teams
return

if update_teams:
compute_team_member_roles()

Expand Down
Loading
Loading