Skip to content

Commit c1a6cf0

Browse files
DX-115674: JWKS-based JWT verification to fix 401 token refresh
When the Dremio API returns 401 (expired token), the MCP SDK's except-Exception catch-all wraps it in a 200 tool error response. The MCP client never sees HTTP 401 so its OAuth refresh flow is never triggered. Fix: add optional `jwks_uri` and `jwks_cache_lifespan` settings. When configured, verify JWT signatures via JWKS and extract claims (exp, aud) in verify_token(). BearerAuthBackend rejects expired tokens with HTTP 401 *before* tool execution and SSE headers are sent, triggering client refresh. - JWKSVerifier: async PyJWKClient wrapper (blocking JWKS fetch runs in thread pool), configurable cache lifespan, auto-refresh on key rotation or fetch errors - Reconciled _extract_jwt_aud: when JWKS is configured, aud is extracted from verified claims; falls back to unverified decode - Graceful degradation: verification failures return None (don't block), letting Dremio handle validation as before - E2E tests proving the SDK swallows 401 as 200 - 8 unit tests for JWKSVerifier Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ec4498d commit c1a6cf0

5 files changed

Lines changed: 445 additions & 11 deletions

File tree

src/dremioai/config/settings.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
field_serializer,
2727
AliasChoices,
2828
)
29-
from pydantic_settings import BaseSettings, SettingsConfigDict
29+
from pydantic_settings import BaseSettings, SettingsConfigDict, PydanticBaseSettingsSource, YamlConfigSettingsSource
3030
from typing import (
3131
Optional,
3232
Union,
@@ -44,7 +44,7 @@
4444
from dremioai.config.tools import ToolType
4545
from enum import auto, StrEnum
4646
from pathlib import Path
47-
from yaml import safe_load, add_representer, dump
47+
from yaml import add_representer, dump
4848
from functools import reduce
4949
from operator import ior
5050
from shutil import which
@@ -253,6 +253,17 @@ class Dremio(FlagAwareModel):
253253
description="Extract org ID from JWT aud claim for LD context targeting",
254254
)
255255
auth_issuer_uri_override: Optional[str] = None
256+
jwks_uri: Optional[str] = Field(
257+
default=None,
258+
description="JWKS endpoint URL for JWT signature verification and expiry checking. "
259+
"When set, the MCP server validates token expiry before tool execution "
260+
"so expired tokens trigger HTTP 401 and the client's OAuth refresh flow. "
261+
"Example: https://your-auth0-tenant.auth0.com/.well-known/jwks.json",
262+
)
263+
jwks_cache_lifespan: Optional[int] = Field(
264+
default=3600,
265+
description="How long (seconds) to cache JWKS keys before refetching. Default: 3600 (1 hour).",
266+
)
256267
wlm: Optional[Wlm] = None
257268
api: Optional[ApiSettings] = Field(default_factory=ApiSettings)
258269
metrics: Optional[Metrics] = None
@@ -400,6 +411,23 @@ class Settings(FlagAwareMixin, BaseSettings):
400411
validate_assignment=True,
401412
)
402413

414+
@classmethod
415+
def settings_customise_sources(
416+
cls,
417+
settings_cls: type[BaseSettings],
418+
init_settings: PydanticBaseSettingsSource,
419+
env_settings: PydanticBaseSettingsSource,
420+
dotenv_settings: PydanticBaseSettingsSource,
421+
file_secret_settings: PydanticBaseSettingsSource,
422+
) -> tuple[PydanticBaseSettingsSource, ...]:
423+
return (
424+
init_settings,
425+
env_settings,
426+
dotenv_settings,
427+
YamlConfigSettingsSource(settings_cls, yaml_file=_yaml_file),
428+
file_secret_settings,
429+
)
430+
403431
def model_post_init(self, __context):
404432
_propagate_flag_prefixes(self, "")
405433
if self.launchdarkly and self.launchdarkly.sdk_key:
@@ -459,6 +487,10 @@ def collect_flag_keys(model_cls: type, prefix: str = "") -> list[str]:
459487
return sorted(keys)
460488

461489

490+
# Module-level holder so configure() can pass the YAML path to the Settings constructor
491+
_yaml_file: Path | None = None
492+
493+
462494
_settings: ContextVar[Settings] = ContextVar("settings", default=None)
463495

464496

@@ -498,9 +530,9 @@ def configure(cfg: Union[str, Path] = None, force=False) -> ContextVar[Settings]
498530
cfg.parent.mkdir(parents=True, exist_ok=True)
499531
cfg.touch()
500532

501-
with cfg.open() as f:
502-
s = safe_load(f)
503-
_settings.set(Settings.model_validate(s if s else {}))
533+
global _yaml_file
534+
_yaml_file = cfg
535+
_settings.set(Settings())
504536

