Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions docker-compose-embedded.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ services:
- AUTH_REQUIRED=true
- EMAIL_AUTH_ENABLED=true
- MCP_CLIENT_AUTH_ENABLED=true
- JWT_SECRET_KEY=jwt-secret-key
- JWT_SECRET_KEY=${JWT_SECRET_KEY:-my-test-key}
- PLATFORM_ADMIN_EMAIL=admin@example.com
- PLATFORM_ADMIN_PASSWORD=Changeme123!
- PLATFORM_ADMIN_FULL_NAME=Platform Administrator
Expand Down Expand Up @@ -149,7 +149,7 @@ services:
register_benchmark:
image: ghcr.io/ibm/mcp-context-forge:1ba8130f7fb82e6f393435be8d064879f234ace1
environment:
- JWT_SECRET_KEY=jwt-secret-key
- JWT_SECRET_KEY=${JWT_SECRET_KEY:-my-test-key}
- BENCHMARK_SERVER_COUNT=${BENCHMARK_SERVER_COUNT:-10}
- BENCHMARK_START_PORT=${BENCHMARK_START_PORT:-9000}
command:
Expand Down Expand Up @@ -209,11 +209,3 @@ services:
print(f'Registration complete: {registered}/{server_count} benchmark servers')
"

# NOTE: register_fast_time uses hardcoded secret in its command script.
# It will fail with 401 when JWT_SECRET_KEY differs from base compose.
# This is non-critical — benchmark servers provide the bulk test data.
# To also register fast_time, manually run:
# curl -X POST http://localhost:8080/gateways \
# -H "Authorization: Bearer $(python3 -m mcpgateway.utils.create_jwt_token --username admin@example.com --exp 10080 --secret jwt-secret-key --algo HS256)" \
# -H "Content-Type: application/json" \
# -d '{"name":"fast_time","url":"http://fast_time_server:8080/http","transport":"STREAMABLEHTTP"}'
11 changes: 9 additions & 2 deletions mcpgateway/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,8 +1286,15 @@ async def _set_auth_method_from_payload(payload: dict) -> None:
except HTTPException:
raise
except Exception as revoke_check_error:
# Log the error but don't fail authentication for admin tokens
logger.warning(f"Token revocation check failed for JTI {jti}: {revoke_check_error}")
# Fail-secure: if the revocation check itself errors, reject the token.
# Allowing through on error would let revoked tokens bypass enforcement
# when the DB is unreachable or the table is missing.
logger.warning(f"Token revocation check failed for JTI {jti} — denying access (fail-secure): {revoke_check_error}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token validation failed",
headers={"WWW-Authenticate": "Bearer"},
)

