-
Notifications
You must be signed in to change notification settings - Fork 636
Expand file tree
/
Copy pathtoken_validation_service.py
More file actions
294 lines (239 loc) · 10.9 KB
/
token_validation_service.py
File metadata and controls
294 lines (239 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# -*- coding: utf-8 -*-
"""Location: ./mcpgateway/services/token_validation_service.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Akshay Shinde
OAuth Token Claim Validation Service for ContextForge.
This module provides best-effort JWT claim validation (audience, scopes, issuer)
for OAuth access tokens before they are forwarded to upstream MCP servers.
Validation is advisory — the upstream MCP server remains the authoritative validator.
Tokens are issued by external Identity Providers (Entra ID, Keycloak, etc.) and
ContextForge does NOT possess their signing keys, so JWT signatures are not verified.
Opaque (non-JWT) tokens are handled gracefully.
"""
# Standard
from dataclasses import dataclass, field
import logging
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
# Third-Party
import jwt
logger = logging.getLogger(__name__)
@dataclass
class TokenValidationResult:
"""Result of OAuth token claim validation.
Examples:
>>> r = TokenValidationResult(is_jwt=True)
>>> r.is_jwt
True
>>> r.warnings
[]
>>> r.audience_match is None
True
>>> r.blocking_errors
[]
"""
is_jwt: bool = False
warnings: List[str] = field(default_factory=list)
audience_match: Optional[bool] = None
scopes_sufficient: Optional[bool] = None
issuer_match: Optional[bool] = None
token_type_valid: Optional[bool] = None
@property
def blocking_errors(self) -> List[str]:
"""Return warnings that correspond to claims present-but-mismatched.
A claim that is simply *absent* from the token produces ``None`` (not
``False``) for the corresponding flag and is **not** blocking — this
preserves backward compatibility with legacy IdPs that omit ``aud``,
``scope``/``scp``, or ``iss``.
Only ``False`` (claim present and wrong) is treated as blocking,
because in that case the upstream MCP server is guaranteed to reject
the token and there is no value in making the network round-trip.
Returns:
List of warning strings whose corresponding claim check evaluated
to ``False`` (present + mismatched). Empty list means no
definitive mismatch was detected.
Examples:
>>> r = TokenValidationResult(is_jwt=True)
>>> r.blocking_errors
[]
>>> r.audience_match = False
>>> r.warnings.append("Token audience mismatch: token aud does not match expected resource or gateway URL")
>>> r.blocking_errors
['Token audience mismatch: token aud does not match expected resource or gateway URL']
"""
if not self.warnings:
return []
# Map each failed flag to the warning it generated by keyword presence.
# This avoids tight coupling between the flag and the exact warning text.
blocking: List[str] = []
if self.audience_match is False:
blocking.extend(w for w in self.warnings if "audience" in w.lower())
if self.scopes_sufficient is False:
blocking.extend(w for w in self.warnings if "scope" in w.lower())
if self.issuer_match is False:
blocking.extend(w for w in self.warnings if "issuer" in w.lower())
return blocking
def _derive_issuer_from_token_url(token_url: str) -> Optional[str]:
"""Derive the expected issuer URL from the token endpoint URL.
For Entra ID: ``https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token``
→ ``https://login.microsoftonline.com/{tenant}/v2.0``
For generic IdPs: returns the base URL (scheme + host).
Args:
token_url: The OAuth token endpoint URL.
Returns:
Derived issuer string, or None if the URL cannot be parsed.
"""
if not token_url:
return None
parsed = urlparse(token_url)
if not parsed.scheme or not parsed.netloc:
return None
# Entra ID pattern: /tenant-id/oauth2/v2.0/token → issuer is /tenant-id/v2.0
path_parts = [p for p in parsed.path.split("/") if p]
if len(path_parts) >= 4 and path_parts[1] == "oauth2" and path_parts[2] == "v2.0" and path_parts[3] == "token":
tenant = path_parts[0]
return f"{parsed.scheme}://{parsed.netloc}/{tenant}/v2.0"
return f"{parsed.scheme}://{parsed.netloc}"
def _normalize_scope(scope_str: str) -> set:
"""Normalize a scope string by stripping resource URI prefixes.
IdPs like Entra ID may return scopes as ``api://app-id/Scope.Name``
while the gateway config stores the full URI form. We compare both
the full form and the short form (after the last ``/``).
Args:
scope_str: Space-delimited scope string from the token.
Returns:
Set of normalized scope names (both full and short forms).
"""
scopes = set()
for s in scope_str.split():
scopes.add(s)
# Also add the short form (after last '/')
if "/" in s:
scopes.add(s.rsplit("/", 1)[-1])
return scopes
def _validate_audience(claims: Dict[str, Any], oauth_config: Dict[str, Any], gateway_url: str, gateway_name: str, result: TokenValidationResult) -> None:
"""Check the ``aud`` claim against the expected audience.
Args:
claims: Decoded JWT claims.
oauth_config: Gateway OAuth configuration.
gateway_url: Fallback audience when ``resource`` is not configured.
gateway_name: Gateway name for log messages.
result: Validation result to update in-place.
"""
expected = oauth_config.get("resource") or gateway_url
if not expected:
return
token_aud = claims.get("aud")
if token_aud is None:
logger.debug("OAuth token for gateway %s has no 'aud' claim", gateway_name)
return
# Normalize both sides to lists for a simple membership check.
# Per RFC 7519 Section 4.1.3, aud can be a string or array.
expected_list = expected if isinstance(expected, list) else [expected]
aud_list = token_aud if isinstance(token_aud, list) else [token_aud]
if any(a in expected_list for a in aud_list):
result.audience_match = True
else:
result.audience_match = False
result.warnings.append("Token audience mismatch: token aud does not match expected resource or gateway URL")
def _validate_scopes(claims: Dict[str, Any], oauth_config: Dict[str, Any], gateway_name: str, result: TokenValidationResult) -> None:
"""Check the ``scope`` / ``scp`` claim against configured scopes.
Args:
claims: Decoded JWT claims.
oauth_config: Gateway OAuth configuration.
gateway_name: Gateway name for log messages.
result: Validation result to update in-place.
"""
configured_scopes = oauth_config.get("scopes", [])
if not configured_scopes:
return
# Entra ID uses 'scp' claim; standard OAuth uses 'scope'
token_scope_str = claims.get("scope") or claims.get("scp") or ""
if not token_scope_str:
logger.debug("OAuth token for gateway %s has no 'scope'/'scp' claim", gateway_name)
return
granted_scopes = _normalize_scope(token_scope_str)
missing = []
for cfg_scope in configured_scopes:
short = cfg_scope.rsplit("/", 1)[-1] if "/" in cfg_scope else cfg_scope
if cfg_scope not in granted_scopes and short not in granted_scopes:
missing.append(cfg_scope)
if missing:
result.scopes_sufficient = False
safe_missing = ", ".join(s[:60] for s in missing[:5])
result.warnings.append(f"Token may be missing required scopes: [{safe_missing}]")
else:
result.scopes_sufficient = True
def _validate_issuer(claims: Dict[str, Any], oauth_config: Dict[str, Any], gateway_name: str, result: TokenValidationResult) -> None:
"""Check the ``iss`` claim against the expected issuer.
Args:
claims: Decoded JWT claims.
oauth_config: Gateway OAuth configuration.
gateway_name: Gateway name for log messages.
result: Validation result to update in-place.
"""
expected_issuer = oauth_config.get("issuer") or _derive_issuer_from_token_url(oauth_config.get("token_url", ""))
if not expected_issuer:
return
token_iss = claims.get("iss")
if not token_iss:
logger.debug("OAuth token for gateway %s has no 'iss' claim", gateway_name)
return
if token_iss.rstrip("/") == expected_issuer.rstrip("/"):
result.issuer_match = True
else:
result.issuer_match = False
safe_iss = str(token_iss)[:80]
safe_expected = str(expected_issuer)[:80]
result.warnings.append(f"Token issuer mismatch: token iss='{safe_iss}', expected '{safe_expected}'")
def validate_oauth_token_claims(
access_token: str,
oauth_config: Dict[str, Any],
gateway_url: str,
gateway_name: str,
token_type: str = "Bearer", # nosec B107
) -> TokenValidationResult:
"""Validate JWT claims on an OAuth access token before forwarding to an MCP server.
This is a best-effort, advisory validation. The upstream MCP server is
the authoritative token validator. Warnings are logged but do not block
token forwarding.
Args:
access_token: The OAuth access token (JWT or opaque).
oauth_config: The gateway's OAuth configuration dict (contains scopes, resource, token_url, issuer).
gateway_url: The gateway URL, used as fallback audience if ``resource`` is not configured.
gateway_name: The gateway name, used for log messages.
token_type: The stored token type (from OAuthToken model).
Returns:
TokenValidationResult with validation details and any warnings.
Examples:
>>> result = validate_oauth_token_claims("opaque-token", {}, "https://example.com", "gw")
>>> result.is_jwt
False
>>> result.warnings
[]
"""
result = TokenValidationResult()
# Validate token_type
if token_type.lower() != "bearer":
result.token_type_valid = False
result.warnings.append(f"Unexpected token_type '{token_type}', expected 'Bearer'")
else:
result.token_type_valid = True
# Attempt JWT decode (no signature verification — we don't have IdP keys)
try:
claims = jwt.decode(
access_token,
options={"verify_signature": False, "verify_aud": False, "verify_iss": False, "verify_exp": False},
algorithms=["RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512", "HS256", "HS384", "HS512", "EdDSA"],
)
except jwt.DecodeError:
# Opaque (non-JWT) token — skip claim validation
logger.debug("OAuth token for gateway %s is not a JWT; skipping claim validation", gateway_name)
result.is_jwt = False
return result
result.is_jwt = True
_validate_audience(claims, oauth_config, gateway_url, gateway_name, result)
_validate_scopes(claims, oauth_config, gateway_name, result)
_validate_issuer(claims, oauth_config, gateway_name, result)
return result