@@ -18,55 +18,68 @@ class JWTService:
18
18
def __init__ (self , jwt_config : JWTConfiguration ) -> None :
19
19
self .jwt_config = jwt_config
20
20
21
- self .jwks_client = (
22
- PyJWKClient (self .jwt_config .jwk_url ) if self .jwt_config .jwk_url else None
23
- )
24
- self .leeway = self .jwt_config .leeway
21
+ def get_jwks_client (self , jwt_config : JWTConfiguration ) -> t .Optional [PyJWKClient ]:
22
+ jwks_client = PyJWKClient (jwt_config .jwk_url ) if jwt_config .jwk_url else None
23
+ return jwks_client
25
24
26
- def get_leeway (self ) -> timedelta :
27
- if self .leeway is None :
25
+ def get_leeway (self , jwt_config : JWTConfiguration ) -> timedelta :
26
+ if jwt_config .leeway is None :
28
27
return timedelta (seconds = 0 )
29
- elif isinstance (self .leeway , (int , float )):
30
- return timedelta (seconds = self .leeway )
31
- elif isinstance (self .leeway , timedelta ):
32
- return self .leeway
28
+ elif isinstance (jwt_config .leeway , (int , float )):
29
+ return timedelta (seconds = jwt_config .leeway )
30
+ elif isinstance (jwt_config .leeway , timedelta ):
31
+ return jwt_config .leeway
33
32
34
- def get_verifying_key (self , token : t .Any ) -> bytes :
33
+ def get_verifying_key (self , token : t .Any , jwt_config : JWTConfiguration ) -> bytes :
35
34
if self .jwt_config .algorithm .startswith ("HS" ):
36
- return self . jwt_config .signing_secret_key .encode ()
35
+ return jwt_config .signing_secret_key .encode ()
37
36
38
- if self .jwks_client :
37
+ jwks_client = self .get_jwks_client (jwt_config )
38
+ if jwks_client :
39
39
try :
40
- p_jwk = self . jwks_client .get_signing_key_from_jwt (token )
40
+ p_jwk = jwks_client .get_signing_key_from_jwt (token )
41
41
return p_jwk .key # type:ignore[no-any-return]
42
42
except PyJWKClientError as ex :
43
43
raise JWTTokenException ("Token is invalid or expired" ) from ex
44
44
45
- return self .jwt_config .verifying_secret_key .encode ()
45
+ return jwt_config .verifying_secret_key .encode ()
46
+
47
+ def _merge_configurations (self , ** jwt_config : t .Any ) -> JWTConfiguration :
48
+ jwt_config_default = self .jwt_config .dict ()
49
+ jwt_config_default .update (jwt_config )
50
+ return JWTConfiguration (** jwt_config_default )
46
51
47
52
def sign (
48
- self , payload : dict , headers : t .Optional [t .Dict [str , t .Any ]] = None
53
+ self ,
54
+ payload : dict ,
55
+ headers : t .Optional [t .Dict [str , t .Any ]] = None ,
56
+ ** jwt_config : t .Any ,
49
57
) -> str :
50
58
"""
51
59
Returns an encoded token for the given payload dictionary.
52
60
"""
53
-
54
- jwt_payload = Token (jwt_config = self . jwt_config ).build (payload .copy ())
61
+ _jwt_config = self . _merge_configurations ( ** jwt_config )
62
+ jwt_payload = Token (jwt_config = _jwt_config ).build (payload .copy ())
55
63
56
64
return jwt .encode (
57
65
jwt_payload ,
58
- self . jwt_config .signing_secret_key ,
59
- algorithm = self . jwt_config .algorithm ,
60
- json_encoder = self . jwt_config .json_encoder ,
66
+ _jwt_config .signing_secret_key ,
67
+ algorithm = _jwt_config .algorithm ,
68
+ json_encoder = _jwt_config .json_encoder ,
61
69
headers = headers ,
62
70
)
63
71
64
72
async def sign_async (
65
- self , payload : dict , headers : t .Optional [t .Dict [str , t .Any ]] = None
73
+ self ,
74
+ payload : dict ,
75
+ headers : t .Optional [t .Dict [str , t .Any ]] = None ,
76
+ ** jwt_config : t .Any ,
66
77
) -> str :
67
- return await anyio .to_thread .run_sync (self .sign , payload , headers )
78
+ return await anyio .to_thread .run_sync (self .sign , payload , headers , ** jwt_config )
68
79
69
- def decode (self , token : str , verify : bool = True ) -> t .Dict [str , t .Any ]:
80
+ def decode (
81
+ self , token : str , verify : bool = True , ** jwt_config : t .Any
82
+ ) -> t .Dict [str , t .Any ]:
70
83
"""
71
84
Performs a validation of the given token and returns its payload
72
85
dictionary.
@@ -75,15 +88,16 @@ def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
75
88
signature check fails, or if its 'exp' claim indicates it has expired.
76
89
"""
77
90
try :
91
+ _jwt_config = self ._merge_configurations (** jwt_config )
78
92
return jwt .decode ( # type: ignore[no-any-return]
79
93
token ,
80
- self .get_verifying_key (token ),
81
- algorithms = [self . jwt_config .algorithm ],
82
- audience = self . jwt_config .audience ,
83
- issuer = self . jwt_config .issuer ,
84
- leeway = self .get_leeway (),
94
+ self .get_verifying_key (token , _jwt_config ),
95
+ algorithms = [_jwt_config .algorithm ],
96
+ audience = _jwt_config .audience ,
97
+ issuer = _jwt_config .issuer ,
98
+ leeway = self .get_leeway (_jwt_config ),
85
99
options = {
86
- "verify_aud" : self . jwt_config .audience is not None ,
100
+ "verify_aud" : _jwt_config .audience is not None ,
87
101
"verify_signature" : verify ,
88
102
},
89
103
)
@@ -92,5 +106,7 @@ def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
92
106
except InvalidTokenError as ex :
93
107
raise JWTTokenException ("Token is invalid or expired" ) from ex
94
108
95
- async def decode_async (self , token : str , verify : bool = True ) -> t .Dict [str , t .Any ]:
96
- return await anyio .to_thread .run_sync (self .decode , token , verify )
109
+ async def decode_async (
110
+ self , token : str , verify : bool = True , ** jwt_config : t .Any
111
+ ) -> t .Dict [str , t .Any ]:
112
+ return await anyio .to_thread .run_sync (self .decode , token , verify , ** jwt_config )
0 commit comments