Skip to content

Commit c06a13f

Browse files
authored
Merge pull request #7 from eadwinCode/dynamic_config
Dynamic Config
2 parents 10f8e4b + 7fa9f23 commit c06a13f

File tree

2 files changed

+53
-37
lines changed

2 files changed

+53
-37
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,16 @@ JSON Encoder class that will be used by the `PYJWT` to encode the `jwt_payload`.
196196

197197
The `JwtService` uses [PYJWT](https://pypi.org/project/PyJWT/) underneath.
198198

199-
### _jwt_service.sign(payload: dict, headers: Dict[str, t.Any] = None) -> str_
200-
Creates a jwt token for the provided payload
199+
### _jwt_service.sign(payload: dict, headers: Dict[str, t.Any] = None, **jwt_config: t.Any) -> str_
200+
Creates a jwt token for the provided payload. Also, you can override default jwt config by using passing some keyword argument as a `jwt_config`
201201

202-
### _jwt_service.sign_async(payload: dict, headers: Dict[str, t.Any] = None) -> str_
202+
### _jwt_service.sign_async(payload: dict, headers: Dict[str, t.Any] = None, **jwt_config: t.Any) -> str_
203203
Async action for `jwt_service.sign`
204204

205-
### _jwt_service.decode(token: str, verify: bool = True) -> t.Dict[str, t.Any]:_
205+
### _jwt_service.decode(token: str, verify: bool = True, **jwt_config: t.Any) -> t.Dict[str, t.Any]:_
206206
Verifies and decodes provided token. And raises JWTException exception if token is invalid or expired
207207

208-
### _jwt_service.decode_async(token: str, verify: bool = True) -> t.Dict[str, t.Any]:_
208+
### _jwt_service.decode_async(token: str, verify: bool = True, **jwt_config: t.Any) -> t.Dict[str, t.Any]:_
209209
Async action for `jwt_service.decode`
210210

211211

ellar_jwt/services.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,55 +18,68 @@ class JWTService:
1818
def __init__(self, jwt_config: JWTConfiguration) -> None:
1919
self.jwt_config = jwt_config
2020

21-
self.jwks_client = (
22-
PyJWKClient(self.jwt_config.jwk_url) if self.jwt_config.jwk_url else None
23-
)
24-
self.leeway = self.jwt_config.leeway
21+
def get_jwks_client(self, jwt_config: JWTConfiguration) -> t.Optional[PyJWKClient]:
22+
jwks_client = PyJWKClient(jwt_config.jwk_url) if jwt_config.jwk_url else None
23+
return jwks_client
2524

26-
def get_leeway(self) -> timedelta:
27-
if self.leeway is None:
25+
def get_leeway(self, jwt_config: JWTConfiguration) -> timedelta:
26+
if jwt_config.leeway is None:
2827
return timedelta(seconds=0)
29-
elif isinstance(self.leeway, (int, float)):
30-
return timedelta(seconds=self.leeway)
31-
elif isinstance(self.leeway, timedelta):
32-
return self.leeway
28+
elif isinstance(jwt_config.leeway, (int, float)):
29+
return timedelta(seconds=jwt_config.leeway)
30+
elif isinstance(jwt_config.leeway, timedelta):
31+
return jwt_config.leeway
3332

34-
def get_verifying_key(self, token: t.Any) -> bytes:
33+
def get_verifying_key(self, token: t.Any, jwt_config: JWTConfiguration) -> bytes:
3534
if self.jwt_config.algorithm.startswith("HS"):
36-
return self.jwt_config.signing_secret_key.encode()
35+
return jwt_config.signing_secret_key.encode()
3736

38-
if self.jwks_client:
37+
jwks_client = self.get_jwks_client(jwt_config)
38+
if jwks_client:
3939
try:
40-
p_jwk = self.jwks_client.get_signing_key_from_jwt(token)
40+
p_jwk = jwks_client.get_signing_key_from_jwt(token)
4141
return p_jwk.key # type:ignore[no-any-return]
4242
except PyJWKClientError as ex:
4343
raise JWTTokenException("Token is invalid or expired") from ex
4444

45-
return self.jwt_config.verifying_secret_key.encode()
45+
return jwt_config.verifying_secret_key.encode()
46+
47+
def _merge_configurations(self, **jwt_config: t.Any) -> JWTConfiguration:
48+
jwt_config_default = self.jwt_config.dict()
49+
jwt_config_default.update(jwt_config)
50+
return JWTConfiguration(**jwt_config_default)
4651

4752
def sign(
48-
self, payload: dict, headers: t.Optional[t.Dict[str, t.Any]] = None
53+
self,
54+
payload: dict,
55+
headers: t.Optional[t.Dict[str, t.Any]] = None,
56+
**jwt_config: t.Any,
4957
) -> str:
5058
"""
5159
Returns an encoded token for the given payload dictionary.
5260
"""
53-
54-
jwt_payload = Token(jwt_config=self.jwt_config).build(payload.copy())
61+
_jwt_config = self._merge_configurations(**jwt_config)
62+
jwt_payload = Token(jwt_config=_jwt_config).build(payload.copy())
5563

5664
return jwt.encode(
5765
jwt_payload,
58-
self.jwt_config.signing_secret_key,
59-
algorithm=self.jwt_config.algorithm,
60-
json_encoder=self.jwt_config.json_encoder,
66+
_jwt_config.signing_secret_key,
67+
algorithm=_jwt_config.algorithm,
68+
json_encoder=_jwt_config.json_encoder,
6169
headers=headers,
6270
)
6371

6472
async def sign_async(
65-
self, payload: dict, headers: t.Optional[t.Dict[str, t.Any]] = None
73+
self,
74+
payload: dict,
75+
headers: t.Optional[t.Dict[str, t.Any]] = None,
76+
**jwt_config: t.Any,
6677
) -> str:
67-
return await anyio.to_thread.run_sync(self.sign, payload, headers)
78+
return await anyio.to_thread.run_sync(self.sign, payload, headers, **jwt_config)
6879

69-
def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
80+
def decode(
81+
self, token: str, verify: bool = True, **jwt_config: t.Any
82+
) -> t.Dict[str, t.Any]:
7083
"""
7184
Performs a validation of the given token and returns its payload
7285
dictionary.
@@ -75,15 +88,16 @@ def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
7588
signature check fails, or if its 'exp' claim indicates it has expired.
7689
"""
7790
try:
91+
_jwt_config = self._merge_configurations(**jwt_config)
7892
return jwt.decode( # type: ignore[no-any-return]
7993
token,
80-
self.get_verifying_key(token),
81-
algorithms=[self.jwt_config.algorithm],
82-
audience=self.jwt_config.audience,
83-
issuer=self.jwt_config.issuer,
84-
leeway=self.get_leeway(),
94+
self.get_verifying_key(token, _jwt_config),
95+
algorithms=[_jwt_config.algorithm],
96+
audience=_jwt_config.audience,
97+
issuer=_jwt_config.issuer,
98+
leeway=self.get_leeway(_jwt_config),
8599
options={
86-
"verify_aud": self.jwt_config.audience is not None,
100+
"verify_aud": _jwt_config.audience is not None,
87101
"verify_signature": verify,
88102
},
89103
)
@@ -92,5 +106,7 @@ def decode(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
92106
except InvalidTokenError as ex:
93107
raise JWTTokenException("Token is invalid or expired") from ex
94108

95-
async def decode_async(self, token: str, verify: bool = True) -> t.Dict[str, t.Any]:
96-
return await anyio.to_thread.run_sync(self.decode, token, verify)
109+
async def decode_async(
110+
self, token: str, verify: bool = True, **jwt_config: t.Any
111+
) -> t.Dict[str, t.Any]:
112+
return await anyio.to_thread.run_sync(self.decode, token, verify, **jwt_config)

0 commit comments

Comments
 (0)