99JWKS_URL = settings .JWKS_URL
1010
1111
12- def get_jwk ():
12+ def get_jwks ():
1313 if settings .JWKS_CACHE :
14- key = settings .JWKS_CACHE
15- return key
14+ # Check to make sure we have a cached JWK for each key, otherwise refetch all.
15+ missing = False
16+ for key in settings .JWKS_URL .keys ():
17+ if key not in settings .JWKS_CACHE :
18+ missing = True
19+
20+ if not missing :
21+ return settings .JWKS_CACHE
1622
1723 # Make a request to get the JWKS
18- try :
19- response = requests .get (JWKS_URL )
20- jwks = jwk .JWKSet .from_json (response .text )
21- settings .JWKS_CACHE = jwks
22- except Exception as e :
23- raise HTTPException (status_code = 500 , detail = str (e ))
24-
24+ for key in settings .JWKS_URL :
25+ try :
26+ response = requests .get (JWKS_URL )
27+ jwks = jwk .JWKSet .from_json (response .text )
28+ settings .JWKS_CACHE [key ] = jwks
29+ except Exception as e :
30+ del settings .JWKS_CACHE [key ]
31+ raise HTTPException (status_code = 500 , detail = str (e ))
2532 return settings .JWKS_CACHE
2633
2734
@@ -41,11 +48,12 @@ def get_token_from_header(request: Request):
4148
4249
4350def verify_jwt (token : str = Depends (get_token_from_header )):
44- try :
45- # Load the public key
46- public_key = get_jwk ()
47- # Decode and verify the JWT
48- decoded_token = jwt .JWT (key = public_key , jwt = token )
49- return decoded_token .claims
50- except Exception as e :
51- raise HTTPException (status_code = 401 , detail = str (e ))
51+ public_keys = get_jwks ()
52+ for key in public_keys :
53+ try :
54+ decoded_token = jwt .JWT (key = key , jwt = token )
55+ return decoded_token .claims
56+ except :
57+ pass
58+
59+ raise HTTPException (status_code = 401 , detail = "Failed to verify JWT token against public keys array." )
0 commit comments