# Resolve teams based on token_use
token_use = payload.get("token_use")
Expand Down
20 changes: 20 additions & 0 deletions mcpgateway/cache/auth_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,26 @@ async def get_auth_context(
self._hit_count += 1
return CachedAuthContext(is_token_revoked=True)

# Cross-worker revocation check: when a token is revoked on another worker,
# the revoking worker writes a Redis revocation marker. Check it BEFORE the
# L1 in-memory cache so that stale L1 entries cannot bypass revocation.
if jti:
redis = await self._get_redis_client()
if redis:
try:
revoke_key = self._get_redis_key("revoke", jti)
if await redis.exists(revoke_key):
# Promote to local set so subsequent requests skip the Redis call
with self._lock:
self._revoked_jtis.add(jti)
# Evict any stale L1 context entries for this JTI
for k in [k for k in self._context_cache if k.endswith(f":{jti}")]:
self._context_cache.pop(k, None)
self._hit_count += 1
return CachedAuthContext(is_token_revoked=True)
except Exception as exc:
logger.debug(f"AuthCache: Redis revocation check failed for {jti[:8]}: {exc}")

cache_key = f"{email}:{jti or 'no-jti'}"

# Check L1 in-memory cache first (no network I/O)
Expand Down
63 changes: 61 additions & 2 deletions mcpgateway/middleware/auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from typing import Callable

# Third-Party
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import JSONResponse, Response

# First-Party
from mcpgateway.auth import get_current_user
Expand All @@ -35,6 +36,12 @@
logger = logging.getLogger(__name__)
security_logger = get_security_logger()

# HTTPException detail strings that indicate security-critical rejections
# (revoked tokens, disabled accounts, fail-secure validation errors).
# Only these trigger a hard JSON deny in the auth middleware; all other
# 401/403s fall through to route-level auth for backwards compatibility.
_HARD_DENY_DETAILS = frozenset({"Token has been revoked", "Account disabled", "Token validation failed"})


def _should_log_auth_success() -> bool:
"""Check if successful authentication should be logged based on settings.
Expand Down Expand Up @@ -144,8 +151,60 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
except Exception as close_error:
logger.debug(f"Failed to close database session: {close_error}")

except HTTPException as e:
if e.status_code in (401, 403) and e.detail in _HARD_DENY_DETAILS:
logger.info(f"✗ Auth rejected ({e.status_code}): {e.detail}")

if log_failure:
db = SessionLocal()
try:
security_logger.log_authentication_attempt(
user_id="unknown",
user_email=None,
auth_method="bearer_token",
success=False,
client_ip=request.client.host if request.client else "unknown",
user_agent=request.headers.get("user-agent"),
failure_reason=str(e.detail),
db=db,
)
db.commit()
except Exception as log_error:
logger.debug(f"Failed to log auth failure: {log_error}")
finally:
try:
db.close()
except Exception as close_error:
logger.debug(f"Failed to close database session: {close_error}")

# Browser/admin requests with stale cookies: let the request continue
# without user context so the RBAC layer can redirect to /admin/login.
# API requests: return a hard JSON 401/403 deny.
# Detection must match rbac.py's is_browser_request logic (Accept,
# HX-Request, and Referer: /admin) to avoid breaking admin UI flows.
accept_header = request.headers.get("accept", "")
is_htmx = request.headers.get("hx-request") == "true"
referer = request.headers.get("referer", "")
is_browser = "text/html" in accept_header or is_htmx or "/admin" in referer
if is_browser:
logger.debug("Browser request with rejected auth — continuing without user for redirect")
return await call_next(request)

# Include essential security headers since this response bypasses
# SecurityHeadersMiddleware (it returns before call_next).
resp_headers = dict(e.headers) if e.headers else {}
resp_headers.setdefault("X-Content-Type-Options", "nosniff")
resp_headers.setdefault("Referrer-Policy", "strict-origin-when-cross-origin")
return JSONResponse(
status_code=e.status_code,
content={"detail": e.detail},
headers=resp_headers,
)

# Non-security HTTP errors (e.g. 500 from a downstream service) — continue as anonymous
logger.info(f"✗ Auth context extraction failed (continuing as anonymous): {e}")
except Exception as e:
# Silently fail - let route handlers enforce auth if needed
# Non-HTTP errors (network, decode, etc.) — continue as anonymous
logger.info(f"✗ Auth context extraction failed (continuing as anonymous): {e}")

# Log failed authentication attempt (based on logging level)
Expand Down
128 changes: 94 additions & 34 deletions mcpgateway/middleware/token_usage_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
# Standard
import logging
import time
from typing import Optional

# Third-Party
import jwt as _jwt
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send
Expand Down Expand Up @@ -108,62 +110,120 @@ async def send_wrapper(message: dict) -> None:
# Calculate response time
response_time_ms = round((time.time() - start_time) * 1000)

# Only log if this was an API token request
# Log API token usage — covers both successful requests and auth-rejected attempts.
# Every request that uses (or tries to use) an API token is recorded,
# including blocked calls with revoked/expired tokens, so that usage stats are accurate.
state = scope.get("state", {})
auth_method = state.get("auth_method") if state else None

if auth_method != "api_token":
return
jti: Optional[str] = None
user_email: Optional[str] = None
blocked: bool = False
block_reason: Optional[str] = None

if auth_method == "api_token":
# --- Successfully authenticated API token request ---
jti = state.get("jti") if state else None
user = state.get("user") if state else None
user_email = getattr(user, "email", None) if user else None
if not user_email:
user_email = state.get("user_email") if state else None

# If we don't have JTI or email, try to decode the token from the header
if not jti or not user_email:
try:
headers = Headers(scope=scope)
auth_header = headers.get("authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return
token = auth_header.replace("Bearer ", "")
request = Request(scope, receive)
try:
payload = await verify_jwt_token_cached(token, request)
jti = jti or payload.get("jti")
user_email = user_email or payload.get("sub") or payload.get("email")
except Exception as decode_error:
logger.debug(f"Failed to decode token for usage logging: {decode_error}")
return
except Exception as e:
logger.debug(f"Error extracting token information: {e}")
return

# Extract token information from scope state
jti = state.get("jti") if state else None
user = state.get("user") if state else None
user_email = getattr(user, "email", None) if user else None
if not user_email:
user_email = state.get("user_email") if state else None
if not jti or not user_email:
logger.debug("Missing JTI or user_email for token usage logging")
return

# If we don't have JTI or email, try to decode the token
if not jti or not user_email:
# Bug 3a fix: reflect the actual outcome — 4xx responses mark the attempt
# as blocked (e.g. RBAC denied, rate-limited, or server-scoping violation).
# 5xx errors are backend failures, not security denials, so exclude them.
blocked = 400 <= status_code < 500
if blocked:
block_reason = f"http_{status_code}"

elif status_code in (401, 403):
# --- Auth-rejected request: check if the Bearer token was an API token ---
# When a revoked or expired API token is used, auth middleware rejects the
# request before setting auth_method="api_token", so the path above is
# never reached. We detect the attempt here by decoding the JWT payload
# without re-verifying it (the token identity is valid even if rejected).
try:
# Get token from Authorization header
headers = Headers(scope=scope)
auth_header = headers.get("authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return
raw_token = auth_header[7:] # strip "Bearer "

token = auth_header.replace("Bearer ", "")
# Decode without signature/expiry check — for identification only, not auth.
unverified = _jwt.decode(raw_token, options={"verify_signature": False})
user_info = unverified.get("user", {})
if user_info.get("auth_provider") != "api_token":
return # Not an API token — nothing to log

# Decode token to get JTI and user email
# Note: We need to create a minimal Request-like object
request = Request(scope, receive)
try:
payload = await verify_jwt_token_cached(token, request)
jti = jti or payload.get("jti")
user_email = user_email or payload.get("sub") or payload.get("email")
except Exception as decode_error:
logger.debug(f"Failed to decode token for usage logging: {decode_error}")
jti = unverified.get("jti")
user_email = unverified.get("sub") or unverified.get("email")
if not jti or not user_email:
return

# Verify JTI belongs to a real API token before logging.
# Without this check, an attacker can craft a JWT with fake
# jti/sub and auth_provider=api_token to pollute usage logs.
# Verify JTI belongs to a real API token and use the DB-stored
# owner email instead of the unverified JWT claim. Without this,
# an attacker who knows a valid JTI could forge a JWT with an
# arbitrary sub/email to poison another user's usage stats.
try:
# Third-Party
from sqlalchemy import select # pylint: disable=import-outside-toplevel

# First-Party
from mcpgateway.db import EmailApiToken # pylint: disable=import-outside-toplevel

with fresh_db_session() as verify_db:
token_row = verify_db.execute(select(EmailApiToken.id, EmailApiToken.user_email).where(EmailApiToken.jti == jti)).first()
if token_row is None:
return # JTI not in DB — forged token, skip logging
# Use the DB-stored owner, not the unverified JWT claim
user_email = token_row.user_email
except Exception:
return # DB error — skip logging rather than log unverified data

blocked = True
block_reason = "revoked_or_expired" if status_code == 401 else f"http_{status_code}"
except Exception as e:
logger.debug(f"Error extracting token information: {e}")
logger.debug(f"Failed to extract API token identity from rejected request: {e}")
return
else:
return # Not an API token request — nothing to log

if not jti or not user_email:
logger.debug("Missing JTI or user_email for token usage logging")
return

# Log token usage
# Shared logging path for both authenticated and blocked API token requests
try:
with fresh_db_session() as db:
token_service = TokenCatalogService(db)
# Get client IP
client = scope.get("client")
ip_address = client[0] if client else None

# Get user agent
headers = Headers(scope=scope)
user_agent = headers.get("user-agent")

# Log usage
await token_service.log_token_usage(
jti=jti,
user_email=user_email,
Expand All @@ -173,8 +233,8 @@ async def send_wrapper(message: dict) -> None:
user_agent=user_agent,
status_code=status_code,
response_time_ms=response_time_ms,
blocked=False,
block_reason=None,
blocked=blocked,
block_reason=block_reason,
)
except Exception as e:
logger.debug(f"Failed to log token usage: {e}")
Loading
Loading