505537
return _settings
506538

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#
2+
# Copyright (C) 2017-2025 Dremio Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
"""
17+
JWKS-based JWT token verifier for the MCP server.
18+
19+
Verifies JWT signatures and extracts claims (``exp``, ``aud``, etc.)
20+
so that expired tokens are rejected with HTTP 401 *before* any tool
21+
execution, triggering the MCP client's OAuth token refresh flow.
22+
23+
The JWKS keyset is cached and refreshed automatically on cache miss
24+
or verification error (e.g. key rotation).
25+
"""
26+
27+
import asyncio
28+
from dataclasses import dataclass
29+
from typing import Optional
30+
31+
import jwt as pyjwt
32+
from jwt import PyJWKClient, PyJWKClientError, ExpiredSignatureError
33+
from dremioai import log
34+
35+
logger = log.logger(__name__)
36+
37+
_DEFAULT_JWKS_CACHE_LIFESPAN = 3600 # 1 hour in seconds
38+
39+
40+
@dataclass
41+
class VerifiedClaims:
42+
"""Subset of JWT claims extracted after signature verification."""
43+
exp: Optional[int] = None
44+
aud: Optional[str] = None
45+
46+
47+
class JWKSVerifier:
48+
"""Verify JWT tokens using a remote JWKS endpoint.
49+
50+
Uses ``PyJWKClient`` with built-in caching (``lifespan`` seconds).
51+
On verification failure due to a key-related error the cache is
52+
invalidated and verification is retried once with fresh keys.
53+
"""
54+
55+
def __init__(self, jwks_uri: str, lifespan: int = _DEFAULT_JWKS_CACHE_LIFESPAN):
56+
self._jwks_uri = jwks_uri
57+
self._lifespan = lifespan
58+
self._client = PyJWKClient(
59+
jwks_uri,
60+
cache_jwk_set=True,
61+
lifespan=lifespan,
62+
)
63+
logger.info(f"JWKS verifier initialised with uri={jwks_uri}, cache={lifespan}s")
64+
65+
async def verify(self, token: str) -> Optional[VerifiedClaims]:
66+
"""Verify *token* and return its claims.
67+
68+
Returns ``None`` when verification cannot be performed — the
69+
token will still be forwarded to Dremio for real validation on
70+
the tool call.
71+
72+
Returns ``VerifiedClaims(exp=0)`` for expired tokens so that
73+
``BearerAuthBackend`` rejects them with HTTP 401.
74+
75+
PyJWKClient.get_signing_key_from_jwt() makes blocking HTTP calls
76+
to fetch JWKS on cache miss, so we run it in a thread pool to
77+
avoid blocking the event loop.
78+
"""
79+
loop = asyncio.get_running_loop()
80+
try:
81+
return await loop.run_in_executor(None, self._verify, token)
82+
except (pyjwt.InvalidKeyError, PyJWKClientError, KeyError):
83+
logger.info("JWKS cache miss or fetch error, refreshing and retrying")
84+
try:
85+
self._client = PyJWKClient(
86+
self._jwks_uri,
87+
cache_jwk_set=True,
88+
lifespan=self._lifespan,
89+
)
90+
return await loop.run_in_executor(None, self._verify, token)
91+
except Exception:
92+
logger.warning("JWKS verification failed after cache refresh", exc_info=True)
93+
return None
94+
except ExpiredSignatureError:
95+
logger.debug("Token expired")
96+
return VerifiedClaims(exp=0)
97+
except Exception:
98+
logger.debug("JWT verification failed, skipping enforcement", exc_info=True)
99+
return None
100+
101+
def _verify(self, token: str) -> VerifiedClaims:
102+
signing_key = self._client.get_signing_key_from_jwt(token)
103+
claims = pyjwt.decode(
104+
token,
105+
signing_key.key,
106+
algorithms=["RS256"],
107+
options={
108+
"verify_aud": False,
109+
"verify_iss": False,
110+
"verify_exp": True,
111+
},
112+
)
113+
aud = claims.get("aud")
114+
if isinstance(aud, list):
115+
aud = aud[0] if aud else None
116+
return VerifiedClaims(
117+
exp=claims.get("exp"),
118+
aud=aud,
119+
)

src/dremioai/servers/mcp.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,22 @@ class Transports(StrEnum):
108108
class FastMCPServerWithAuthToken(FastMCP):
109109
class DelegatingTokenVerifier(TokenVerifier):
110110

111+
def __init__(self):
112+
self._jwks_verifier = None
113+
dremio = settings.instance().dremio
114+
jwks_uri = dremio.get("jwks_uri")
115+
if jwks_uri:
116+
from dremioai.servers.jwks_verifier import JWKSVerifier
117+
lifespan = dremio.get("jwks_cache_lifespan") or 3600
118+
self._jwks_verifier = JWKSVerifier(jwks_uri, lifespan=lifespan)
119+
111120
@staticmethod
112121
def _extract_jwt_aud(token: str) -> str | None:
113122
"""Extract the aud claim from a JWT without signature verification.
114123
115-
The token is always forwarded to Dremio for real validation;
116-
we only read the audience (org ID) for LD context targeting.
124+
Fallback for when JWKS is not configured. The token is always
125+
forwarded to Dremio for real validation; we only read the
126+
audience (org ID) for LD context targeting.
117127
"""
118128
try:
119129
import jwt
@@ -125,13 +135,25 @@ def _extract_jwt_aud(token: str) -> str | None:
125135

126136
async def verify_token(self, token: str) -> AccessToken | None:
127137
if token:
128-
if settings.instance().dremio.get("extract_org_id_from_jwt"):
129-
if org_id := self._extract_jwt_aud(token):
130-
FeatureFlagManager.set_org_id(org_id)
138+
expires_at = None
139+
org_id = None
140+
141+
if self._jwks_verifier:
142+
verified = await self._jwks_verifier.verify(token)
143+
if verified:
144+
expires_at = verified.exp
145+
org_id = verified.aud
146+
elif settings.instance().dremio.get("extract_org_id_from_jwt"):
147+
org_id = self._extract_jwt_aud(token)
148+
149+
if org_id:
150+
FeatureFlagManager.set_org_id(org_id)
151+
131152
return AccessToken(
132-
token=token, # Include the token itself
153+
token=token,
133154
client_id="unused-client",
134155
scopes=["read"],
156+
expires_at=expires_at,
135157
)
136158
else:
137159
log.logger("verify_token").info(f"Token not provided")

0 commit comments

Comments
 (0)