diff --git a/ansible_base/jwt_consumer/common/auth.py b/ansible_base/jwt_consumer/common/auth.py index e9cd72ee3..dd78971b4 100644 --- a/ansible_base/jwt_consumer/common/auth.py +++ b/ansible_base/jwt_consumer/common/auth.py @@ -1,4 +1,5 @@ import logging +import time import uuid from datetime import datetime from typing import Optional @@ -286,18 +287,27 @@ def process_rbac_permissions(self): # Claims hash mismatch - fetch from gateway logger.info(f"Claims hash mismatch for user {user_ansible_id}. JWT: {jwt_claims_hash}, Local: {local_claims_hash}. Fetching from gateway.") + reconcile_start = time.monotonic() try: + fetch_start = time.monotonic() gateway_claims = self._fetch_jwt_claims_from_gateway(user_ansible_id) - # Extract claims structure from gateway response + fetch_elapsed = time.monotonic() - fetch_start + objects = gateway_claims.get('objects', {}) object_roles = gateway_claims.get('object_roles', {}) global_roles = gateway_claims.get('global_roles', []) - # Process the RBAC permissions with the gateway claims + save_start = time.monotonic() save_user_claims(self.user, objects, object_roles, global_roles) + save_elapsed = time.monotonic() - save_start - # Update cache with the new hash self.cache.cache_claims_hash(user_ansible_id, jwt_claims_hash) + + total_elapsed = time.monotonic() - reconcile_start + logger.info( + f"Claims reconciliation for {user_ansible_id}: " + f"fetch={fetch_elapsed:.3f}s, save={save_elapsed:.3f}s, total={total_elapsed:.3f}s" + ) except GatewayLockedException: if self.token.get('user_data', {}).get("is_superuser", False) is False: self.log_and_raise( diff --git a/ansible_base/rbac/caching.py b/ansible_base/rbac/caching.py index 99701decf..c118024ba 100644 --- a/ansible_base/rbac/caching.py +++ b/ansible_base/rbac/caching.py @@ -4,6 +4,8 @@ from uuid import UUID from django.conf import settings +from django.db import transaction +from django.db.utils import IntegrityError from ansible_base.rbac.models import ObjectRole, RoleDefinition, RoleEvaluation, RoleEvaluationUUID from ansible_base.rbac.permission_registry import permission_registry @@ -121,6 +123,56 @@ def get_parent_teams_of_teams(org_team_mapping: dict) -> dict[int, list[int]]: return team_team_parents +def _is_stale_objectrole_fk(exc): + """Return True if *exc* is specifically an FK violation from a stale ObjectRole reference. + + PostgreSQL (psycopg2/psycopg3): checks SQLSTATE 23503 (foreign_key_violation) + and that the referenced table is ``dab_rbac_objectrole``. Other + IntegrityError sub-types (unique-constraint, check-constraint, etc.) or FK + violations referencing a different table return False so they propagate + normally. + + Other backends (SQLite): returns False. The TOCTOU race requires truly + concurrent committed transactions, which SQLite's global write lock + prevents, so FK errors there indicate a real bug. + """ + cause = exc.__cause__ + if cause is None: + return False + # psycopg3 exposes .sqlstate, psycopg2 exposes .pgcode + sqlstate = getattr(cause, 'sqlstate', None) or getattr(cause, 'pgcode', None) + if sqlstate != '23503': + return False + return 'dab_rbac_objectrole' in str(cause) + + +def _safe_m2m_add(team, to_add): + """Add ObjectRole IDs to team.member_roles, handling concurrent deletions. + + A TOCTOU race exists: ObjectRole IDs collected during the read phase of + compute_team_member_roles() may be deleted by a concurrent transaction + (e.g. org deletion cascading through rbac_post_delete_remove_object_roles) + before this write. The savepoint allows us to catch the FK violation + without aborting the outer transaction, then retry with only IDs that + still exist. + """ + try: + with transaction.atomic(): + team.member_roles.add(*to_add) + except IntegrityError as exc: + if not _is_stale_objectrole_fk(exc): + raise + to_add = set(ObjectRole.objects.filter(id__in=to_add).values_list('id', flat=True)) + if to_add: + try: + with transaction.atomic(): + team.member_roles.add(*to_add) + except IntegrityError as exc: + if not _is_stale_objectrole_fk(exc): + raise + logger.warning('Persistent IntegrityError adding member_roles for team %s, will be corrected on next recompute', team.id) + + def compute_team_member_roles(): """ Fills in the ObjectRole.provides_teams relationship for all teams. @@ -154,11 +206,39 @@ def compute_team_member_roles(): to_add = expected_ids - existing_ids to_remove = existing_ids - expected_ids if to_add: - team.member_roles.add(*to_add) + _safe_m2m_add(team, to_add) if to_remove: team.member_roles.remove(*to_remove) +def _safe_bulk_create_evaluations(model, evaluations, ignore_conflicts): + """bulk_create RoleEvaluation rows, handling concurrent ObjectRole deletions. + + ignore_conflicts (ON CONFLICT DO NOTHING) only suppresses unique-constraint + violations. A concurrent ObjectRole deletion causes an FK violation on + the role_id column which is a different IntegrityError. We use a savepoint + so the outer transaction stays healthy, then retry without the stale rows. + """ + if not evaluations: + return + try: + with transaction.atomic(): + model.objects.bulk_create(evaluations, ignore_conflicts=ignore_conflicts) + except IntegrityError as exc: + if not _is_stale_objectrole_fk(exc): + raise + existing_role_ids = set(ObjectRole.objects.filter(id__in={e.role_id for e in evaluations}).values_list('id', flat=True)) + evaluations = [e for e in evaluations if e.role_id in existing_role_ids] + if evaluations: + try: + with transaction.atomic(): + model.objects.bulk_create(evaluations, ignore_conflicts=ignore_conflicts) + except IntegrityError as exc: + if not _is_stale_objectrole_fk(exc): + raise + logger.warning('Persistent IntegrityError in bulk_create for %s, will be corrected on next recompute', model.__name__) + + def compute_object_role_permissions(object_roles=None, types_prefetch=None): """ Assumes the ObjectRole.provides_teams relationship is correct. @@ -194,10 +274,8 @@ def compute_object_role_permissions(object_roles=None, types_prefetch=None): to_add_uuid.append(evaluation) else: raise RuntimeError(f'Could not find a place in cache for {evaluation}') - if to_add_int: - RoleEvaluation.objects.bulk_create(to_add_int, ignore_conflicts=settings.ANSIBLE_BASE_EVALUATIONS_IGNORE_CONFLICTS) - if to_add_uuid: - RoleEvaluationUUID.objects.bulk_create(to_add_uuid, ignore_conflicts=settings.ANSIBLE_BASE_EVALUATIONS_IGNORE_CONFLICTS) + _safe_bulk_create_evaluations(RoleEvaluation, to_add_int, settings.ANSIBLE_BASE_EVALUATIONS_IGNORE_CONFLICTS) + _safe_bulk_create_evaluations(RoleEvaluationUUID, to_add_uuid, settings.ANSIBLE_BASE_EVALUATIONS_IGNORE_CONFLICTS) if to_delete: logger.info(f'Deleting {len(to_delete)} object-permission records') diff --git a/ansible_base/rbac/claims.py b/ansible_base/rbac/claims.py index 64e1e011d..48660b631 100644 --- a/ansible_base/rbac/claims.py +++ b/ansible_base/rbac/claims.py @@ -1,11 +1,13 @@ import hashlib import json import logging +import time from collections import defaultdict from typing import Optional, Tuple, Union from django.apps import apps from django.conf import settings +from django.db import transaction from django.db.models import F, Model, OuterRef, QuerySet from django.db.utils import IntegrityError @@ -310,61 +312,81 @@ def get_or_create_resource(objects: dict, content_type: str, data: dict) -> Tupl def save_user_claims(user: Model, objects: dict, object_roles: dict, global_roles: list) -> None: + """Apply RBAC permissions from gateway claims data. + + Wrapped in: + - no_reverse_sync(): Suppresses post_save signals from stub Org/Team + creation that would otherwise make HTTP calls back to Gateway. + - transaction.atomic(): Ensures partial failures roll back cleanly + instead of leaving the user in an inconsistent permission state. + - defer_role_evaluation(): Batches all RoleEvaluation cache updates + into a single pass instead of N individual compute cycles. """ - Apply RBAC permissions from claims data - """ - role_diff = RoleUserAssignment.objects.filter(user=user, role_definition__name__in=settings.ANSIBLE_BASE_JWT_MANAGED_ROLES) - - for system_role_name in global_roles: - logger.debug(f"Processing system role {system_role_name} for {user.username}") - rd = get_role_definition(system_role_name) - if rd: - if rd.name in settings.ANSIBLE_BASE_JWT_MANAGED_ROLES: - assignment = rd.give_global_permission(user) - role_diff = role_diff.exclude(pk=assignment.pk) - logger.info(f"Granted user {user.username} global role {system_role_name}") + from ansible_base.rbac.triggers import defer_role_evaluation + from ansible_base.resource_registry.signals.handlers import no_reverse_sync + + start = time.monotonic() + grants = 0 + removals = 0 + + with no_reverse_sync(), transaction.atomic(), defer_role_evaluation(): + role_diff = RoleUserAssignment.objects.filter(user=user, role_definition__name__in=settings.ANSIBLE_BASE_JWT_MANAGED_ROLES) + + for system_role_name in global_roles: + logger.debug(f"Processing system role {system_role_name} for {user.username}") + rd = get_role_definition(system_role_name) + if rd: + if rd.name in settings.ANSIBLE_BASE_JWT_MANAGED_ROLES: + assignment = rd.give_global_permission(user) + role_diff = role_diff.exclude(pk=assignment.pk) + grants += 1 + logger.info(f"Granted user {user.username} global role {system_role_name}") + else: + logger.error(f"Unable to grant {user.username} system level role {system_role_name} because it is not a JWT managed role") else: - logger.error(f"Unable to grant {user.username} system level role {system_role_name} because it is not a JWT managed role") - else: - logger.error(f"Unable to grant {user.username} system level role {system_role_name} because it does not exist") - continue - - for object_role_name in object_roles.keys(): - rd = get_role_definition(object_role_name) - if rd is None: - logger.error(f"Unable to grant {user.username} object role {object_role_name} because it does not exist") - continue - elif rd.name not in settings.ANSIBLE_BASE_JWT_MANAGED_ROLES: - logger.error(f"Unable to grant {user.username} object role {object_role_name} because it is not a JWT managed role") - continue - - object_type = object_roles[object_role_name]['content_type'] - object_indexes = object_roles[object_role_name]['objects'] + logger.error(f"Unable to grant {user.username} system level role {system_role_name} because it does not exist") + continue - for index in object_indexes: - object_data = objects[object_type][index] - try: - resource, obj = get_or_create_resource(objects, object_type, object_data) - except IntegrityError as e: - logger.warning( - f"Got integrity error ({e}) on {object_data}. Skipping {object_type} assignment. " - "Please make sure the sync task is running to prevent this warning in the future." - ) + for object_role_name in object_roles.keys(): + rd = get_role_definition(object_role_name) + if rd is None: + logger.error(f"Unable to grant {user.username} object role {object_role_name} because it does not exist") continue + elif rd.name not in settings.ANSIBLE_BASE_JWT_MANAGED_ROLES: + logger.error(f"Unable to grant {user.username} object role {object_role_name} because it is not a JWT managed role") + continue + + object_type = object_roles[object_role_name]['content_type'] + object_indexes = object_roles[object_role_name]['objects'] + + for index in object_indexes: + object_data = objects[object_type][index] + try: + resource, obj = get_or_create_resource(objects, object_type, object_data) + except IntegrityError as e: + logger.warning( + f"Got integrity error ({e}) on {object_data}. Skipping {object_type} assignment. " + "Please make sure the sync task is running to prevent this warning in the future." + ) + continue + + if resource is not None: + assignment = rd.give_permission(user, obj) + role_diff = role_diff.exclude(pk=assignment.pk) + grants += 1 + logger.info(f"Granted user {user.username} role {object_role_name} to object {obj.name} with ansible_id {object_data['ansible_id']}") + + for role_assignment in role_diff: + rd = role_assignment.role_definition + content_object = role_assignment.content_object + if content_object: + rd.remove_permission(user, content_object) + else: + rd.remove_global_permission(user) + removals += 1 - if resource is not None: - assignment = rd.give_permission(user, obj) - role_diff = role_diff.exclude(pk=assignment.pk) - logger.info(f"Granted user {user.username} role {object_role_name} to object {obj.name} with ansible_id {object_data['ansible_id']}") - - # Remove all permissions not authorized by the JWT - for role_assignment in role_diff: - rd = role_assignment.role_definition - content_object = role_assignment.content_object - if content_object: - rd.remove_permission(user, content_object) - else: - rd.remove_global_permission(user) + elapsed = time.monotonic() - start + logger.info(f"save_user_claims for {user.username}: {grants} grants, {removals} removals in {elapsed:.3f}s") # ---- for claims hashing ---- diff --git a/ansible_base/rbac/triggers.py b/ansible_base/rbac/triggers.py index 5aa03bbee..4dcf2bafa 100644 --- a/ansible_base/rbac/triggers.py +++ b/ansible_base/rbac/triggers.py @@ -1,4 +1,5 @@ import logging +from contextlib import contextmanager from typing import Union from uuid import UUID @@ -25,6 +26,57 @@ dab_post_migrate = Signal() +class _DeferredEvaluationState: + """Tracks whether RoleEvaluation cache updates should be deferred. + + Used by defer_role_evaluation() to collect object_roles that need + recomputation and process them in a single batch on context exit. + """ + + def __init__(self): + self.enabled = False + self.object_roles = set() + self.needs_team_recompute = False + + +_deferred_evaluation = _DeferredEvaluationState() + + +@contextmanager +def defer_role_evaluation(): + """Defer RoleEvaluation cache updates until the context exits. + + During bulk operations like save_user_claims(), each give_permission() / + remove_permission() call triggers compute_object_role_permissions() + individually. This context manager collects all affected object_roles + and processes them in a single batch when the context exits, reducing + N individual compute+write cycles to one. + + Supports nesting by saving and restoring previous state. + """ + previous_enabled = _deferred_evaluation.enabled + previous_roles = _deferred_evaluation.object_roles + previous_teams = _deferred_evaluation.needs_team_recompute + + _deferred_evaluation.enabled = True + _deferred_evaluation.object_roles = set() + _deferred_evaluation.needs_team_recompute = False + try: + yield + finally: + deferred_roles = _deferred_evaluation.object_roles + deferred_teams = _deferred_evaluation.needs_team_recompute + + _deferred_evaluation.enabled = previous_enabled + _deferred_evaluation.object_roles = previous_roles + _deferred_evaluation.needs_team_recompute = previous_teams + + if deferred_teams: + compute_team_member_roles() + if deferred_roles: + compute_object_role_permissions(object_roles=deferred_roles) + + def team_ancestor_roles(team): """ Return a queryset of all roles that directly or indirectly grant any form of permission to a team. @@ -84,6 +136,11 @@ def needed_updates_on_assignment(role_definition, actor, object_role, created=Fa def update_after_assignment(update_teams, to_update): "Call this with the output of needed_updates_on_assignment" + if _deferred_evaluation.enabled: + _deferred_evaluation.object_roles.update(to_update) + _deferred_evaluation.needs_team_recompute |= update_teams + return + if update_teams: compute_team_member_roles() diff --git a/test_app/tests/jwt_consumer/common/test_auth_timing.py b/test_app/tests/jwt_consumer/common/test_auth_timing.py new file mode 100644 index 000000000..1839c9ff3 --- /dev/null +++ b/test_app/tests/jwt_consumer/common/test_auth_timing.py @@ -0,0 +1,112 @@ +"""Tests for timing logs in process_rbac_permissions().""" + +import logging +from unittest import mock + +import pytest + +from ansible_base.jwt_consumer.common.auth import JWTCommonAuth + + +@pytest.mark.django_db +class TestProcessRbacPermissionsTiming: + def test_logs_timing_on_successful_reconciliation(self, admin_user, caplog): + """Verify timing breakdown is logged after a successful gateway fetch + save.""" + authentication = JWTCommonAuth() + authentication.user = admin_user + user_ansible_id = "12345678-1234-5678-9abc-123456789012" + authentication.token = { + "sub": user_ansible_id, + "claims_hash": "new_hash", + } + + gateway_response = { + 'objects': {}, + 'object_roles': {}, + 'global_roles': [], + } + + with ( + mock.patch.object(authentication.cache, 'get_cached_claims_hash', return_value=None), + mock.patch.object(authentication.cache, 'cache_claims_hash'), + mock.patch('ansible_base.rbac.claims.get_user_claims', return_value={}), + mock.patch('ansible_base.rbac.claims.get_user_claims_hashable_form', return_value={}), + mock.patch('ansible_base.rbac.claims.get_claims_hash', return_value="different_hash"), + mock.patch.object(authentication, '_fetch_jwt_claims_from_gateway', return_value=gateway_response), + mock.patch('ansible_base.rbac.claims.save_user_claims') as mock_save, + caplog.at_level(logging.INFO, logger='ansible_base.jwt_consumer.common.auth'), + ): + authentication.process_rbac_permissions() + + mock_save.assert_called_once() + + timing_logs = [r for r in caplog.records if 'Claims reconciliation' in r.message] + assert len(timing_logs) == 1 + msg = timing_logs[0].message + assert f'Claims reconciliation for {user_ansible_id}' in msg + assert 'fetch=' in msg + assert 'save=' in msg + assert 'total=' in msg + + def test_no_timing_log_on_cache_hit(self, admin_user, caplog): + """Cache hit should return early with no timing log.""" + authentication = JWTCommonAuth() + authentication.user = admin_user + authentication.token = { + "sub": "12345678-1234-5678-9abc-123456789012", + "claims_hash": "cached_hash", + } + + with ( + mock.patch.object(authentication.cache, 'get_cached_claims_hash', return_value="cached_hash"), + caplog.at_level(logging.INFO, logger='ansible_base.jwt_consumer.common.auth'), + ): + authentication.process_rbac_permissions() + + timing_logs = [r for r in caplog.records if 'Claims reconciliation' in r.message] + assert len(timing_logs) == 0 + + def test_no_timing_log_on_local_match(self, admin_user, caplog): + """Local hash match should return early with no timing log.""" + authentication = JWTCommonAuth() + authentication.user = admin_user + authentication.token = { + "sub": "12345678-1234-5678-9abc-123456789012", + "claims_hash": "matching_hash", + } + + with ( + mock.patch.object(authentication.cache, 'get_cached_claims_hash', return_value=None), + mock.patch('ansible_base.rbac.claims.get_user_claims', return_value={}), + mock.patch('ansible_base.rbac.claims.get_user_claims_hashable_form', return_value={}), + mock.patch('ansible_base.rbac.claims.get_claims_hash', return_value="matching_hash"), + caplog.at_level(logging.INFO, logger='ansible_base.jwt_consumer.common.auth'), + ): + authentication.process_rbac_permissions() + + timing_logs = [r for r in caplog.records if 'Claims reconciliation' in r.message] + assert len(timing_logs) == 0 + + def test_no_timing_log_on_gateway_failure(self, admin_user, caplog): + """Gateway failure should raise, not log timing.""" + authentication = JWTCommonAuth() + authentication.user = admin_user + authentication.token = { + "sub": "12345678-1234-5678-9abc-123456789012", + "claims_hash": "new_hash", + "user_data": {"is_superuser": False}, + } + + with ( + mock.patch.object(authentication.cache, 'get_cached_claims_hash', return_value=None), + mock.patch('ansible_base.rbac.claims.get_user_claims', return_value={}), + mock.patch('ansible_base.rbac.claims.get_user_claims_hashable_form', return_value={}), + mock.patch('ansible_base.rbac.claims.get_claims_hash', return_value="different_hash"), + mock.patch.object(authentication, '_fetch_jwt_claims_from_gateway', side_effect=Exception("network error")), + caplog.at_level(logging.INFO, logger='ansible_base.jwt_consumer.common.auth'), + ): + with pytest.raises(Exception, match="Unable to validate user permissions"): + authentication.process_rbac_permissions() + + timing_logs = [r for r in caplog.records if 'Claims reconciliation' in r.message] + assert len(timing_logs) == 0 diff --git a/test_app/tests/rbac/test_caching.py b/test_app/tests/rbac/test_caching.py new file mode 100644 index 000000000..ef0cac487 --- /dev/null +++ b/test_app/tests/rbac/test_caching.py @@ -0,0 +1,343 @@ +"""Tests for ansible_base.rbac.caching — specifically the TOCTOU race-condition +guards: _is_stale_objectrole_fk, _safe_m2m_add, and _safe_bulk_create_evaluations. +""" + +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest +from django.db.utils import IntegrityError + +from ansible_base.rbac.caching import ( + _is_stale_objectrole_fk, + _safe_bulk_create_evaluations, + _safe_m2m_add, +) +from ansible_base.rbac.models import RoleEvaluation + +# --------------------------------------------------------------------------- +# Helpers to build mock psycopg-style exceptions +# --------------------------------------------------------------------------- + + +def _make_integrity_error(sqlstate=None, message='some error', use_pgcode=False): + """Build a Django IntegrityError wrapping a fake DB-level cause.""" + cause = Exception(message) + if use_pgcode: + cause.pgcode = sqlstate + else: + cause.sqlstate = sqlstate + exc = IntegrityError(message) + exc.__cause__ = cause + return exc + + +def _make_objectrole_fk_error(use_pgcode=False): + """Build an IntegrityError that mimics a Postgres FK violation on dab_rbac_objectrole.""" + return _make_integrity_error( + sqlstate='23503', + message=( + 'insert or update on table "dab_rbac_objectrole_provides_teams" violates ' + 'foreign key constraint "dab_rbac_objectrole_p_objectrole_id_abc123_fk_dab_rbac_o"\n' + 'DETAIL: Key (objectrole_id)=(999) is not present in table "dab_rbac_objectrole".' + ), + use_pgcode=use_pgcode, + ) + + +@contextmanager +def _noop_atomic(*args, **kwargs): + """Stand-in for transaction.atomic() that skips real savepoint SQL.""" + yield + + +# --------------------------------------------------------------------------- +# _is_stale_objectrole_fk +# --------------------------------------------------------------------------- + + +class TestIsStaleObjectroleFk: + def test_returns_true_for_objectrole_fk_violation_psycopg3(self): + exc = _make_objectrole_fk_error(use_pgcode=False) + assert _is_stale_objectrole_fk(exc) is True + + def test_returns_true_for_objectrole_fk_violation_psycopg2(self): + exc = _make_objectrole_fk_error(use_pgcode=True) + assert _is_stale_objectrole_fk(exc) is True + + def test_returns_false_when_no_cause(self): + exc = IntegrityError('bare error') + assert _is_stale_objectrole_fk(exc) is False + + def test_returns_false_for_unique_constraint_violation(self): + exc = _make_integrity_error( + sqlstate='23505', + message='duplicate key value violates unique constraint "dab_rbac_objectrole_one_object_role"', + ) + assert _is_stale_objectrole_fk(exc) is False + + def test_returns_false_for_fk_violation_on_different_table(self): + exc = _make_integrity_error( + sqlstate='23503', + message='Key (team_id)=(42) is not present in table "main_team".', + ) + assert _is_stale_objectrole_fk(exc) is False + + def test_returns_false_for_check_constraint(self): + exc = _make_integrity_error( + sqlstate='23514', + message='new row violates check constraint "positive_id"', + ) + assert _is_stale_objectrole_fk(exc) is False + + def test_returns_false_for_sqlite_fk_error(self): + """SQLite FK errors have no sqlstate/pgcode — should not be suppressed.""" + cause = Exception('FOREIGN KEY constraint failed') + exc = IntegrityError('FOREIGN KEY constraint failed') + exc.__cause__ = cause + assert _is_stale_objectrole_fk(exc) is False + + def test_returns_false_when_sqlstate_is_none(self): + cause = Exception('some random db error') + exc = IntegrityError('some random db error') + exc.__cause__ = cause + assert _is_stale_objectrole_fk(exc) is False + + def test_returns_true_for_role_evaluation_fk_violation(self): + """The role_id FK on dab_rbac_roleevaluation also references dab_rbac_objectrole.""" + exc = _make_integrity_error( + sqlstate='23503', + message=( + 'insert or update on table "dab_rbac_roleevaluation" violates ' + 'foreign key constraint "dab_rbac_roleevaluat_role_id_abc_fk_dab_rbac_o"\n' + 'DETAIL: Key (role_id)=(42) is not present in table "dab_rbac_objectrole".' + ), + ) + assert _is_stale_objectrole_fk(exc) is True + + +# --------------------------------------------------------------------------- +# _safe_m2m_add +# +# team.member_roles is a Django M2M descriptor that returns a fresh manager +# each access, so we cannot patch it on a real Team instance. Instead these +# tests use a MagicMock team whose .member_roles.add is fully controllable, +# and mock transaction.atomic to skip real savepoint SQL (the mock-raised +# IntegrityError never touches the DB). +# --------------------------------------------------------------------------- + + +class TestSafeM2mAdd: + @pytest.mark.django_db + def test_happy_path_adds_ids(self, team, member_rd, rando): + """When no race occurs, compute_team_member_roles (which calls + _safe_m2m_add internally) works end-to-end with real DB objects.""" + from ansible_base.rbac.caching import compute_team_member_roles + + member_rd.give_permission(rando, team) + compute_team_member_roles() + assert team.member_roles.exists() + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + def test_retries_on_stale_fk(self, mock_or_qs): + """First add() hits a stale-FK IntegrityError, retry succeeds with filtered IDs.""" + mock_or_qs.filter.return_value.values_list.return_value = [1] + team = MagicMock() + call_count = 0 + + def flaky_add(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_objectrole_fk_error() + + team.member_roles.add.side_effect = flaky_add + _safe_m2m_add(team, {1, 9999}) + assert call_count == 2 + mock_or_qs.filter.assert_called_once() + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + @patch('ansible_base.rbac.caching.logger') + def test_logs_warning_on_persistent_fk_error(self, mock_logger, mock_or_qs): + """Both attempts fail with stale-FK error — logs warning, does not raise.""" + mock_or_qs.filter.return_value.values_list.return_value = [1] + team = MagicMock() + team.member_roles.add.side_effect = _make_objectrole_fk_error() + + _safe_m2m_add(team, {1}) + mock_logger.warning.assert_called_once() + assert 'Persistent IntegrityError' in mock_logger.warning.call_args[0][0] + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + def test_reraises_non_fk_integrity_error(self): + """A unique-constraint IntegrityError is re-raised, not swallowed.""" + team = MagicMock() + team.member_roles.add.side_effect = _make_integrity_error(sqlstate='23505', message='duplicate key') + with pytest.raises(IntegrityError, match='duplicate key'): + _safe_m2m_add(team, {1}) + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + def test_retry_reraises_non_fk_integrity_error(self, mock_or_qs): + """First call hits stale FK (retries), second call hits a different error — re-raised.""" + mock_or_qs.filter.return_value.values_list.return_value = [1] + team = MagicMock() + stale_exc = _make_objectrole_fk_error() + unique_exc = _make_integrity_error(sqlstate='23505', message='duplicate key') + call_count = 0 + + def switching_add(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise stale_exc + raise unique_exc + + team.member_roles.add.side_effect = switching_add + with pytest.raises(IntegrityError, match='duplicate key'): + _safe_m2m_add(team, {1}) + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + def test_skips_retry_when_no_valid_ids_remain(self, mock_or_qs): + """After filtering, no IDs remain — second add() is never called.""" + mock_or_qs.filter.return_value.values_list.return_value = [] + team = MagicMock() + exc = _make_objectrole_fk_error() + team.member_roles.add.side_effect = exc + + _safe_m2m_add(team, {9999}) + team.member_roles.add.assert_called_once() + + +# --------------------------------------------------------------------------- +# _safe_bulk_create_evaluations +# --------------------------------------------------------------------------- + + +class TestSafeBulkCreateEvaluations: + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + def test_noop_on_empty_list(self): + """Empty evaluations list returns immediately without DB calls.""" + with patch.object(RoleEvaluation.objects, 'bulk_create') as mock_bc: + _safe_bulk_create_evaluations(RoleEvaluation, [], False) + mock_bc.assert_not_called() + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + def test_happy_path_calls_bulk_create(self): + """When no error, bulk_create is called once.""" + evals = [MagicMock(role_id=1, object_id=1)] + with patch.object(RoleEvaluation.objects, 'bulk_create') as mock_bc: + _safe_bulk_create_evaluations(RoleEvaluation, evals, True) + mock_bc.assert_called_once_with(evals, ignore_conflicts=True) + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + def test_retries_on_stale_fk(self, mock_or_qs): + """First bulk_create hits stale FK, retry succeeds after filtering.""" + mock_or_qs.filter.return_value.values_list.return_value = [1] + eval1 = MagicMock(role_id=1, object_id=10) + eval2 = MagicMock(role_id=2, object_id=20) + exc = _make_objectrole_fk_error() + call_count = 0 + + def flaky_bulk_create(objs, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise exc + + with patch.object(RoleEvaluation.objects, 'bulk_create', side_effect=flaky_bulk_create): + _safe_bulk_create_evaluations(RoleEvaluation, [eval1, eval2], True) + + assert call_count == 2 + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + @patch('ansible_base.rbac.caching.logger') + def test_logs_warning_on_persistent_fk_error(self, mock_logger, mock_or_qs): + """Both attempts fail with stale-FK — logs warning, does not raise.""" + mock_or_qs.filter.return_value.values_list.return_value = [1] + eval1 = MagicMock(role_id=1, object_id=10) + exc = _make_objectrole_fk_error() + + with patch.object(RoleEvaluation.objects, 'bulk_create', side_effect=exc): + _safe_bulk_create_evaluations(RoleEvaluation, [eval1], False) + mock_logger.warning.assert_called_once() + assert 'Persistent IntegrityError' in mock_logger.warning.call_args[0][0] + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + def test_reraises_non_fk_integrity_error(self): + """A unique-constraint IntegrityError propagates.""" + eval1 = MagicMock(role_id=1, object_id=10) + exc = _make_integrity_error(sqlstate='23505', message='duplicate key') + + with patch.object(RoleEvaluation.objects, 'bulk_create', side_effect=exc): + with pytest.raises(IntegrityError, match='duplicate key'): + _safe_bulk_create_evaluations(RoleEvaluation, [eval1], False) + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + def test_retry_reraises_non_fk_integrity_error(self, mock_or_qs): + """First call hits stale FK, retry hits different error — re-raised.""" + mock_or_qs.filter.return_value.values_list.return_value = [1] + eval1 = MagicMock(role_id=1, object_id=10) + stale_exc = _make_objectrole_fk_error() + unique_exc = _make_integrity_error(sqlstate='23505', message='duplicate key') + call_count = 0 + + def switching_bulk_create(objs, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise stale_exc + raise unique_exc + + with patch.object(RoleEvaluation.objects, 'bulk_create', side_effect=switching_bulk_create): + with pytest.raises(IntegrityError, match='duplicate key'): + _safe_bulk_create_evaluations(RoleEvaluation, [eval1], False) + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + def test_filters_out_stale_role_ids(self, mock_or_qs): + """After first failure, only evaluations with still-valid role_ids are retried.""" + mock_or_qs.filter.return_value.values_list.return_value = [1] + eval_good = MagicMock(role_id=1, object_id=10) + eval_stale = MagicMock(role_id=999, object_id=20) + exc = _make_objectrole_fk_error() + + retried_evals = [] + call_count = 0 + + def flaky_then_capture(objs, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise exc + retried_evals.extend(objs) + + with patch.object(RoleEvaluation.objects, 'bulk_create', side_effect=flaky_then_capture): + _safe_bulk_create_evaluations(RoleEvaluation, [eval_good, eval_stale], True) + + assert retried_evals == [eval_good] + + @patch('ansible_base.rbac.caching.transaction.atomic', _noop_atomic) + @patch('ansible_base.rbac.caching.ObjectRole.objects') + def test_skips_retry_when_no_valid_ids_remain(self, mock_or_qs): + """After filtering, no evaluations remain — second bulk_create is never called.""" + mock_or_qs.filter.return_value.values_list.return_value = [] + eval1 = MagicMock(role_id=999, object_id=10) + exc = _make_objectrole_fk_error() + call_count = 0 + + def counting_bulk_create(objs, **kwargs): + nonlocal call_count + call_count += 1 + raise exc + + with patch.object(RoleEvaluation.objects, 'bulk_create', side_effect=counting_bulk_create): + _safe_bulk_create_evaluations(RoleEvaluation, [eval1], False) + + assert call_count == 1 diff --git a/test_app/tests/rbac/test_defer_role_evaluation.py b/test_app/tests/rbac/test_defer_role_evaluation.py new file mode 100644 index 000000000..5713e4a59 --- /dev/null +++ b/test_app/tests/rbac/test_defer_role_evaluation.py @@ -0,0 +1,147 @@ +from unittest.mock import patch + +import pytest + +from ansible_base.rbac.models import RoleEvaluation +from ansible_base.rbac.triggers import _deferred_evaluation, defer_role_evaluation, update_after_assignment +from test_app.models import Inventory + + +@pytest.mark.django_db +class TestDeferRoleEvaluation: + def test_deferred_state_starts_disabled(self): + assert _deferred_evaluation.enabled is False + assert _deferred_evaluation.object_roles == set() + assert _deferred_evaluation.needs_team_recompute is False + + def test_context_manager_enables_and_restores(self): + assert _deferred_evaluation.enabled is False + with defer_role_evaluation(): + assert _deferred_evaluation.enabled is True + assert _deferred_evaluation.enabled is False + + @patch('ansible_base.rbac.triggers.compute_object_role_permissions') + @patch('ansible_base.rbac.triggers.compute_team_member_roles') + def test_context_manager_nesting(self, mock_teams, mock_perms): + assert _deferred_evaluation.enabled is False + with defer_role_evaluation(): + assert _deferred_evaluation.enabled is True + _deferred_evaluation.object_roles.add("outer_role") + with defer_role_evaluation(): + assert _deferred_evaluation.enabled is True + assert _deferred_evaluation.object_roles == set() + assert _deferred_evaluation.enabled is True + assert "outer_role" in _deferred_evaluation.object_roles + assert _deferred_evaluation.enabled is False + + def test_context_manager_restores_on_exception(self): + assert _deferred_evaluation.enabled is False + with pytest.raises(RuntimeError): + with defer_role_evaluation(): + assert _deferred_evaluation.enabled is True + raise RuntimeError("test") + assert _deferred_evaluation.enabled is False + + @patch('ansible_base.rbac.triggers.compute_object_role_permissions') + @patch('ansible_base.rbac.triggers.compute_team_member_roles') + def test_update_after_assignment_defers_when_enabled(self, mock_teams, mock_perms): + mock_role = object() + with defer_role_evaluation(): + update_after_assignment(False, {mock_role}) + assert mock_role in _deferred_evaluation.object_roles + assert _deferred_evaluation.needs_team_recompute is False + + def test_update_after_assignment_defers_team_recompute(self): + with defer_role_evaluation(): + update_after_assignment(True, set()) + assert _deferred_evaluation.needs_team_recompute is True + + @patch('ansible_base.rbac.triggers.compute_object_role_permissions') + @patch('ansible_base.rbac.triggers.compute_team_member_roles') + def test_update_after_assignment_accumulates(self, mock_teams, mock_perms): + role_a = object() + role_b = object() + with defer_role_evaluation(): + update_after_assignment(False, {role_a}) + update_after_assignment(True, {role_b}) + assert role_a in _deferred_evaluation.object_roles + assert role_b in _deferred_evaluation.object_roles + assert _deferred_evaluation.needs_team_recompute is True + + @patch('ansible_base.rbac.triggers.compute_object_role_permissions') + @patch('ansible_base.rbac.triggers.compute_team_member_roles') + def test_update_after_assignment_calls_directly_when_not_deferred(self, mock_teams, mock_perms): + mock_roles = {object()} + update_after_assignment(True, mock_roles) + mock_teams.assert_called_once() + mock_perms.assert_called_once_with(object_roles=mock_roles) + + @patch('ansible_base.rbac.triggers.compute_object_role_permissions') + @patch('ansible_base.rbac.triggers.compute_team_member_roles') + def test_deferred_batch_executes_on_exit(self, mock_teams, mock_perms): + role_a = object() + role_b = object() + with defer_role_evaluation(): + update_after_assignment(True, {role_a}) + update_after_assignment(False, {role_b}) + mock_teams.assert_not_called() + mock_perms.assert_not_called() + + mock_teams.assert_called_once() + mock_perms.assert_called_once() + called_roles = mock_perms.call_args[1]['object_roles'] + assert role_a in called_roles + assert role_b in called_roles + + @patch('ansible_base.rbac.triggers.compute_object_role_permissions') + @patch('ansible_base.rbac.triggers.compute_team_member_roles') + def test_deferred_skips_team_recompute_when_not_needed(self, mock_teams, mock_perms): + with defer_role_evaluation(): + update_after_assignment(False, {object()}) + + mock_teams.assert_not_called() + mock_perms.assert_called_once() + + @patch('ansible_base.rbac.triggers.compute_object_role_permissions') + @patch('ansible_base.rbac.triggers.compute_team_member_roles') + def test_deferred_skips_all_when_empty(self, mock_teams, mock_perms): + with defer_role_evaluation(): + pass + + mock_teams.assert_not_called() + mock_perms.assert_not_called() + + def test_give_permission_batches_evaluations(self, inv_rd, rando, inventory, organization): + """Integration test: give_permission under defer batches RoleEvaluation writes.""" + inv2 = Inventory.objects.create(name='inv-2', organization=organization) + + with defer_role_evaluation(): + inv_rd.give_permission(rando, inventory) + inv_rd.give_permission(rando, inv2) + + assert rando.has_obj_perm(inventory, 'change') + assert rando.has_obj_perm(inv2, 'change') + assert RoleEvaluation.objects.filter(object_id=inventory.pk).exists() + assert RoleEvaluation.objects.filter(object_id=inv2.pk).exists() + + def test_remove_permission_batches_evaluations(self, inv_rd, rando, inventory, organization): + """Integration test: remove_permission under defer cleans up correctly.""" + inv_rd.give_permission(rando, inventory) + assert rando.has_obj_perm(inventory, 'change') + + with defer_role_evaluation(): + inv_rd.remove_permission(rando, inventory) + + assert not rando.has_obj_perm(inventory, 'change') + + def test_mixed_give_remove_under_defer(self, inv_rd, rando, inventory, organization): + """Integration test: mixed operations produce correct final state.""" + inv2 = Inventory.objects.create(name='inv-mix-2', organization=organization) + inv_rd.give_permission(rando, inventory) + + with defer_role_evaluation(): + inv_rd.give_permission(rando, inv2) + inv_rd.remove_permission(rando, inventory) + + assert not rando.has_obj_perm(inventory, 'change') + assert rando.has_obj_perm(inv2, 'change') diff --git a/test_app/tests/rbac/test_save_user_claims_optimizations.py b/test_app/tests/rbac/test_save_user_claims_optimizations.py new file mode 100644 index 000000000..ba5f87532 --- /dev/null +++ b/test_app/tests/rbac/test_save_user_claims_optimizations.py @@ -0,0 +1,163 @@ +"""Tests for save_user_claims() performance and reliability optimizations. + +Verifies that save_user_claims(): +- Wraps operations in transaction.atomic() for rollback on failure +- Wraps operations in no_reverse_sync() to suppress signals +- Wraps operations in defer_role_evaluation() to batch cache updates +- Logs timing and grant/removal counts +""" + +import logging +from unittest import mock +from uuid import uuid4 + +import pytest + +from ansible_base.rbac.claims import save_user_claims +from ansible_base.rbac.models import RoleDefinition, RoleUserAssignment +from ansible_base.rbac.permission_registry import permission_registry +from test_app.models import Organization + + +@pytest.fixture +def organization_admin_role(): + return RoleDefinition.objects.create_from_permissions( + permissions=[ + permission_registry.team_permission, + f'view_{permission_registry.team_model._meta.model_name}', + 'view_organization', + 'change_organization', + ], + name='Organization Admin', + content_type=permission_registry.content_type_model.objects.get_for_model(Organization), + managed=True, + ) + + +@pytest.mark.django_db +class TestSaveUserClaimsAtomicity: + def test_partial_failure_rolls_back(self, admin_user, organization, organization_admin_role): + """If save_user_claims fails midway, no partial writes should persist.""" + objects = {'organization': [{'ansible_id': str(organization.resource.ansible_id), 'name': organization.name}]} + object_roles = {"Organization Admin": {'content_type': 'organization', 'objects': [0]}} + + save_user_claims(admin_user, objects, object_roles, []) + assert RoleUserAssignment.objects.filter(user=admin_user).count() == 1 + + # Now attempt a save that will fail partway through give_permission + with mock.patch( + 'ansible_base.rbac.models.role.RoleDefinition.give_permission', + side_effect=RuntimeError("simulated failure"), + ): + with pytest.raises(RuntimeError, match="simulated failure"): + new_org_id = str(uuid4()) + objects_bad = { + 'organization': [ + {'ansible_id': str(organization.resource.ansible_id), 'name': organization.name}, + {'ansible_id': new_org_id, 'name': 'New Org'}, + ] + } + object_roles_bad = {"Organization Admin": {'content_type': 'organization', 'objects': [0, 1]}} + save_user_claims(admin_user, objects_bad, object_roles_bad, []) + + # Original assignment should still be intact (rolled back to pre-failure state) + assert RoleUserAssignment.objects.filter(user=admin_user).count() == 1 + + +@pytest.mark.django_db +class TestSaveUserClaimsNoReverseSync: + def test_no_reverse_sync_during_save(self, admin_user, organization_admin_role): + """Verify no_reverse_sync is active during save_user_claims execution.""" + from ansible_base.resource_registry.signals.handlers import reverse_sync_enabled + + sync_states_during_save = [] + + original_give_permission = RoleDefinition.give_permission + + def spy_give_permission(self, actor, content_object): + sync_states_during_save.append(reverse_sync_enabled.enabled) + return original_give_permission(self, actor, content_object) + + org = Organization.objects.create(name='sync-test-org') + objects = {'organization': [{'ansible_id': str(org.resource.ansible_id), 'name': org.name}]} + object_roles = {"Organization Admin": {'content_type': 'organization', 'objects': [0]}} + + with mock.patch.object(RoleDefinition, 'give_permission', spy_give_permission): + save_user_claims(admin_user, objects, object_roles, []) + + assert len(sync_states_during_save) > 0 + assert all(state is False for state in sync_states_during_save), "reverse_sync should be disabled during save_user_claims" + + +@pytest.mark.django_db +class TestSaveUserClaimsDeferredEvaluation: + def test_deferred_evaluation_during_save(self, admin_user, organization_admin_role): + """Verify defer_role_evaluation is active during save_user_claims execution.""" + from ansible_base.rbac.triggers import _deferred_evaluation + + deferred_states = [] + + original_give_permission = RoleDefinition.give_permission + + def spy_give_permission(self, actor, content_object): + deferred_states.append(_deferred_evaluation.enabled) + return original_give_permission(self, actor, content_object) + + org = Organization.objects.create(name='defer-test-org') + objects = {'organization': [{'ansible_id': str(org.resource.ansible_id), 'name': org.name}]} + object_roles = {"Organization Admin": {'content_type': 'organization', 'objects': [0]}} + + with mock.patch.object(RoleDefinition, 'give_permission', spy_give_permission): + save_user_claims(admin_user, objects, object_roles, []) + + assert len(deferred_states) > 0 + assert all(state is True for state in deferred_states), "evaluation should be deferred during save_user_claims" + + +@pytest.mark.django_db +class TestSaveUserClaimsLogging: + def test_logs_grant_count(self, admin_user, organization, organization_admin_role, caplog): + objects = {'organization': [{'ansible_id': str(organization.resource.ansible_id), 'name': organization.name}]} + object_roles = {"Organization Admin": {'content_type': 'organization', 'objects': [0]}} + + with caplog.at_level(logging.INFO, logger='ansible_base.rbac.claims'): + save_user_claims(admin_user, objects, object_roles, []) + + summary_logs = [r for r in caplog.records if r.message.startswith('save_user_claims for')] + assert len(summary_logs) == 1 + assert '1 grants' in summary_logs[0].message + assert '0 removals' in summary_logs[0].message + assert 's' in summary_logs[0].message # elapsed time + + def test_logs_removal_count(self, admin_user, organization, organization_admin_role, caplog): + objects = {'organization': [{'ansible_id': str(organization.resource.ansible_id), 'name': organization.name}]} + object_roles = {"Organization Admin": {'content_type': 'organization', 'objects': [0]}} + save_user_claims(admin_user, objects, object_roles, []) + + caplog.clear() + with caplog.at_level(logging.INFO, logger='ansible_base.rbac.claims'): + save_user_claims(admin_user, {}, {}, []) + + summary_logs = [r for r in caplog.records if r.message.startswith('save_user_claims for')] + assert len(summary_logs) == 1 + assert '0 grants' in summary_logs[0].message + assert '1 removals' in summary_logs[0].message + + def test_logs_mixed_grants_and_removals(self, admin_user, organization_admin_role, caplog): + org1 = Organization.objects.create(name='log-org-1') + org2 = Organization.objects.create(name='log-org-2') + + objects = {'organization': [{'ansible_id': str(org1.resource.ansible_id), 'name': org1.name}]} + object_roles = {"Organization Admin": {'content_type': 'organization', 'objects': [0]}} + save_user_claims(admin_user, objects, object_roles, []) + + caplog.clear() + # Now grant org2 and remove org1 + objects2 = {'organization': [{'ansible_id': str(org2.resource.ansible_id), 'name': org2.name}]} + with caplog.at_level(logging.INFO, logger='ansible_base.rbac.claims'): + save_user_claims(admin_user, objects2, object_roles, []) + + summary_logs = [r for r in caplog.records if r.message.startswith('save_user_claims for')] + assert len(summary_logs) == 1 + assert '1 grants' in summary_logs[0].message + assert '1 removals' in summary_logs[0].message