Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 70 additions & 34 deletions mcpgateway/middleware/token_usage_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
# Standard
import logging
import time
from typing import Optional

# Third-Party
import jwt as _jwt
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send
Expand Down Expand Up @@ -108,62 +110,96 @@ async def send_wrapper(message: dict) -> None:
# Calculate response time
response_time_ms = round((time.time() - start_time) * 1000)

# Only log if this was an API token request
# Log API token usage — covers both successful requests and auth-rejected attempts.
# Every request that uses (or tries to use) an API token is recorded,
# including blocked calls with revoked/expired tokens, so that usage stats are accurate.
state = scope.get("state", {})
auth_method = state.get("auth_method") if state else None

if auth_method != "api_token":
return
jti: Optional[str] = None
user_email: Optional[str] = None
blocked: bool = False
block_reason: Optional[str] = None

if auth_method == "api_token":
# --- Successfully authenticated API token request ---
jti = state.get("jti") if state else None
user = state.get("user") if state else None
user_email = getattr(user, "email", None) if user else None
if not user_email:
user_email = state.get("user_email") if state else None

# If we don't have JTI or email, try to decode the token from the header
if not jti or not user_email:
try:
headers = Headers(scope=scope)
auth_header = headers.get("authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return
token = auth_header.replace("Bearer ", "")
request = Request(scope, receive)
try:
payload = await verify_jwt_token_cached(token, request)
jti = jti or payload.get("jti")
user_email = user_email or payload.get("sub") or payload.get("email")
except Exception as decode_error:
logger.debug(f"Failed to decode token for usage logging: {decode_error}")
return
except Exception as e:
logger.debug(f"Error extracting token information: {e}")
return

# Extract token information from scope state
jti = state.get("jti") if state else None
user = state.get("user") if state else None
user_email = getattr(user, "email", None) if user else None
if not user_email:
user_email = state.get("user_email") if state else None
if not jti or not user_email:
logger.debug("Missing JTI or user_email for token usage logging")
return

# If we don't have JTI or email, try to decode the token
if not jti or not user_email:
# Bug 3a fix: reflect the actual outcome — 4xx responses mark the attempt
# as blocked (e.g. RBAC denied, rate-limited, or server-scoping violation).
blocked = status_code >= 400
if blocked:
block_reason = f"http_{status_code}"

elif status_code in (401, 403):
# --- Auth-rejected request: check if the Bearer token was an API token ---
# When a revoked or expired API token is used, auth middleware rejects the
# request before setting auth_method="api_token", so the path above is
# never reached. We detect the attempt here by decoding the JWT payload
# without re-verifying it (the token identity is valid even if rejected).
try:
# Get token from Authorization header
headers = Headers(scope=scope)
auth_header = headers.get("authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return
raw_token = auth_header[7:] # strip "Bearer "

token = auth_header.replace("Bearer ", "")
# Decode without signature/expiry check — for identification only, not auth.
unverified = _jwt.decode(raw_token, options={"verify_signature": False})
user_info = unverified.get("user", {})
if user_info.get("auth_provider") != "api_token":
return # Not an API token — nothing to log

# Decode token to get JTI and user email
# Note: We need to create a minimal Request-like object
request = Request(scope, receive)
try:
payload = await verify_jwt_token_cached(token, request)
jti = jti or payload.get("jti")
user_email = user_email or payload.get("sub") or payload.get("email")
except Exception as decode_error:
logger.debug(f"Failed to decode token for usage logging: {decode_error}")
jti = unverified.get("jti")
user_email = unverified.get("sub") or unverified.get("email")
if not jti or not user_email:
return

blocked = True
block_reason = "revoked_or_expired" if status_code == 401 else f"http_{status_code}"
except Exception as e:
logger.debug(f"Error extracting token information: {e}")
logger.debug(f"Failed to extract API token identity from rejected request: {e}")
return
else:
return # Not an API token request — nothing to log

if not jti or not user_email:
logger.debug("Missing JTI or user_email for token usage logging")
return

# Log token usage
# Shared logging path for both authenticated and blocked API token requests
try:
with fresh_db_session() as db:
token_service = TokenCatalogService(db)
# Get client IP
client = scope.get("client")
ip_address = client[0] if client else None

# Get user agent
headers = Headers(scope=scope)
user_agent = headers.get("user-agent")

# Log usage
await token_service.log_token_usage(
jti=jti,
user_email=user_email,
Expand All @@ -173,8 +209,8 @@ async def send_wrapper(message: dict) -> None:
user_agent=user_agent,
status_code=status_code,
response_time_ms=response_time_ms,
blocked=False,
block_reason=None,
blocked=blocked,
block_reason=block_reason,
)
except Exception as e:
logger.debug(f"Failed to log token usage: {e}")
38 changes: 30 additions & 8 deletions mcpgateway/routers/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@


def _require_authenticated_session(current_user: dict) -> None:
"""Block anonymous and unauthenticated access to token management endpoints.
"""Block anonymous, unauthenticated, and API-token access to token management endpoints.

Rejects requests where authentication could not be determined or where
the caller is anonymous. All authenticated methods (JWT, API tokens,
OAuth, SSO, proxy, etc.) are allowed — RBAC permission checks and
scope containment (via _get_caller_permissions) handle authorization.
Enforces Management Plane isolation: only interactive sessions (JWT from web
login, SSO, or OAuth) may create, list, or revoke tokens. API tokens are
Data Plane credentials and must never be able to manage other tokens
(token-chaining attack vector).

Args:
current_user: User context from get_current_user_with_permissions

Raises:
HTTPException: 403 if auth_method is None or anonymous
HTTPException: 403 if auth_method is None, anonymous, or api_token
"""
auth_method = current_user.get("auth_method")

Expand All @@ -60,6 +60,16 @@ def _require_authenticated_session(current_user: dict) -> None:
detail="Token management requires authentication. Anonymous access is not permitted.",
)

# Block API tokens from managing other tokens (Management Plane isolation).
# Token CRUD endpoints require an interactive session (JWT from web login or SSO).
# Allowing API tokens here would let a compromised token create new long-lived
# tokens and escalate persistence — a token-chaining attack.
if auth_method == "api_token":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=("Token management requires an interactive session (JWT from web login or SSO). " "API tokens cannot create, list, or revoke other tokens."),
)


async def _get_caller_permissions(
db: Session,
Expand Down Expand Up @@ -114,12 +124,24 @@ async def create_token(
"""
_require_authenticated_session(current_user)

# Auto-inherit team_id from the caller's single team when not explicitly provided.
# This prevents tokens from being silently scoped to public-only (team_id=None)
# when the user belongs to exactly one team, maintaining RBAC context at token level.
# Multi-team users must specify team_id explicitly to avoid ambiguity.
# Admins with teams=null are exempt and may still create global-scope tokens.
effective_team_id = request.team_id
if effective_team_id is None and not current_user.get("is_admin"):
user_teams = current_user.get("token_teams") or []
if len(user_teams) == 1:
effective_team_id = user_teams[0]
logger.debug("Auto-inherited team_id=%s for token creation by %s", effective_team_id, current_user["email"])

service = TokenCatalogService(db)

# Get caller permissions for scope containment (if custom scope requested)
caller_permissions = None
if request.scope and request.scope.permissions:
caller_permissions = await _get_caller_permissions(db, current_user, request.team_id)
caller_permissions = await _get_caller_permissions(db, current_user, effective_team_id)

# Convert request to TokenScope if provided
scope = None
Expand All @@ -140,7 +162,7 @@ async def create_token(
scope=scope,
expires_in_days=request.expires_in_days,
tags=request.tags,
team_id=request.team_id,
team_id=effective_team_id,
caller_permissions=caller_permissions,
is_active=request.is_active,
)
Expand Down
Loading
Loading