diff --git a/docker-compose-embedded.yml b/docker-compose-embedded.yml index 71fb12969f..0653425bf5 100644 --- a/docker-compose-embedded.yml +++ b/docker-compose-embedded.yml @@ -73,7 +73,7 @@ services: - AUTH_REQUIRED=true - EMAIL_AUTH_ENABLED=true - MCP_CLIENT_AUTH_ENABLED=true - - JWT_SECRET_KEY=jwt-secret-key + - JWT_SECRET_KEY=${JWT_SECRET_KEY:-my-test-key} - PLATFORM_ADMIN_EMAIL=admin@example.com - PLATFORM_ADMIN_PASSWORD=Changeme123! - PLATFORM_ADMIN_FULL_NAME=Platform Administrator @@ -149,7 +149,7 @@ services: register_benchmark: image: ghcr.io/ibm/mcp-context-forge:1ba8130f7fb82e6f393435be8d064879f234ace1 environment: - - JWT_SECRET_KEY=jwt-secret-key + - JWT_SECRET_KEY=${JWT_SECRET_KEY:-my-test-key} - BENCHMARK_SERVER_COUNT=${BENCHMARK_SERVER_COUNT:-10} - BENCHMARK_START_PORT=${BENCHMARK_START_PORT:-9000} command: @@ -209,11 +209,3 @@ services: print(f'Registration complete: {registered}/{server_count} benchmark servers') " - # NOTE: register_fast_time uses hardcoded secret in its command script. - # It will fail with 401 when JWT_SECRET_KEY differs from base compose. - # This is non-critical — benchmark servers provide the bulk test data. - # To also register fast_time, manually run: - # curl -X POST http://localhost:8080/gateways \ - # -H "Authorization: Bearer $(python3 -m mcpgateway.utils.create_jwt_token --username admin@example.com --exp 10080 --secret jwt-secret-key --algo HS256)" \ - # -H "Content-Type: application/json" \ - # -d '{"name":"fast_time","url":"http://fast_time_server:8080/http","transport":"STREAMABLEHTTP"}' diff --git a/mcpgateway/auth.py b/mcpgateway/auth.py index fc2de726f3..a3f6c4aca4 100644 --- a/mcpgateway/auth.py +++ b/mcpgateway/auth.py @@ -1286,8 +1286,15 @@ async def _set_auth_method_from_payload(payload: dict) -> None: except HTTPException: raise except Exception as revoke_check_error: - # Log the error but don't fail authentication for admin tokens - logger.warning(f"Token revocation check failed for JTI {jti}: {revoke_check_error}") + # Fail-secure: if the revocation check itself errors, reject the token. + # Allowing through on error would let revoked tokens bypass enforcement + # when the DB is unreachable or the table is missing. + logger.warning(f"Token revocation check failed for JTI {jti} — denying access (fail-secure): {revoke_check_error}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token validation failed", + headers={"WWW-Authenticate": "Bearer"}, + ) # Resolve teams based on token_use token_use = payload.get("token_use") diff --git a/mcpgateway/cache/auth_cache.py b/mcpgateway/cache/auth_cache.py index afd749ae32..450952f191 100644 --- a/mcpgateway/cache/auth_cache.py +++ b/mcpgateway/cache/auth_cache.py @@ -266,6 +266,26 @@ async def get_auth_context( self._hit_count += 1 return CachedAuthContext(is_token_revoked=True) + # Cross-worker revocation check: when a token is revoked on another worker, + # the revoking worker writes a Redis revocation marker. Check it BEFORE the + # L1 in-memory cache so that stale L1 entries cannot bypass revocation. + if jti: + redis = await self._get_redis_client() + if redis: + try: + revoke_key = self._get_redis_key("revoke", jti) + if await redis.exists(revoke_key): + # Promote to local set so subsequent requests skip the Redis call + with self._lock: + self._revoked_jtis.add(jti) + # Evict any stale L1 context entries for this JTI + for k in [k for k in self._context_cache if k.endswith(f":{jti}")]: + self._context_cache.pop(k, None) + self._hit_count += 1 + return CachedAuthContext(is_token_revoked=True) + except Exception as exc: + logger.debug(f"AuthCache: Redis revocation check failed for {jti[:8]}: {exc}") + cache_key = f"{email}:{jti or 'no-jti'}" # Check L1 in-memory cache first (no network I/O) diff --git a/mcpgateway/middleware/auth_middleware.py b/mcpgateway/middleware/auth_middleware.py index 1f74a50092..a8e45e97f4 100644 --- a/mcpgateway/middleware/auth_middleware.py +++ b/mcpgateway/middleware/auth_middleware.py @@ -20,10 +20,11 @@ from typing import Callable # Third-Party +from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import JSONResponse, Response # First-Party from mcpgateway.auth import get_current_user @@ -35,6 +36,12 @@ logger = logging.getLogger(__name__) security_logger = get_security_logger() +# HTTPException detail strings that indicate security-critical rejections +# (revoked tokens, disabled accounts, fail-secure validation errors). +# Only these trigger a hard JSON deny in the auth middleware; all other +# 401/403s fall through to route-level auth for backwards compatibility. +_HARD_DENY_DETAILS = frozenset({"Token has been revoked", "Account disabled", "Token validation failed"}) + def _should_log_auth_success() -> bool: """Check if successful authentication should be logged based on settings. @@ -144,8 +151,60 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: except Exception as close_error: logger.debug(f"Failed to close database session: {close_error}") + except HTTPException as e: + if e.status_code in (401, 403) and e.detail in _HARD_DENY_DETAILS: + logger.info(f"✗ Auth rejected ({e.status_code}): {e.detail}") + + if log_failure: + db = SessionLocal() + try: + security_logger.log_authentication_attempt( + user_id="unknown", + user_email=None, + auth_method="bearer_token", + success=False, + client_ip=request.client.host if request.client else "unknown", + user_agent=request.headers.get("user-agent"), + failure_reason=str(e.detail), + db=db, + ) + db.commit() + except Exception as log_error: + logger.debug(f"Failed to log auth failure: {log_error}") + finally: + try: + db.close() + except Exception as close_error: + logger.debug(f"Failed to close database session: {close_error}") + + # Browser/admin requests with stale cookies: let the request continue + # without user context so the RBAC layer can redirect to /admin/login. + # API requests: return a hard JSON 401/403 deny. + # Detection must match rbac.py's is_browser_request logic (Accept, + # HX-Request, and Referer: /admin) to avoid breaking admin UI flows. + accept_header = request.headers.get("accept", "") + is_htmx = request.headers.get("hx-request") == "true" + referer = request.headers.get("referer", "") + is_browser = "text/html" in accept_header or is_htmx or "/admin" in referer + if is_browser: + logger.debug("Browser request with rejected auth — continuing without user for redirect") + return await call_next(request) + + # Include essential security headers since this response bypasses + # SecurityHeadersMiddleware (it returns before call_next). + resp_headers = dict(e.headers) if e.headers else {} + resp_headers.setdefault("X-Content-Type-Options", "nosniff") + resp_headers.setdefault("Referrer-Policy", "strict-origin-when-cross-origin") + return JSONResponse( + status_code=e.status_code, + content={"detail": e.detail}, + headers=resp_headers, + ) + + # Non-security HTTP errors (e.g. 500 from a downstream service) — continue as anonymous + logger.info(f"✗ Auth context extraction failed (continuing as anonymous): {e}") except Exception as e: - # Silently fail - let route handlers enforce auth if needed + # Non-HTTP errors (network, decode, etc.) — continue as anonymous logger.info(f"✗ Auth context extraction failed (continuing as anonymous): {e}") # Log failed authentication attempt (based on logging level) diff --git a/mcpgateway/middleware/token_usage_middleware.py b/mcpgateway/middleware/token_usage_middleware.py index 050e45e5af..052ec771d3 100644 --- a/mcpgateway/middleware/token_usage_middleware.py +++ b/mcpgateway/middleware/token_usage_middleware.py @@ -21,8 +21,10 @@ # Standard import logging import time +from typing import Optional # Third-Party +import jwt as _jwt from starlette.datastructures import Headers from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send @@ -108,62 +110,120 @@ async def send_wrapper(message: dict) -> None: # Calculate response time response_time_ms = round((time.time() - start_time) * 1000) - # Only log if this was an API token request + # Log API token usage — covers both successful requests and auth-rejected attempts. + # Every request that uses (or tries to use) an API token is recorded, + # including blocked calls with revoked/expired tokens, so that usage stats are accurate. state = scope.get("state", {}) auth_method = state.get("auth_method") if state else None - if auth_method != "api_token": - return + jti: Optional[str] = None + user_email: Optional[str] = None + blocked: bool = False + block_reason: Optional[str] = None + + if auth_method == "api_token": + # --- Successfully authenticated API token request --- + jti = state.get("jti") if state else None + user = state.get("user") if state else None + user_email = getattr(user, "email", None) if user else None + if not user_email: + user_email = state.get("user_email") if state else None + + # If we don't have JTI or email, try to decode the token from the header + if not jti or not user_email: + try: + headers = Headers(scope=scope) + auth_header = headers.get("authorization") + if not auth_header or not auth_header.startswith("Bearer "): + return + token = auth_header.replace("Bearer ", "") + request = Request(scope, receive) + try: + payload = await verify_jwt_token_cached(token, request) + jti = jti or payload.get("jti") + user_email = user_email or payload.get("sub") or payload.get("email") + except Exception as decode_error: + logger.debug(f"Failed to decode token for usage logging: {decode_error}") + return + except Exception as e: + logger.debug(f"Error extracting token information: {e}") + return - # Extract token information from scope state - jti = state.get("jti") if state else None - user = state.get("user") if state else None - user_email = getattr(user, "email", None) if user else None - if not user_email: - user_email = state.get("user_email") if state else None + if not jti or not user_email: + logger.debug("Missing JTI or user_email for token usage logging") + return - # If we don't have JTI or email, try to decode the token - if not jti or not user_email: + # Bug 3a fix: reflect the actual outcome — 4xx responses mark the attempt + # as blocked (e.g. RBAC denied, rate-limited, or server-scoping violation). + # 5xx errors are backend failures, not security denials, so exclude them. + blocked = 400 <= status_code < 500 + if blocked: + block_reason = f"http_{status_code}" + + elif status_code in (401, 403): + # --- Auth-rejected request: check if the Bearer token was an API token --- + # When a revoked or expired API token is used, auth middleware rejects the + # request before setting auth_method="api_token", so the path above is + # never reached. We detect the attempt here by decoding the JWT payload + # without re-verifying it (the token identity is valid even if rejected). try: - # Get token from Authorization header headers = Headers(scope=scope) auth_header = headers.get("authorization") if not auth_header or not auth_header.startswith("Bearer "): return + raw_token = auth_header[7:] # strip "Bearer " - token = auth_header.replace("Bearer ", "") + # Decode without signature/expiry check — for identification only, not auth. + unverified = _jwt.decode(raw_token, options={"verify_signature": False}) + user_info = unverified.get("user", {}) + if user_info.get("auth_provider") != "api_token": + return # Not an API token — nothing to log - # Decode token to get JTI and user email - # Note: We need to create a minimal Request-like object - request = Request(scope, receive) - try: - payload = await verify_jwt_token_cached(token, request) - jti = jti or payload.get("jti") - user_email = user_email or payload.get("sub") or payload.get("email") - except Exception as decode_error: - logger.debug(f"Failed to decode token for usage logging: {decode_error}") + jti = unverified.get("jti") + user_email = unverified.get("sub") or unverified.get("email") + if not jti or not user_email: return + + # Verify JTI belongs to a real API token before logging. + # Without this check, an attacker can craft a JWT with fake + # jti/sub and auth_provider=api_token to pollute usage logs. + # Verify JTI belongs to a real API token and use the DB-stored + # owner email instead of the unverified JWT claim. Without this, + # an attacker who knows a valid JTI could forge a JWT with an + # arbitrary sub/email to poison another user's usage stats. + try: + # Third-Party + from sqlalchemy import select # pylint: disable=import-outside-toplevel + + # First-Party + from mcpgateway.db import EmailApiToken # pylint: disable=import-outside-toplevel + + with fresh_db_session() as verify_db: + token_row = verify_db.execute(select(EmailApiToken.id, EmailApiToken.user_email).where(EmailApiToken.jti == jti)).first() + if token_row is None: + return # JTI not in DB — forged token, skip logging + # Use the DB-stored owner, not the unverified JWT claim + user_email = token_row.user_email + except Exception: + return # DB error — skip logging rather than log unverified data + + blocked = True + block_reason = "revoked_or_expired" if status_code == 401 else f"http_{status_code}" except Exception as e: - logger.debug(f"Error extracting token information: {e}") + logger.debug(f"Failed to extract API token identity from rejected request: {e}") return + else: + return # Not an API token request — nothing to log - if not jti or not user_email: - logger.debug("Missing JTI or user_email for token usage logging") - return - - # Log token usage + # Shared logging path for both authenticated and blocked API token requests try: with fresh_db_session() as db: token_service = TokenCatalogService(db) - # Get client IP client = scope.get("client") ip_address = client[0] if client else None - - # Get user agent headers = Headers(scope=scope) user_agent = headers.get("user-agent") - # Log usage await token_service.log_token_usage( jti=jti, user_email=user_email, @@ -173,8 +233,8 @@ async def send_wrapper(message: dict) -> None: user_agent=user_agent, status_code=status_code, response_time_ms=response_time_ms, - blocked=False, - block_reason=None, + blocked=blocked, + block_reason=block_reason, ) except Exception as e: logger.debug(f"Failed to log token usage: {e}") diff --git a/mcpgateway/routers/tokens.py b/mcpgateway/routers/tokens.py index 45d8f7fb8e..8aa244e41c 100644 --- a/mcpgateway/routers/tokens.py +++ b/mcpgateway/routers/tokens.py @@ -30,18 +30,18 @@ def _require_authenticated_session(current_user: dict) -> None: - """Block anonymous and unauthenticated access to token management endpoints. + """Block anonymous, unauthenticated, and API-token access to token management endpoints. - Rejects requests where authentication could not be determined or where - the caller is anonymous. All authenticated methods (JWT, API tokens, - OAuth, SSO, proxy, etc.) are allowed — RBAC permission checks and - scope containment (via _get_caller_permissions) handle authorization. + Enforces Management Plane isolation: only interactive sessions (JWT from web + login, SSO, or OAuth) may create, list, or revoke tokens. API tokens are + Data Plane credentials and must never be able to manage other tokens + (token-chaining attack vector). Args: current_user: User context from get_current_user_with_permissions Raises: - HTTPException: 403 if auth_method is None or anonymous + HTTPException: 403 if auth_method is None, anonymous, or api_token """ auth_method = current_user.get("auth_method") @@ -60,6 +60,16 @@ def _require_authenticated_session(current_user: dict) -> None: detail="Token management requires authentication. Anonymous access is not permitted.", ) + # Block API tokens from managing other tokens (Management Plane isolation). + # Token CRUD endpoints require an interactive session (JWT from web login or SSO). + # Allowing API tokens here would let a compromised token create new long-lived + # tokens and escalate persistence — a token-chaining attack. + if auth_method == "api_token": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=("Token management requires an interactive session (JWT from web login or SSO). " "API tokens cannot create, list, or revoke other tokens."), + ) + async def _get_caller_permissions( db: Session, @@ -114,12 +124,24 @@ async def create_token( """ _require_authenticated_session(current_user) + # Auto-inherit team_id from the caller's single team when not explicitly provided. + # This prevents tokens from being silently scoped to public-only (team_id=None) + # when the user belongs to exactly one team, maintaining RBAC context at token level. + # Multi-team users must specify team_id explicitly to avoid ambiguity. + # Admins with teams=null are exempt and may still create global-scope tokens. + effective_team_id = request.team_id + if effective_team_id is None and not current_user.get("is_admin"): + user_teams = current_user.get("token_teams") or [] + if len(user_teams) == 1: + effective_team_id = user_teams[0] + logger.debug("Auto-inherited team_id=%s for token creation by %s", effective_team_id, current_user["email"]) + service = TokenCatalogService(db) # Get caller permissions for scope containment (if custom scope requested) caller_permissions = None if request.scope and request.scope.permissions: - caller_permissions = await _get_caller_permissions(db, current_user, request.team_id) + caller_permissions = await _get_caller_permissions(db, current_user, effective_team_id) # Convert request to TokenScope if provided scope = None @@ -140,7 +162,7 @@ async def create_token( scope=scope, expires_in_days=request.expires_in_days, tags=request.tags, - team_id=request.team_id, + team_id=effective_team_id, caller_permissions=caller_permissions, is_active=request.is_active, ) diff --git a/mcpgateway/services/token_catalog_service.py b/mcpgateway/services/token_catalog_service.py index 870706f97c..28c028e52c 100644 --- a/mcpgateway/services/token_catalog_service.py +++ b/mcpgateway/services/token_catalog_service.py @@ -852,14 +852,14 @@ async def revoke_token(self, token_id: str, user_email: str, revoked_by: str, re self.db.add(revocation) self.db.commit() - # Invalidate auth cache for revoked token + # Invalidate auth cache synchronously so revoked tokens are rejected immediately + # (fire-and-forget via create_task risks a race where the next request arrives + # before the invalidation task runs, allowing the revoked token through). try: # First-Party from mcpgateway.cache.auth_cache import auth_cache # pylint: disable=import-outside-toplevel - task = asyncio.create_task(auth_cache.invalidate_revocation(token.jti)) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) + await auth_cache.invalidate_revocation(token.jti) except Exception as cache_error: logger.debug(f"Failed to invalidate auth cache for revoked token: {cache_error}") @@ -899,9 +899,7 @@ async def admin_revoke_token(self, token_id: str, revoked_by: str, reason: Optio # First-Party from mcpgateway.cache.auth_cache import auth_cache # pylint: disable=import-outside-toplevel - task = asyncio.create_task(auth_cache.invalidate_revocation(token.jti)) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) + await auth_cache.invalidate_revocation(token.jti) except Exception as cache_error: logger.debug(f"Failed to invalidate auth cache: {cache_error}") diff --git a/tests/unit/mcpgateway/cache/test_auth_cache_l1_l2.py b/tests/unit/mcpgateway/cache/test_auth_cache_l1_l2.py index 6c350f663d..ebe82f029e 100644 --- a/tests/unit/mcpgateway/cache/test_auth_cache_l1_l2.py +++ b/tests/unit/mcpgateway/cache/test_auth_cache_l1_l2.py @@ -589,6 +589,7 @@ async def test_auth_cache_redis_error_paths(monkeypatch): redis = AsyncMock() redis.get = AsyncMock(side_effect=RuntimeError("boom")) + redis.exists = AsyncMock(side_effect=RuntimeError("boom")) redis.setex = AsyncMock(side_effect=RuntimeError("boom")) monkeypatch.setattr(cache, "_get_redis_client", AsyncMock(return_value=redis)) @@ -613,6 +614,48 @@ def test_auth_cache_team_membership_sync_hit(): assert cache.get_team_membership_valid_sync("user@example.com", ["team-1"]) is True +@pytest.mark.asyncio +async def test_redis_revocation_marker_detected(monkeypatch): + """When Redis has a revocation marker for a JTI, get_auth_context returns revoked and promotes to local set.""" + cache = AuthCache(enabled=True) + jti = "revoked-cross-worker-jti" + + redis = AsyncMock() + redis.exists = AsyncMock(return_value=1) # Revocation marker exists + monkeypatch.setattr(cache, "_get_redis_client", AsyncMock(return_value=redis)) + + # Also seed L1 with a stale non-revoked entry to prove Redis check wins + cache_key = f"user@example.com:{jti}" + cache._context_cache[cache_key] = CacheEntry( + value=CachedAuthContext(user={"email": "user@example.com"}, is_token_revoked=False), + expiry=time.time() + 60, + ) + + result = await cache.get_auth_context("user@example.com", jti) + + assert result is not None + assert result.is_token_revoked is True + # JTI should be promoted to local set for fast subsequent lookups + assert jti in cache._revoked_jtis + # Stale L1 entry should be evicted + assert cache_key not in cache._context_cache + + +@pytest.mark.asyncio +async def test_redis_revocation_check_error_falls_through(monkeypatch): + """When the Redis revocation check errors, fall through to L1/L2 cache.""" + cache = AuthCache(enabled=True) + + redis = AsyncMock() + redis.exists = AsyncMock(side_effect=RuntimeError("Redis timeout")) + redis.get = AsyncMock(return_value=None) + monkeypatch.setattr(cache, "_get_redis_client", AsyncMock(return_value=redis)) + + # No L1 entry, no Redis context → should return None (cache miss) + result = await cache.get_auth_context("user@example.com", "some-jti") + assert result is None + + @pytest.mark.asyncio async def test_auth_cache_invalidation_redis_warning_paths(monkeypatch): cache = AuthCache(enabled=True) diff --git a/tests/unit/mcpgateway/middleware/test_auth_middleware.py b/tests/unit/mcpgateway/middleware/test_auth_middleware.py index 94b5ba17b0..aff946b26e 100644 --- a/tests/unit/mcpgateway/middleware/test_auth_middleware.py +++ b/tests/unit/mcpgateway/middleware/test_auth_middleware.py @@ -298,3 +298,262 @@ async def test_failure_logging_close_exception(monkeypatch): assert response.status_code == 200 mock_db.close.assert_called_once() + + +# ============================================================================ +# HTTPException 401/403 hard-deny tests (auth_middleware lines 148-194) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_http_401_returns_json_deny_for_api_request(): + """HTTPException 401 from get_current_user returns JSON 401 for API requests.""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("ok")) + request = MagicMock(spec=Request) + request.url.path = "/api/tools" + request.cookies = {"jwt_token": "revoked_token"} + request.headers = {"accept": "application/json"} + request.client = MagicMock() + request.client.host = "127.0.0.1" + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=401, detail="Token has been revoked"))): + response = await middleware.dispatch(request, call_next) + + assert response.status_code == 401 + call_next.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_http_403_returns_json_deny_for_api_request(): + """HTTPException 403 from get_current_user returns JSON 403 for API requests.""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("ok")) + request = MagicMock(spec=Request) + request.url.path = "/api/tools" + request.cookies = {"jwt_token": "bad_token"} + request.headers = {"accept": "application/json"} + request.client = MagicMock() + request.client.host = "127.0.0.1" + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=403, detail="Account disabled"))): + response = await middleware.dispatch(request, call_next) + + assert response.status_code == 403 + call_next.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_http_401_browser_request_continues_for_redirect(): + """HTTPException 401 for browser/HTMX requests continues to allow RBAC redirect.""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("login page", status_code=200)) + request = MagicMock(spec=Request) + request.url.path = "/admin/overview/partial" + request.cookies = {"jwt_token": "stale_cookie"} + request.headers = {"accept": "text/html", "hx-request": "true"} + request.client = MagicMock() + request.client.host = "127.0.0.1" + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=401, detail="Token has been revoked"))): + response = await middleware.dispatch(request, call_next) + + # Browser request should pass through for RBAC redirect, not get JSON 401 + call_next.assert_awaited_once_with(request) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_http_401_with_failure_logging_enabled(): + """HTTPException 401 logs the failure when auth failure logging is enabled.""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("ok")) + request = MagicMock(spec=Request) + request.url.path = "/api/tools" + request.cookies = {"jwt_token": "revoked_token"} + request.headers = {"accept": "application/json"} + request.client = MagicMock() + request.client.host = "10.0.0.1" + + mock_security_logger = MagicMock() + mock_db = MagicMock() + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=True), \ + patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=mock_db), \ + patch("mcpgateway.middleware.auth_middleware.security_logger", mock_security_logger), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=401, detail="Token has been revoked"))): + response = await middleware.dispatch(request, call_next) + + assert response.status_code == 401 + mock_security_logger.log_authentication_attempt.assert_called_once() + mock_db.commit.assert_called_once() + mock_db.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_http_401_logging_db_error_handled(): + """DB error during 401 failure logging is caught gracefully.""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("ok")) + request = MagicMock(spec=Request) + request.url.path = "/api/tools" + request.cookies = {"jwt_token": "bad_token"} + request.headers = {"accept": "application/json"} + request.client = MagicMock() + request.client.host = "10.0.0.1" + + mock_security_logger = MagicMock() + mock_security_logger.log_authentication_attempt = MagicMock(side_effect=Exception("DB down")) + mock_db = MagicMock() + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=True), \ + patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=mock_db), \ + patch("mcpgateway.middleware.auth_middleware.security_logger", mock_security_logger), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=401, detail="Token has been revoked"))): + response = await middleware.dispatch(request, call_next) + + # Should still return 401 despite logging failure + assert response.status_code == 401 + mock_db.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_http_401_logging_db_close_error_handled(): + """DB close error during 401 failure logging is caught gracefully.""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("ok")) + request = MagicMock(spec=Request) + request.url.path = "/api/tools" + request.cookies = {"jwt_token": "bad_token"} + request.headers = {"accept": "application/json"} + request.client = MagicMock() + request.client.host = "10.0.0.1" + + mock_security_logger = MagicMock() + mock_db = MagicMock() + mock_db.close.side_effect = Exception("close failed") + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=True), \ + patch("mcpgateway.middleware.auth_middleware.SessionLocal", return_value=mock_db), \ + patch("mcpgateway.middleware.auth_middleware.security_logger", mock_security_logger), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=401, detail="Token has been revoked"))): + response = await middleware.dispatch(request, call_next) + + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_non_401_403_http_exception_continues_as_anonymous(): + """HTTPException with non-401/403 status continues as anonymous.""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("ok")) + request = MagicMock(spec=Request) + request.url.path = "/api/tools" + request.cookies = {"jwt_token": "some_token"} + request.headers = {} + request.client = MagicMock() + request.client.host = "127.0.0.1" + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=500, detail="Internal error"))): + response = await middleware.dispatch(request, call_next) + + # Non-security error: continue as anonymous + call_next.assert_awaited_once_with(request) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_non_revocation_401_falls_through_as_anonymous(): + """Non-revocation 401 (e.g. malformed token) continues as anonymous for route-level auth.""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("ok")) + request = MagicMock(spec=Request) + request.url.path = "/api/tools" + request.cookies = {"jwt_token": "minimal_jwt"} + request.headers = {"accept": "application/json"} + request.client = MagicMock() + request.client.host = "127.0.0.1" + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=401, detail="Invalid authentication credentials"))): + response = await middleware.dispatch(request, call_next) + + # Non-revocation 401 should fall through, not hard-deny + call_next.assert_awaited_once_with(request) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_http_401_referer_admin_continues_for_redirect(): + """HTTPException 401 with Referer: /admin continues for RBAC redirect (not JSON deny).""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("login page", status_code=200)) + request = MagicMock(spec=Request) + request.url.path = "/admin/tools/partial" + request.cookies = {"jwt_token": "stale_cookie"} + request.headers = {"accept": "*/*", "referer": "http://localhost:8080/admin/"} + request.client = MagicMock() + request.client.host = "127.0.0.1" + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=401, detail="Token revoked"))): + response = await middleware.dispatch(request, call_next) + + # Referer-based admin detection should let the request through for redirect + call_next.assert_awaited_once_with(request) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_http_401_json_deny_includes_security_headers(): + """JSON 401 response includes essential security headers (X-Content-Type-Options, Referrer-Policy).""" + from fastapi import HTTPException + + middleware = AuthContextMiddleware(app=AsyncMock()) + call_next = AsyncMock(return_value=Response("ok")) + request = MagicMock(spec=Request) + request.url.path = "/api/tools" + request.cookies = {"jwt_token": "revoked_token"} + request.headers = {"accept": "application/json"} + request.client = MagicMock() + request.client.host = "127.0.0.1" + + with patch("mcpgateway.middleware.auth_middleware._should_log_auth_success", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware._should_log_auth_failure", return_value=False), \ + patch("mcpgateway.middleware.auth_middleware.get_current_user", AsyncMock(side_effect=HTTPException(status_code=401, detail="Token has been revoked"))): + response = await middleware.dispatch(request, call_next) + + assert response.status_code == 401 + assert response.headers.get("x-content-type-options") == "nosniff" + assert response.headers.get("referrer-policy") == "strict-origin-when-cross-origin" diff --git a/tests/unit/mcpgateway/middleware/test_token_usage_middleware.py b/tests/unit/mcpgateway/middleware/test_token_usage_middleware.py index 5e87df5a0f..bd7c958469 100644 --- a/tests/unit/mcpgateway/middleware/test_token_usage_middleware.py +++ b/tests/unit/mcpgateway/middleware/test_token_usage_middleware.py @@ -495,3 +495,404 @@ async def app_impl(scope, receive, send): call_args = mock_token_service.log_token_usage.call_args assert call_args.kwargs["jti"] == "jti-opaque-123" assert call_args.kwargs["user_email"] == "opaque@example.com" + + +@pytest.mark.asyncio +async def test_logs_blocked_api_token_on_4xx(): + """Middleware marks usage as blocked and sets block_reason when an api_token request returns a 4xx status.""" + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 429, "headers": []}) + await send({"type": "http.response.body", "body": b"rate limited"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + user_mock = MagicMock() + user_mock.email = "user@example.com" + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": { + "auth_method": "api_token", + "jti": "jti-blocked-789", + "user": user_mock, + }, + "client": ("192.168.1.100", 12345), + "headers": [], + } + + mock_db = MagicMock() + mock_token_service = MagicMock() + mock_token_service.log_token_usage = AsyncMock() + + with ( + patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session") as mock_fresh_session, + patch("mcpgateway.middleware.token_usage_middleware.TokenCatalogService", return_value=mock_token_service), + ): + mock_fresh_session.return_value.__enter__.return_value = mock_db + await _make_asgi_call(middleware, scope) + + mock_token_service.log_token_usage.assert_awaited_once() + call_args = mock_token_service.log_token_usage.call_args + assert call_args.kwargs["blocked"] is True + assert call_args.kwargs["block_reason"] == "http_429" + + +@pytest.mark.asyncio +async def test_logs_revoked_api_token_on_401(): + """Middleware logs a blocked attempt using the DB-stored owner email, not the JWT claim.""" + # Standard + import jwt as _jwt_lib + from collections import namedtuple + + token_payload = { + "jti": "jti-revoked-abc", + "sub": "revoked@example.com", + "user": {"auth_provider": "api_token"}, + } + raw_token = _jwt_lib.encode(token_payload, "test-secret-key-for-unit-tests-only", algorithm="HS256") + + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 401, "headers": []}) + await send({"type": "http.response.body", "body": b"unauthorized"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": {}, + "client": ("10.0.0.1", 9000), + "headers": [(b"authorization", f"Bearer {raw_token}".encode())], + } + + # Mock the JTI verification to return a DB row with the real owner email + TokenRow = namedtuple("TokenRow", ["id", "user_email"]) + mock_verify_db = MagicMock() + mock_verify_db.execute.return_value.first.return_value = TokenRow(id="tok-1", user_email="real-owner@example.com") + mock_log_db = MagicMock() + mock_token_service = MagicMock() + mock_token_service.log_token_usage = AsyncMock() + + call_count = [0] + + def fresh_session_factory(): + ctx = MagicMock() + ctx.__enter__ = MagicMock(return_value=mock_verify_db if call_count[0] == 0 else mock_log_db) + ctx.__exit__ = MagicMock(return_value=False) + call_count[0] += 1 + return ctx + + with ( + patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session", side_effect=fresh_session_factory), + patch("mcpgateway.middleware.token_usage_middleware.TokenCatalogService", return_value=mock_token_service), + ): + await _make_asgi_call(middleware, scope) + + mock_token_service.log_token_usage.assert_awaited_once() + call_args = mock_token_service.log_token_usage.call_args + assert call_args.kwargs["jti"] == "jti-revoked-abc" + # Must use the DB-stored email, NOT the unverified JWT "sub" claim + assert call_args.kwargs["user_email"] == "real-owner@example.com" + assert call_args.kwargs["blocked"] is True + assert call_args.kwargs["block_reason"] == "revoked_or_expired" + + +@pytest.mark.asyncio +async def test_logs_rejected_api_token_on_403(): + """Middleware logs a blocked attempt with http_403 reason using DB-stored email.""" + # Standard + import jwt as _jwt_lib + from collections import namedtuple + + token_payload = { + "jti": "jti-rejected-def", + "sub": "rejected@example.com", + "user": {"auth_provider": "api_token"}, + } + raw_token = _jwt_lib.encode(token_payload, "test-secret-key-for-unit-tests-only", algorithm="HS256") + + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 403, "headers": []}) + await send({"type": "http.response.body", "body": b"forbidden"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": {}, + "client": ("10.0.0.1", 9000), + "headers": [(b"authorization", f"Bearer {raw_token}".encode())], + } + + TokenRow = namedtuple("TokenRow", ["id", "user_email"]) + mock_verify_db = MagicMock() + mock_verify_db.execute.return_value.first.return_value = TokenRow(id="tok-2", user_email="db-owner@example.com") + mock_log_db = MagicMock() + mock_token_service = MagicMock() + mock_token_service.log_token_usage = AsyncMock() + + call_count = [0] + + def fresh_session_factory(): + ctx = MagicMock() + ctx.__enter__ = MagicMock(return_value=mock_verify_db if call_count[0] == 0 else mock_log_db) + ctx.__exit__ = MagicMock(return_value=False) + call_count[0] += 1 + return ctx + + with ( + patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session", side_effect=fresh_session_factory), + patch("mcpgateway.middleware.token_usage_middleware.TokenCatalogService", return_value=mock_token_service), + ): + await _make_asgi_call(middleware, scope) + + mock_token_service.log_token_usage.assert_awaited_once() + call_args = mock_token_service.log_token_usage.call_args + assert call_args.kwargs["blocked"] is True + assert call_args.kwargs["block_reason"] == "http_403" + assert call_args.kwargs["user_email"] == "db-owner@example.com" + + +@pytest.mark.asyncio +async def test_skips_rejected_api_token_missing_jti(): + """Middleware skips logging when a rejected API token payload has no jti or sub/email.""" + # Standard + import jwt as _jwt_lib + + token_payload = { + # No jti, no sub/email + "user": {"auth_provider": "api_token"}, + } + raw_token = _jwt_lib.encode(token_payload, "test-secret-key-for-unit-tests-only", algorithm="HS256") + + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 401, "headers": []}) + await send({"type": "http.response.body", "body": b"unauthorized"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": {}, + "client": ("10.0.0.1", 9000), + "headers": [(b"authorization", f"Bearer {raw_token}".encode())], + } + + mock_db = MagicMock() + mock_token_service = MagicMock() + mock_token_service.log_token_usage = AsyncMock() + + with ( + patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session") as mock_fresh_session, + patch("mcpgateway.middleware.token_usage_middleware.TokenCatalogService", return_value=mock_token_service), + ): + mock_fresh_session.return_value.__enter__.return_value = mock_db + await _make_asgi_call(middleware, scope) + + mock_token_service.log_token_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_skips_rejected_request_without_bearer_header(): + """Middleware skips logging when a 401/403 response has no Bearer authorization header.""" + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 401, "headers": []}) + await send({"type": "http.response.body", "body": b"unauthorized"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": {}, + "client": ("10.0.0.1", 9000), + "headers": [], # No authorization header + } + + with patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session") as mock_session: + await _make_asgi_call(middleware, scope) + + mock_session.assert_not_called() + + +@pytest.mark.asyncio +async def test_skips_rejected_non_api_token_jwt(): + """Middleware skips logging when a 401/403 response carries a JWT that is not an API token.""" + # Standard + import jwt as _jwt_lib + + token_payload = { + "jti": "jti-jwt-session", + "sub": "user@example.com", + "user": {"auth_provider": "email"}, # Not an API token + } + raw_token = _jwt_lib.encode(token_payload, "test-secret-key-for-unit-tests-only", algorithm="HS256") + + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 403, "headers": []}) + await send({"type": "http.response.body", "body": b"forbidden"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": {}, + "client": ("10.0.0.1", 9000), + "headers": [(b"authorization", f"Bearer {raw_token}".encode())], + } + + with patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session") as mock_session: + await _make_asgi_call(middleware, scope) + + mock_session.assert_not_called() + + +@pytest.mark.asyncio +async def test_skips_rejected_request_with_malformed_token(): + """Middleware skips logging when a 401/403 response carries a malformed token that cannot be decoded.""" + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 401, "headers": []}) + await send({"type": "http.response.body", "body": b"unauthorized"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": {}, + "client": ("10.0.0.1", 9000), + "headers": [(b"authorization", b"Bearer not-a-valid-jwt-at-all")], + } + + with patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session") as mock_session: + await _make_asgi_call(middleware, scope) + + mock_session.assert_not_called() + + +@pytest.mark.asyncio +async def test_skips_forged_jwt_with_unknown_jti(): + """Middleware skips logging when a rejected API token's JTI doesn't exist in the database.""" + # Standard + import jwt as _jwt_lib + + token_payload = { + "jti": "forged-jti-not-in-db", + "sub": "attacker@evil.com", + "user": {"auth_provider": "api_token"}, + } + raw_token = _jwt_lib.encode(token_payload, "test-secret-key-for-unit-tests-only", algorithm="HS256") + + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 401, "headers": []}) + await send({"type": "http.response.body", "body": b"unauthorized"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": {}, + "client": ("10.0.0.1", 9000), + "headers": [(b"authorization", f"Bearer {raw_token}".encode())], + } + + mock_token_service = MagicMock() + mock_token_service.log_token_usage = AsyncMock() + + # Mock fresh_db_session: first call (JTI verify) returns no match, second would be for logging + mock_verify_db = MagicMock() + mock_verify_db.execute.return_value.first.return_value = None + + mock_ctx = MagicMock() + mock_ctx.__enter__ = MagicMock(return_value=mock_verify_db) + mock_ctx.__exit__ = MagicMock(return_value=False) + + with ( + patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session", return_value=mock_ctx), + patch("mcpgateway.middleware.token_usage_middleware.TokenCatalogService", return_value=mock_token_service), + ): + await _make_asgi_call(middleware, scope) + + # Should NOT log usage because JTI was not found in DB + mock_token_service.log_token_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_skips_logging_on_jti_verification_db_error(): + """Middleware skips logging when the JTI verification DB query fails.""" + # Standard + import jwt as _jwt_lib + + token_payload = { + "jti": "jti-db-error", + "sub": "user@example.com", + "user": {"auth_provider": "api_token"}, + } + raw_token = _jwt_lib.encode(token_payload, "test-secret-key-for-unit-tests-only", algorithm="HS256") + + app = AsyncMock() + + async def app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 401, "headers": []}) + await send({"type": "http.response.body", "body": b"unauthorized"}) + + app.side_effect = app_impl + middleware = TokenUsageMiddleware(app=app) + + scope = { + "type": "http", + "path": "/api/tools", + "method": "GET", + "state": {}, + "client": ("10.0.0.1", 9000), + "headers": [(b"authorization", f"Bearer {raw_token}".encode())], + } + + mock_token_service = MagicMock() + mock_token_service.log_token_usage = AsyncMock() + + with ( + patch("mcpgateway.middleware.token_usage_middleware.fresh_db_session", side_effect=Exception("DB down")), + patch("mcpgateway.middleware.token_usage_middleware.TokenCatalogService", return_value=mock_token_service), + ): + await _make_asgi_call(middleware, scope) + + # Should NOT log usage because DB verification failed + mock_token_service.log_token_usage.assert_not_awaited() diff --git a/tests/unit/mcpgateway/routers/test_tokens.py b/tests/unit/mcpgateway/routers/test_tokens.py index fc47d3f05e..aca338efad 100644 --- a/tests/unit/mcpgateway/routers/test_tokens.py +++ b/tests/unit/mcpgateway/routers/test_tokens.py @@ -110,9 +110,13 @@ def mock_token_record(): class TestAuthenticatedSessionGate: """Test authenticated session gating for token endpoints.""" - def test_api_token_allowed(self): - """API tokens are allowed for token management (RBAC handles authorization).""" - _require_authenticated_session({"auth_method": "api_token"}) + def test_api_token_blocked(self): + """API tokens are blocked from token management (Management Plane isolation).""" + with pytest.raises(HTTPException) as exc_info: + _require_authenticated_session({"auth_method": "api_token"}) + + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + assert "interactive session" in exc_info.value.detail def test_missing_auth_method_blocked(self): """Missing auth_method fails secure.""" @@ -660,13 +664,16 @@ async def test_admin_revoke_token_success(self, mock_db, mock_admin_user): mock_service.admin_revoke_token.assert_called_once() @pytest.mark.asyncio - async def test_admin_revoke_token_allowed_for_api_token(self, mock_db, mock_admin_user): - """Admin API tokens are allowed for admin token management (RBAC handles authorization).""" + async def test_admin_revoke_token_blocked_for_api_token(self, mock_db, mock_admin_user): + """API tokens are blocked from admin token management (Management Plane isolation).""" current_user = dict(mock_admin_user) current_user["auth_method"] = "api_token" - # Should not raise 403 for api_token auth_method — succeeds or raises non-403 error - await admin_revoke_token(token_id="token-123", request=None, current_user=current_user, db=mock_db) + # API tokens must NEVER manage tokens, even for admins — security invariant + with pytest.raises(HTTPException) as exc_info: + await admin_revoke_token(token_id="token-123", request=None, current_user=current_user, db=mock_db) + + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio async def test_admin_revoke_token_non_admin(self, mock_db, mock_current_user): @@ -827,11 +834,12 @@ async def test_list_team_tokens_unauthorized(self, mock_db, mock_current_user): class TestApiTokenAuth: - """Test that API token auth_method works for all token endpoints (bug #2870). + """Test that API token auth_method is blocked from all token management endpoints. - Previously, _require_interactive_session blocked API tokens from all token - management endpoints. After the fix, API tokens should be able to manage - tokens like any other authenticated method — RBAC handles authorization. + API tokens represent the data plane; token management (CRUD) is a management-plane + operation that requires an interactive session (JWT from web login or SSO). + Allowing API tokens here would enable token-chaining attacks where a compromised + token creates new long-lived tokens to maintain persistence. """ @pytest.fixture @@ -857,122 +865,83 @@ def admin_api_token_user(self, mock_db): } @pytest.mark.asyncio - async def test_create_token_with_api_token(self, mock_db, api_token_user, mock_token_record): - """API token can create new tokens (core bug #2870 fix).""" + async def test_create_token_blocked_for_api_token(self, mock_db, api_token_user): + """API token cannot create new tokens (token-chaining prevention).""" request = TokenCreateRequest(name="Created-Via-API-Token", description="Test") - with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: - mock_service = mock_service_class.return_value - mock_service.create_token = AsyncMock(return_value=(mock_token_record, "new-raw-token")) - - response = await create_token(request, current_user=api_token_user, db=mock_db) + with pytest.raises(HTTPException) as exc_info: + await create_token(request, current_user=api_token_user, db=mock_db) - assert isinstance(response, TokenCreateResponse) - assert response.access_token == "new-raw-token" - mock_service.create_token.assert_called_once() + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + assert "interactive session" in exc_info.value.detail @pytest.mark.asyncio - async def test_list_tokens_with_api_token(self, mock_db, api_token_user, mock_token_record): - """API token can list own tokens.""" - with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: - mock_service = mock_service_class.return_value - mock_service.list_user_and_team_tokens = AsyncMock(return_value=[mock_token_record]) - mock_service.count_user_and_team_tokens = AsyncMock(return_value=1) - mock_service.get_token_revocations_batch = AsyncMock(return_value={}) - - response = await list_tokens(include_inactive=False, limit=50, offset=0, db=mock_db, current_user=api_token_user) + async def test_list_tokens_blocked_for_api_token(self, mock_db, api_token_user): + """API token cannot list tokens (management plane isolation).""" + with pytest.raises(HTTPException) as exc_info: + await list_tokens(include_inactive=False, limit=50, offset=0, db=mock_db, current_user=api_token_user) - assert isinstance(response, TokenListResponse) - assert len(response.tokens) == 1 + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_get_token_with_api_token(self, mock_db, api_token_user, mock_token_record): - """API token can get token details.""" - with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: - mock_service = mock_service_class.return_value - mock_service.get_token = AsyncMock(return_value=mock_token_record) - - response = await get_token(token_id="token-123", current_user=api_token_user, db=mock_db) + async def test_get_token_blocked_for_api_token(self, mock_db, api_token_user): + """API token cannot get token details (management plane isolation).""" + with pytest.raises(HTTPException) as exc_info: + await get_token(token_id="token-123", current_user=api_token_user, db=mock_db) - assert isinstance(response, TokenResponse) - assert response.id == "token-123" + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_revoke_token_with_api_token(self, mock_db, api_token_user): - """API token can revoke tokens.""" - with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: - mock_service = mock_service_class.return_value - mock_service.revoke_token = AsyncMock(return_value=True) - + async def test_revoke_token_blocked_for_api_token(self, mock_db, api_token_user): + """API token cannot revoke tokens (management plane isolation).""" + with pytest.raises(HTTPException) as exc_info: await revoke_token(token_id="token-123", request=None, current_user=api_token_user, db=mock_db) - mock_service.revoke_token.assert_called_once() + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_update_token_with_api_token(self, mock_db, api_token_user, mock_token_record): - """API token can update tokens.""" + async def test_update_token_blocked_for_api_token(self, mock_db, api_token_user): + """API token cannot update tokens (management plane isolation).""" request = TokenUpdateRequest(name="Updated-Via-API-Token") - with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: - mock_service = mock_service_class.return_value - mock_token_record.name = "Updated-Via-API-Token" - mock_service.update_token = AsyncMock(return_value=mock_token_record) - - response = await update_token(token_id="token-123", request=request, current_user=api_token_user, db=mock_db) + with pytest.raises(HTTPException) as exc_info: + await update_token(token_id="token-123", request=request, current_user=api_token_user, db=mock_db) - assert response.name == "Updated-Via-API-Token" + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_get_usage_stats_with_api_token(self, mock_db, api_token_user, mock_token_record): - """API token can view usage stats.""" - stats_data = { - "period_days": 7, - "total_requests": 100, - "successful_requests": 95, - "blocked_requests": 5, - "success_rate": 0.95, - "average_response_time_ms": 150.0, - "top_endpoints": [], - } - - with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: - mock_service = mock_service_class.return_value - mock_service.get_token = AsyncMock(return_value=mock_token_record) - mock_service.get_token_usage_stats = AsyncMock(return_value=stats_data) - - response = await get_token_usage_stats(token_id="token-123", days=7, current_user=api_token_user, db=mock_db) + async def test_get_usage_stats_blocked_for_api_token(self, mock_db, api_token_user): + """API token cannot view usage stats (management plane isolation).""" + with pytest.raises(HTTPException) as exc_info: + await get_token_usage_stats(token_id="token-123", days=7, current_user=api_token_user, db=mock_db) - assert isinstance(response, TokenUsageStatsResponse) - assert response.total_requests == 100 + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_admin_list_all_with_api_token(self, mock_db, admin_api_token_user, mock_token_record): - """Admin API token can list all tokens.""" - with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: - mock_service = mock_service_class.return_value - mock_service.list_user_tokens = AsyncMock(return_value=[mock_token_record]) - mock_service.count_user_tokens = AsyncMock(return_value=1) - mock_service.get_token_revocations_batch = AsyncMock(return_value={}) - - response = await list_all_tokens(user_email="other@example.com", include_inactive=False, limit=100, offset=0, current_user=admin_api_token_user, db=mock_db) + async def test_admin_list_all_blocked_for_api_token(self, mock_db, admin_api_token_user): + """Admin API token cannot list all tokens (management plane isolation).""" + with pytest.raises(HTTPException) as exc_info: + await list_all_tokens(user_email="other@example.com", include_inactive=False, limit=100, offset=0, current_user=admin_api_token_user, db=mock_db) - assert isinstance(response, TokenListResponse) + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_create_team_token_with_api_token(self, mock_db, api_token_user, mock_token_record): - """API token can create team tokens.""" + async def test_create_team_token_blocked_for_api_token(self, mock_db, api_token_user): + """API token cannot create team tokens (management plane isolation).""" request = TokenCreateRequest(name="Team-Token-Via-API", description="Test") - mock_token_record.team_id = "team-456" - with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: - mock_service = mock_service_class.return_value - mock_service.create_token = AsyncMock(return_value=(mock_token_record, "team-token")) + with pytest.raises(HTTPException) as exc_info: + await create_team_token(team_id="team-456", request=request, current_user=api_token_user, db=mock_db) - response = await create_team_token(team_id="team-456", request=request, current_user=api_token_user, db=mock_db) + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN - assert response.access_token == "team-token" - call_args = mock_service.create_token.call_args - assert call_args[1]["team_id"] == "team-456" + @pytest.mark.asyncio + async def test_list_team_tokens_blocked_for_api_token(self, mock_db, api_token_user): + """API token cannot list team tokens (management plane isolation).""" + with pytest.raises(HTTPException) as exc_info: + await list_team_tokens(team_id="team-456", include_inactive=False, limit=50, offset=0, current_user=api_token_user, db=mock_db) + + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN class TestAuthenticatedSessionErrorMessages: @@ -1082,3 +1051,32 @@ async def test_create_token_with_complex_scope(self, mock_db, mock_current_user, assert len(scope.permissions) == 3 assert len(scope.ip_restrictions) == 2 assert scope.usage_limits["max_calls"] == 10000 + + @pytest.mark.asyncio + async def test_create_token_auto_inherits_single_team(self, mock_db, mock_token_record): + """Non-admin user belonging to exactly one team auto-inherits team_id when not set in the request.""" + single_team_user = { + "email": "dev@example.com", + "is_admin": False, + "permissions": ["tokens.create"], + "db": mock_db, + "auth_method": "jwt", + "token_teams": ["team-auto"], + } + request = MagicMock(spec=TokenCreateRequest) + request.name = "Auto Team Token" + request.description = None + request.scope = None + request.expires_in_days = 30 + request.tags = [] + request.team_id = None + request.is_active = True + + with patch("mcpgateway.routers.tokens.TokenCatalogService") as mock_service_class: + mock_service = mock_service_class.return_value + mock_service.create_token = AsyncMock(return_value=(mock_token_record, "auto-inherit-token")) + + await create_token(request, current_user=single_team_user, db=mock_db) + + call_args = mock_service.create_token.call_args + assert call_args[1]["team_id"] == "team-auto" diff --git a/tests/unit/mcpgateway/test_auth.py b/tests/unit/mcpgateway/test_auth.py index 7b5bf642fc..42aa8a417b 100644 --- a/tests/unit/mcpgateway/test_auth.py +++ b/tests/unit/mcpgateway/test_auth.py @@ -231,32 +231,22 @@ async def test_revoked_jwt_token_raises_401(self): assert exc_info.value.detail == "Token has been revoked" @pytest.mark.asyncio - async def test_token_revocation_check_failure_logs_warning(self, caplog): - """Test that token revocation check failure logs warning but doesn't fail auth.""" + async def test_token_revocation_check_failure_denies_access(self, caplog): + """Test that token revocation check failure denies access (fail-secure).""" credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="jwt_with_jti") jwt_payload = {"sub": "test@example.com", "jti": "token_id_456", "exp": (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp()} - mock_user = EmailUser( - email="test@example.com", - password_hash="hash", - full_name="Test User", - is_admin=False, - is_active=True, - email_verified_at=datetime.now(timezone.utc), - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - ) - caplog.set_level(logging.WARNING) with patch("mcpgateway.auth.verify_jwt_token_cached", AsyncMock(return_value=jwt_payload)): with patch("mcpgateway.auth._check_token_revoked_sync", side_effect=Exception("Database error")): - with patch("mcpgateway.auth._get_user_by_email_sync", return_value=mock_user): + with patch("mcpgateway.auth._get_user_by_email_sync", return_value=None): with patch("mcpgateway.auth._get_personal_team_sync", return_value=None): - user = await get_current_user(credentials=credentials) + with pytest.raises(HTTPException) as exc_info: + await get_current_user(credentials=credentials) - assert user.email == mock_user.email + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert "Token revocation check failed for JTI token_id_456" in caplog.text @pytest.mark.asyncio @@ -1205,20 +1195,21 @@ async def test_api_token_last_used_updated_on_jwt_auth(self, monkeypatch): with patch("mcpgateway.auth._get_user_by_email_sync", return_value=mock_user): with patch("mcpgateway.auth._get_personal_team_sync", return_value=None): with patch("mcpgateway.auth._update_api_token_last_used_sync") as mock_update: - with patch("mcpgateway.auth.asyncio.to_thread", AsyncMock(side_effect=lambda f, *args: f(*args))): - user = await get_current_user(credentials=credentials, request=request) + with patch("mcpgateway.auth._check_token_revoked_sync", return_value=False): + with patch("mcpgateway.auth.asyncio.to_thread", AsyncMock(side_effect=lambda f, *args: f(*args))): + user = await get_current_user(credentials=credentials, request=request) - # Verify user was authenticated - assert user.email == "api@example.com" + # Verify user was authenticated + assert user.email == "api@example.com" - # Verify auth_method was set to api_token - assert request.state.auth_method == "api_token" + # Verify auth_method was set to api_token + assert request.state.auth_method == "api_token" - # Verify JTI was stored in request.state - assert request.state.jti == "jti-api-456" + # Verify JTI was stored in request.state + assert request.state.jti == "jti-api-456" - # Verify last_used update was called - mock_update.assert_called_once_with("jti-api-456") + # Verify last_used update was called + mock_update.assert_called_once_with("jti-api-456") @pytest.mark.asyncio async def test_api_token_last_used_update_failure_continues_auth(self, monkeypatch): @@ -1298,14 +1289,15 @@ async def test_api_token_jti_stored_in_request_state(self, monkeypatch): with patch("mcpgateway.auth.verify_jwt_token_cached", AsyncMock(return_value=jwt_payload)): with patch("mcpgateway.auth._get_user_by_email_sync", return_value=mock_user): with patch("mcpgateway.auth._get_personal_team_sync", return_value="team_123"): - user = await get_current_user(credentials=credentials, request=request) + with patch("mcpgateway.auth._check_token_revoked_sync", return_value=False): + user = await get_current_user(credentials=credentials, request=request) - # Verify user was authenticated - assert user.email == "test@example.com" + # Verify user was authenticated + assert user.email == "test@example.com" - # Verify JTI was stored in request.state - assert hasattr(request.state, "jti") - assert request.state.jti == "jti-store-test-789" + # Verify JTI was stored in request.state + assert hasattr(request.state, "jti") + assert request.state.jti == "jti-store-test-789" @pytest.mark.asyncio async def test_legacy_api_token_last_used_updated(self, monkeypatch): @@ -1340,18 +1332,19 @@ async def test_legacy_api_token_last_used_updated(self, monkeypatch): with patch("mcpgateway.auth._get_personal_team_sync", return_value=None): with patch("mcpgateway.auth._is_api_token_jti_sync", return_value=True): with patch("mcpgateway.auth._update_api_token_last_used_sync") as mock_update: - with patch("mcpgateway.auth.asyncio.to_thread", AsyncMock(side_effect=lambda f, *args: f(*args))): - user = await get_current_user(credentials=credentials, request=request) + with patch("mcpgateway.auth._check_token_revoked_sync", return_value=False): + with patch("mcpgateway.auth.asyncio.to_thread", AsyncMock(side_effect=lambda f, *args: f(*args))): + user = await get_current_user(credentials=credentials, request=request) - # Verify user was authenticated - assert user.email == "legacy@example.com" + # Verify user was authenticated + assert user.email == "legacy@example.com" - # Verify auth_method was set to api_token - assert request.state.auth_method == "api_token" + # Verify auth_method was set to api_token + assert request.state.auth_method == "api_token" - # Verify last_used update was called for legacy token - assert mock_update.call_count == 1 - mock_update.assert_called_with("jti-legacy-999") + # Verify last_used update was called for legacy token + assert mock_update.call_count == 1 + mock_update.assert_called_with("jti-legacy-999") @pytest.mark.asyncio async def test_legacy_api_token_last_used_update_failure_continues_auth(self, monkeypatch): @@ -1767,6 +1760,7 @@ async def test_legacy_non_api_token_jti(self): patch("mcpgateway.auth._is_api_token_jti_sync", return_value=False), patch("mcpgateway.auth._get_user_by_email_sync", return_value=mock_user), patch("mcpgateway.auth._get_personal_team_sync", return_value=None), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), ): user = await get_current_user(credentials=credentials, request=request) @@ -2136,6 +2130,7 @@ async def test_cache_user_missing_fallthrough(self, monkeypatch): patch("mcpgateway.cache.auth_cache.auth_cache.get_auth_context", AsyncMock(return_value=cached_ctx)), patch("mcpgateway.auth._get_user_by_email_sync", return_value=mock_user), patch("mcpgateway.auth._get_personal_team_sync", return_value=None), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), ): user = await get_current_user(credentials=credentials, request=request) @@ -2159,6 +2154,7 @@ async def test_cache_exception_fallthrough(self, monkeypatch): patch("mcpgateway.cache.auth_cache.auth_cache.get_auth_context", AsyncMock(side_effect=RuntimeError("cache down"))), patch("mcpgateway.auth._get_user_by_email_sync", return_value=mock_user), patch("mcpgateway.auth._get_personal_team_sync", return_value=None), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), ): user = await get_current_user(credentials=credentials)