Skip to content

Commit 7fa9f23

Browse files
committed
made configuration more dynamic
1 parent e9c72f2 commit 7fa9f23

File tree

4 files changed

+53
-39
lines changed

4 files changed

+53
-39
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,16 @@ JSON Encoder class that will be used by the `PYJWT` to encode the `jwt_payload`.
198198

199199
The `JwtService` uses [PYJWT](https://pypi.org/project/PyJWT/) underneath.
200200

201-
### _jwt_service.sign(payload: dict, headers: Dict[str, t.Any] = None) -> str_
202-
Creates a jwt token for the provided payload
201+
### _jwt_service.sign(payload: dict, headers: Dict[str, t.Any] = None, **jwt_config: t.Any) -> str_
202+
Creates a jwt token for the provided payload. Also, you can override default jwt config by using passing some keyword argument as a `jwt_config`
203203

204-
### _jwt_service.sign_async(payload: dict, headers: Dict[str, t.Any] = None) -> str_
204+
### _jwt_service.sign_async(payload: dict, headers: Dict[str, t.Any] = None, **jwt_config: t.Any) -> str_
205205
Async action for `jwt_service.sign`
206206

207-
### _jwt_service.decode(token: str, verify: bool = True) -> t.Dict[str, t.Any]:_
207+
### _jwt_service.decode(token: str, verify: bool = True, **jwt_config: t.Any) -> t.Dict[str, t.Any]:_
208208
Verifies and decodes provided token. And raises JWTException exception if token is invalid or expired
209209

210-
### _jwt_service.decode_async(token: str, verify: bool = True) -> t.Dict[str, t.Any]:_
210+
### _jwt_service.decode_async(token: str, verify: bool = True, **jwt_config: t.Any) -> t.Dict[str, t.Any]:_
211211
Async action for `jwt_service.decode`
212212

213213

ellar_jwt/services.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,55 +18,68 @@ class JWTService:
1818
def __init__(self, jwt_config: JWTConfiguration) -> None:
1919
self.jwt_config = jwt_config
2020

21-
self.jwks_client = (
22-
PyJWKClient(self.jwt_config.jwk_url) if self.jwt_config.jwk_url else None
23-
)
24-
self.leeway = self.jwt_config.leeway
21+
def get_jwks_client(self, jwt_config: JWTConfiguration) -> t.Optional[PyJWKClient]:
22+
jwks_client = PyJWKClient(jwt_config.jwk_url) if jwt_config.jwk_url else None
23+
return jwks_client
2524

26-
def get_leeway(self) -> timedelta:
27-
if self.leeway is None:
25+
def get_leeway(self, jwt_config: JWTConfiguration) -> timedelta:
26+
if jwt_config.leeway is None:
2827
return timedelta(seconds=0)
29-
elif isinstance(self.leeway, (int, float)):
30-
return timedelta(seconds=self.leeway)
31-
elif isinstance(self.leeway, timedelta):
32-
return self.leeway
28+
elif isinstance(jwt_config.leeway, (int, float)):
29+
return timedelta(seconds=jwt_config.leeway)
30+
elif isinstance(jwt_config.leeway, timedelta):
31+
return jwt_config.leeway
3332

34-
def get_verifying_key(self, token: t.Any) -> bytes:
33+
def get_verifying_key(self, token: t.Any, jwt_config: JWTConfiguration) -> bytes:
3534
if self.jwt_config.algorithm.startswith("HS"):
36-
return self.jwt_config.signing_secret_key.encode()
35+
return jwt_config.signing_secret_key.encode()
3736

38-
if self.jwks_client:
37+
jwks_client = self.get_jwks_client(jwt_config)
38+
if jwks_client:
3939
try:
40-
p_jwk = self.jwks_client.get_signing_key_from_jwt(token)
40+
p_jwk = jwks_client.get_signing_key_from_jwt(token)
4141
return p_jwk.key # type:ignore[no-any-return]
4242
except PyJWKClientError as ex:
4343
raise JWTTokenException("Token is invalid or expired") from ex
4444

45-
return self.jwt_config.verifying_secret_key.encode()
45+
return jwt_config.verifying_secret_key.encode()
46+
47+
def _merge_configurations(self, **jwt_config: t.Any) -> JWTConfiguration:
48+
jwt_config_default = self.jwt_config.dict()
49+
jwt_config_default.update(jwt_config)
50+
return JWTConfiguration(**jwt_config_default)
4651

4752
def sign(
48-
self, payload: dict, headers: t.Optional[t.Dict[str, t.Any]] = None
53+
self,
54+
payload: dict,
55+
headers: t.Optional[t.Dict[str, t.Any]] = None,
56+
**jwt_config: t.Any,
4957
) -> str:
5058
"""
5159
Returns an encoded token for the given payload dictionary.
5260
"""
53-
54-
jwt_payload = Token(jwt_config=self.jwt_config).build(payload.copy())
61+
_jwt_config = self._merge_configurations(**jwt_config)
62+
jwt_payload = Token(jwt_config=_jwt_config).build(payload.copy())
5563

5664
return jwt.encode(
5765
jwt_payload,
58-
self.jwt_config.signing_secret_key,
59-
algorithm=self.jwt_config.algorithm,
60-
json_encoder=self.jwt_config.json_encoder,
66+
_jwt_config.signing_secret_key,
67+
algorithm=_jwt_config.algorithm,
68+
json_encoder=_jwt_config.json_encoder,
6169
headers=headers,
6270
)
6371

6472
async def sign_async(
65-
self, payload: dict, headers: t.Optional[t.Dict[str, t.Any]] = None
73+
self,
74+
payload: dict,
75+
headers: t.Optional[t.Dict[str, t.Any]] = None,
76+
**jwt_config: t.Any,
6677
) -> str:
67-
return await anyio.to_thread.run_sync(self.sign, payload, headers)
78+
return await anyio.to_thread.run_sync(self.sign, payload, headers, **jwt_config)
6879

69-
def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
80+
def decode(
81+
self, token: str, verify: bool = True, **jwt_config: t.Any
82+
) -> t.Dict[str, t.Any]:
7083
"""
7184
Performs a validation of the given token and returns its payload
7285
dictionary.
@@ -75,15 +88,16 @@ def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
7588
signature check fails, or if its 'exp' claim indicates it has expired.
7689
"""
7790
try:
91+
_jwt_config = self._merge_configurations(**jwt_config)
7892
return jwt.decode( # type: ignore[no-any-return]
7993
token,
80-
self.get_verifying_key(token),
81-
algorithms=[self.jwt_config.algorithm],
82-
audience=self.jwt_config.audience,
83-
issuer=self.jwt_config.issuer,
84-
leeway=self.get_leeway(),
94+
self.get_verifying_key(token, _jwt_config),
95+
algorithms=[_jwt_config.algorithm],
96+
audience=_jwt_config.audience,
97+
issuer=_jwt_config.issuer,
98+
leeway=self.get_leeway(_jwt_config),
8599
options={
86-
"verify_aud": self.jwt_config.audience is not None,
100+
"verify_aud": _jwt_config.audience is not None,
87101
"verify_signature": verify,
88102
},
89103
)
@@ -92,5 +106,7 @@ def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
92106
except InvalidTokenError as ex:
93107
raise JWTTokenException("Token is invalid or expired") from ex
94108

95-
async def decode_async(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
96-
return await anyio.to_thread.run_sync(self.decode, token, verify)
109+
async def decode_async(
110+
self, token: str, verify: bool = True, **jwt_config: t.Any
111+
) -> t.Dict[str, t.Any]:
112+
return await anyio.to_thread.run_sync(self.decode, token, verify, **jwt_config)

ellar_jwt/token.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def __init__(self, jwt_config: JWTConfiguration) -> None:
1616
self.payload: t.Dict = {}
1717

1818
def build(self, payload: t.Dict) -> t.Dict:
19-
2019
# Set "exp" and "iat" claims with default value
2120
self.set_exp()
2221
self.set_iat()

tests/test_jwt_service.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def test_decode_with_expiry(self, title, backend):
242242
],
243243
)
244244
def test_decode_with_invalid_sig(self, title, backend):
245-
246245
payload = self.payload.copy()
247246
payload["exp"] = aware_utcnow() + timedelta(days=1)
248247
token_1 = jwt.encode(

0 commit comments

Comments
 (0)