77
88from __future__ import annotations
99
10- from collections .abc import Container , Iterable , Sequence
1110from datetime import timedelta
12- from functools import lru_cache
13- from typing import Any , override
11+ from functools import lru_cache , partial
12+ from typing import Any
1413
15- from jwt import DecodeError , PyJWK , PyJWS , PyJWT
16- from jwt .algorithms import AllowedPublicKeys
17- from jwt .types import Options
14+ from jwt import DecodeError , PyJWS , PyJWT
1815
1916from homeassistant .util .json import json_loads
2017
2118JWT_TOKEN_CACHE_SIZE = 16
2219MAX_TOKEN_SIZE = 8192
2320
24- _NO_VERIFY_OPTIONS = Options (
25- verify_signature = False ,
26- verify_exp = False ,
27- verify_nbf = False ,
28- verify_iat = False ,
29- verify_aud = False ,
30- verify_iss = False ,
31- verify_sub = False ,
32- verify_jti = False ,
33- require = [],
34- )
21+ _VERIFY_KEYS = ("signature" , "exp" , "nbf" , "iat" , "aud" , "iss" , "sub" , "jti" )
22+
23+ _VERIFY_OPTIONS : dict [str , Any ] = {f"verify_{ key } " : True for key in _VERIFY_KEYS } | {
24+ "require" : []
25+ }
26+ _NO_VERIFY_OPTIONS = {f"verify_{ key } " : False for key in _VERIFY_KEYS }
3527
3628
3729class _PyJWSWithLoadCache (PyJWS ):
@@ -46,6 +38,9 @@ def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
4638 return super ()._load (jwt )
4739
4840
41+ _jws = _PyJWSWithLoadCache ()
42+
43+
4944@lru_cache (maxsize = JWT_TOKEN_CACHE_SIZE )
5045def _decode_payload (json_payload : str ) -> dict [str , Any ]:
5146 """Decode the payload from a JWS dictionary."""
@@ -61,12 +56,21 @@ def _decode_payload(json_payload: str) -> dict[str, Any]:
6156class _PyJWTWithVerify (PyJWT ):
6257 """PyJWT with a fast decode implementation."""
6358
64- def __init__ (self ) -> None :
65- """Initialize the PyJWT instance."""
66- # We require exp and iat claims to be present
67- super ().__init__ (Options (require = ["exp" , "iat" ]))
68- # Override the _jws instance with our cached version
69- self ._jws = _PyJWSWithLoadCache ()
59+ def decode_payload (
60+ self , jwt : str , key : str , options : dict [str , Any ], algorithms : list [str ]
61+ ) -> dict [str , Any ]:
62+ """Decode a JWT's payload."""
63+ if len (jwt ) > MAX_TOKEN_SIZE :
64+ # Avoid caching impossible tokens
65+ raise DecodeError ("Token too large" )
66+ return _decode_payload (
67+ _jws .decode_complete (
68+ jwt = jwt ,
69+ key = key ,
70+ algorithms = algorithms ,
71+ options = options ,
72+ )["payload" ]
73+ )
7074
7175 def verify_and_decode (
7276 self ,
@@ -75,70 +79,37 @@ def verify_and_decode(
7579 algorithms : list [str ],
7680 issuer : str | None = None ,
7781 leeway : float | timedelta = 0 ,
78- options : Options | None = None ,
82+ options : dict [ str , Any ] | None = None ,
7983 ) -> dict [str , Any ]:
8084 """Verify a JWT's signature and claims."""
81- return self .decode (
85+ merged_options = {** _VERIFY_OPTIONS , ** (options or {})}
86+ payload = self .decode_payload (
8287 jwt = jwt ,
8388 key = key ,
89+ options = merged_options ,
8490 algorithms = algorithms ,
85- issuer = issuer ,
86- leeway = leeway ,
87- options = options ,
8891 )
89-
90- @override
91- def decode (
92- self ,
93- jwt : str | bytes ,
94- key : AllowedPublicKeys | PyJWK | str | bytes = "" ,
95- algorithms : Sequence [str ] | None = None ,
96- options : Options | None = None ,
97- verify : bool | None = None ,
98- detached_payload : bytes | None = None ,
99- audience : str | Iterable [str ] | None = None ,
100- subject : str | None = None ,
101- issuer : str | Container [str ] | None = None ,
102- leeway : float | timedelta = 0 ,
103- ** kwargs : Any ,
104- ) -> dict [str , Any ]:
105- """Decode a JWT, verifying the signature and claims."""
106- if len (jwt ) > MAX_TOKEN_SIZE :
107- # Avoid caching impossible tokens
108- raise DecodeError ("Token too large" )
109- return super ().decode (
110- jwt = jwt ,
111- key = key ,
112- algorithms = algorithms ,
113- options = options ,
114- verify = verify ,
115- detached_payload = detached_payload ,
116- audience = audience ,
117- subject = subject ,
92+ # These should never be missing since we verify them
93+ # but this is an additional safeguard to make sure
94+ # nothing slips through.
95+ assert "exp" in payload , "exp claim is required"
96+ assert "iat" in payload , "iat claim is required"
97+ self ._validate_claims (
98+ payload = payload ,
99+ options = merged_options ,
118100 issuer = issuer ,
119101 leeway = leeway ,
120- ** kwargs ,
121102 )
122-
123- @override
124- def _decode_payload (self , decoded : dict [str , Any ]) -> dict [str , Any ]:
125- return _decode_payload (decoded ["payload" ])
103+ return payload
126104
127105
128106_jwt = _PyJWTWithVerify ()
129107verify_and_decode = _jwt .verify_and_decode
130-
131-
132- @lru_cache (maxsize = JWT_TOKEN_CACHE_SIZE )
133- def unverified_hs256_token_decode (jwt : str ) -> dict [str , Any ]:
134- """Decode a JWT without verifying the signature."""
135- return _jwt .decode (
136- jwt = jwt ,
137- key = "" ,
138- algorithms = ["HS256" ],
139- options = _NO_VERIFY_OPTIONS ,
108+ unverified_hs256_token_decode = lru_cache (maxsize = JWT_TOKEN_CACHE_SIZE )(
109+ partial (
110+ _jwt .decode_payload , key = "" , algorithms = ["HS256" ], options = _NO_VERIFY_OPTIONS
140111 )
141-
112+ )
142113
143114__all__ = [
144115 "unverified_hs256_token_decode" ,
0 commit comments