11from __future__ import annotations
22
33import json
4+ import time
45import urllib .request
5- from functools import lru_cache
66from ssl import SSLContext
77from typing import Any , Dict , List , Optional
88from urllib .error import URLError
@@ -44,12 +44,15 @@ def __init__(
4444 else :
4545 self .jwk_set_cache = None
4646
47+ # Replace lru_cache with TTL-aware individual key cache
48+ # Use the same TTL as JWKSetCache for consistency
4749 if cache_keys :
48- # Cache signing keys
49- # Ignore mypy (https://github.com/python/mypy/issues/2427)
50- self .get_signing_key = lru_cache (maxsize = max_cached_keys )(
51- self .get_signing_key
52- ) # type: ignore
50+ self ._key_cache_enabled = True
51+ self ._key_cache : Dict [str , tuple [PyJWK , float ]] = {} # kid -> (key, timestamp)
52+ self ._max_cached_keys = max_cached_keys
53+ self ._key_cache_ttl = lifespan # Use same TTL as JWKSetCache
54+ else :
55+ self ._key_cache_enabled = False
5356
5457 def fetch_data (self ) -> Any :
5558 jwk_set : Any = None
@@ -95,12 +98,44 @@ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
9598
9699 return signing_keys
97100
101+ def _get_cached_key (self , kid : str ) -> Optional [PyJWK ]:
102+ """Get a cached key if it exists and hasn't expired."""
103+ if not self ._key_cache_enabled or kid not in self ._key_cache :
104+ return None
105+
106+ key , timestamp = self ._key_cache [kid ]
107+
108+ # Check and remove if expired (use same logic as JWKSetCache)
109+ if time .monotonic () - timestamp > self ._key_cache_ttl :
110+ del self ._key_cache [kid ]
111+ return None
112+
113+ return key
114+
115+ def _cache_key (self , kid : str , key : PyJWK ) -> None :
116+ """Cache a key with current timestamp."""
117+ if not self ._key_cache_enabled :
118+ return
119+
120+ # Evict oldest if at capacity
121+ if len (self ._key_cache ) >= self ._max_cached_keys and kid not in self ._key_cache :
122+ # Simple eviction: remove oldest timestamp
123+ oldest_kid = min (self ._key_cache .keys (),
124+ key = lambda k : self ._key_cache [k ][1 ])
125+ del self ._key_cache [oldest_kid ]
126+
127+ self ._key_cache [kid ] = (key , time .monotonic ())
128+
98129 def get_signing_key (self , kid : str ) -> PyJWK :
130+ # Check TTL-aware cache first
131+ cached_key = self ._get_cached_key (kid )
132+ if cached_key is not None :
133+ return cached_key
134+
99135 signing_keys = self .get_signing_keys ()
100136 signing_key = self .match_kid (signing_keys , kid )
101137
102138 if not signing_key :
103- # If no matching signing key from the jwk set, refresh the jwk set and try again.
104139 signing_keys = self .get_signing_keys (refresh = True )
105140 signing_key = self .match_kid (signing_keys , kid )
106141
@@ -109,6 +144,8 @@ def get_signing_key(self, kid: str) -> PyJWK:
109144 f'Unable to find a signing key that matches: "{ kid } "'
110145 )
111146
147+ # Cache the key with TTL (not lru)
148+ self ._cache_key (kid , signing_key )
112149 return signing_key
113150
114151 def get_signing_key_from_jwt (self , token : str | bytes ) -> PyJWK :
0 commit comments