Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This project adheres to `Semantic Versioning <https://semver.org/>`__.

Fixed
~~~~~
- Fix indefinite key caching in PyJWKClient by replacing lru_cache with TTL-aware cache in `#1070 <https://github.com/jpadilla/pyjwt/pull/1070>`__
- Validate key against allowed types for Algorithm family in `#964 <https://github.com/jpadilla/pyjwt/pull/964>`__
- Add iterator for JWKSet in `#1041 <https://github.com/jpadilla/pyjwt/pull/1041>`__
- Validate `iss` claim is a string during encoding and decoding by @pachewise in `#1040 <https://github.com/jpadilla/pyjwt/pull/1040>`__
Expand Down
54 changes: 47 additions & 7 deletions jwt/jwks_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import json
import time
import urllib.request
from functools import lru_cache
from ssl import SSLContext
from typing import Any, Dict, List, Optional
from urllib.error import URLError
Expand Down Expand Up @@ -44,12 +44,17 @@ def __init__(
else:
self.jwk_set_cache = None

# Replace lru_cache with TTL-aware individual key cache
# Use the same TTL as JWKSetCache for consistency
if cache_keys:
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(
self.get_signing_key
) # type: ignore
self._key_cache_enabled = True
self._key_cache: Dict[
str, tuple[PyJWK, float]
] = {} # kid -> (key, timestamp)
self._max_cached_keys = max_cached_keys
self._key_cache_ttl = lifespan # Use same TTL as JWKSetCache
else:
self._key_cache_enabled = False

def fetch_data(self) -> Any:
jwk_set: Any = None
Expand Down Expand Up @@ -95,12 +100,45 @@ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:

return signing_keys

def _get_cached_key(self, kid: str) -> Optional[PyJWK]:
"""Get a cached key if it exists and hasn't expired."""
if not self._key_cache_enabled or kid not in self._key_cache:
return None

key, timestamp = self._key_cache[kid]

# Check and remove if expired (use same logic as JWKSetCache)
if time.monotonic() - timestamp > self._key_cache_ttl:
del self._key_cache[kid]
return None

return key

def _cache_key(self, kid: str, key: PyJWK) -> None:
"""Cache a key with current timestamp."""
if not self._key_cache_enabled:
return

# Evict oldest if at capacity
if len(self._key_cache) >= self._max_cached_keys and kid not in self._key_cache:
# Simple eviction: remove oldest timestamp
oldest_kid = min(
self._key_cache.keys(), key=lambda k: self._key_cache[k][1]
)
del self._key_cache[oldest_kid]

self._key_cache[kid] = (key, time.monotonic())

def get_signing_key(self, kid: str) -> PyJWK:
# Check TTL-aware cache first
cached_key = self._get_cached_key(kid)
if cached_key is not None:
return cached_key

signing_keys = self.get_signing_keys()
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
# If no matching signing key from the jwk set, refresh the jwk set and try again.
signing_keys = self.get_signing_keys(refresh=True)
signing_key = self.match_kid(signing_keys, kid)

Expand All @@ -109,6 +147,8 @@ def get_signing_key(self, kid: str) -> PyJWK:
f'Unable to find a signing key that matches: "{kid}"'
)

# Cache the key with TTL (not lru)
self._cache_key(kid, signing_key)
return signing_key

def get_signing_key_from_jwt(self, token: str | bytes) -> PyJWK:
Expand Down
76 changes: 76 additions & 0 deletions tests/test_jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,79 @@ def test_get_jwt_set_sslcontext_no_ca(self):
jwks_client.get_jwk_set()

assert "Failed to get an expected error"

def test_security_fix_revoked_keys_expire(self):
"""
Test that demonstrates the security fix working.

This test should:
- FAIL with old lru_cache implementation (serves revoked keys forever)
- PASS with new TTL implementation (revoked keys expire)
"""
from unittest.mock import MagicMock, patch

client = PyJWKClient("https://example.com", cache_keys=True, lifespan=0.1)

# Use the real RSA key from existing tests
real_rsa_key = {
"kid": "revoked-key-123",
"kty": "RSA",
"use": "sig",
"n": "0wtlJRY9-ru61LmOgieeI7_rD1oIna9QpBMAOWw8wTuoIhFQFwcIi7MFB7IEfelCPj08vkfLsuFtR8cG07EE4uvJ78bAqRjMsCvprWp4e2p7hqPnWcpRpDEyHjzirEJle1LPpjLLVaSWgkbrVaOD0lkWkP1T1TkrOset_Obh8BwtO-Ww-UfrEwxTyz1646AGkbT2nL8PX0trXrmira8GnrCkFUgTUS61GoTdb9bCJ19PLX9Gnxw7J0BtR0GubopXq8KlI0ThVql6ZtVGN2dvmrCPAVAZleM5TVB61m0VSXvGWaF6_GeOhbFoyWcyUmFvzWhBm8Q38vWgsSI7oHTkEw",
"e": "AQAB",
}

different_key = {
"kid": "different-key-456",
"kty": "RSA",
"use": "sig",
"n": "39SJ39VgrQ0qMNK74CaueUBlyYsUyuA7yWlHYZ-jAj6tlFKugEVUTBUVbhGF44uOr99iL_cwmr-srqQDEi-jFHdkS6WFkYyZ03oyyx5dtBMtzrXPieFipSGfQ5EGUGloaKDjL-Ry9tiLnysH2VVWZ5WDDN-DGHxuCOWWjiBNcTmGfnj5_NvRHNUh2iTLuiJpHbGcPzWc5-lc4r-_ehw9EFfp2XsxE9xvtbMZ4SouJCiv9xnrnhe2bdpWuu34hXZCrQwE8DjRY3UR8LjyMxHHPLzX2LWNMHjfN3nAZMteS-Ok11VYDFI-4qCCVGo_WesBCAeqCjPLRyZoV27x1YGsUQ",
"e": "AQAB",
}

jwks_with_key = {"keys": [real_rsa_key]}
jwks_key_revoked = {"keys": [different_key]}

mock_time = MagicMock()
mock_time.return_value = 1000.0

with (
patch("time.monotonic", mock_time),
patch.object(client, "fetch_data") as mock_fetch,
):
mock_fetch.return_value = jwks_with_key
key1 = client.get_signing_key("revoked-key-123")
assert key1.key_id == "revoked-key-123"

mock_time.return_value = 1000.15

mock_fetch.return_value = jwks_key_revoked

with pytest.raises(PyJWKClientError, match="Unable to find a signing key"):
client.get_signing_key("revoked-key-123")

def test_key_cache_eviction_when_at_capacity(self):
"""Test that key cache evicts oldest entries when at capacity."""
from unittest.mock import MagicMock

client = PyJWKClient("https://example.com", cache_keys=True, max_cached_keys=2)

key1 = MagicMock()
key1.key_id = "key1"
key2 = MagicMock()
key2.key_id = "key2"
key3 = MagicMock()
key3.key_id = "key3"

# Fill cache to capacity
client._cache_key("key1", key1)
client._cache_key("key2", key2)
assert len(client._key_cache) == 2

# Add third key - should evict oldest (key1)
client._cache_key("key3", key3)
assert len(client._key_cache) == 2

assert "key1" not in client._key_cache
assert "key2" in client._key_cache
assert "key3" in client._key_cache
Loading