11import base64
22import hashlib
3- import json
43import logging
54
65import inspect
6+ import jwt
77import requests
88from django .contrib .auth import get_user_model
99from django .contrib .auth .backends import ModelBackend
1010from django .core .exceptions import ImproperlyConfigured , SuspiciousOperation
1111from django .urls import reverse
12- from django .utils .encoding import force_bytes , smart_bytes , smart_str
12+ from django .utils .encoding import force_bytes , smart_str
1313from django .utils .module_loading import import_string
14- from josepy .b64 import b64decode
15- from josepy .jwk import JWK
16- from josepy .jws import JWS , Header
1714from requests .auth import HTTPBasicAuth
1815from requests .exceptions import HTTPError
1916
@@ -127,10 +124,10 @@ def update_user(self, user, claims):
127124
128125 def _verify_jws (self , payload , key ):
129126 """Verify the given JWS payload with the given key and return the payload"""
130- jws = JWS . from_compact (payload )
127+ jws = jwt . get_unverified_header (payload )
131128
132129 try :
133- alg = jws . signature . combined . alg . name
130+ alg = jws [ " alg" ]
134131 except KeyError :
135132 msg = "No alg value found in header"
136133 raise SuspiciousOperation (msg )
@@ -142,21 +139,19 @@ def _verify_jws(self, payload, key):
142139 )
143140 raise SuspiciousOperation (msg )
144141
145- if isinstance (key , str ):
146- # Use smart_bytes here since the key string comes from settings.
147- jwk = JWK .load (smart_bytes (key ))
148- else :
149- # The key is a json returned from the IDP JWKS endpoint.
150- jwk = JWK .from_json (key )
151-
152- if not jws .verify (jwk ):
142+ try :
143+ # Maybe add a settings to enforce audiance validation
144+ return jwt .decode (payload , key , algorithms = alg , options = {"verify_aud" : False })
145+ except jwt .DecodeError :
153146 msg = "JWS token verification failed."
154147 raise SuspiciousOperation (msg )
155148
156- return jws .payload
157-
158149 def retrieve_matching_jwk (self , token ):
159- """Get the signing key by exploring the JWKS endpoint of the OP."""
150+ """Get the signing key by exploring the JWKS endpoint of the OP.
151+
152+ Don't use jwt.PyJWKClient()get_signing_key_from_jwt() because it doesn't check
153+ the algorithm in case of multiple jwk with the same kid.
154+ """
160155 response_jwks = requests .get (
161156 self .OIDC_OP_JWKS_ENDPOINT ,
162157 verify = self .get_settings ("OIDC_VERIFY_SSL" , True ),
@@ -167,32 +162,29 @@ def retrieve_matching_jwk(self, token):
167162 jwks = response_jwks .json ()
168163
169164 # Compute the current header from the given token to find a match
170- jws = JWS .from_compact (token )
171- json_header = jws .signature .protected
172- header = Header .json_loads (json_header )
165+ jws = jwt .get_unverified_header (token )
173166
174167 key = None
175168 for jwk in jwks ["keys" ]:
176169 if import_from_settings ("OIDC_VERIFY_KID" , True ) and jwk [
177170 "kid"
178- ] != smart_str (header . kid ):
171+ ] != smart_str (jws [ " kid" ] ):
179172 continue
180- if "alg" in jwk and jwk ["alg" ] != smart_str (header . alg ):
173+ if "alg" in jwk and jwk ["alg" ] != smart_str (jws [ " alg" ] ):
181174 continue
182175 key = jwk
183176 if key is None :
184177 raise SuspiciousOperation ("Could not find a valid JWKS." )
185- return key
178+ return jwt . PyJWK ( key )
186179
187180 def get_payload_data (self , token , key ):
188181 """Helper method to get the payload of the JWT token."""
189182 if self .get_settings ("OIDC_ALLOW_UNSECURED_JWT" , False ):
190- header , payload_data , signature = token .split (b"." )
191- header = json .loads (smart_str (b64decode (header )))
183+ header = jwt .get_unverified_header (token )
192184
193185 # If config allows unsecured JWTs check the header and return the decoded payload
194186 if "alg" in header and header ["alg" ] == "none" :
195- return b64decode ( payload_data )
187+ return jwt . decode ( token , options = { "verify_signature" : False } )
196188
197189 # By default fallback to verify JWT signatures
198190 return self ._verify_jws (token , key )
@@ -201,7 +193,6 @@ def verify_token(self, token, **kwargs):
201193 """Validate the token signature."""
202194 nonce = kwargs .get ("nonce" )
203195
204- token = force_bytes (token )
205196 if self .OIDC_RP_SIGN_ALGO .startswith ("RS" ) or self .OIDC_RP_SIGN_ALGO .startswith (
206197 "ES"
207198 ):
@@ -212,16 +203,7 @@ def verify_token(self, token, **kwargs):
212203 else :
213204 key = self .OIDC_RP_CLIENT_SECRET
214205
215- payload_data = self .get_payload_data (token , key )
216-
217- # The 'token' will always be a byte string since it's
218- # the result of base64.urlsafe_b64decode().
219- # The payload is always the result of base64.urlsafe_b64decode().
220- # In Python 3 and 2, that's always a byte string.
221- # In Python3.6, the json.loads() function can accept a byte string
222- # as it will automagically decode it to a unicode string before
223- # deserializing https://bugs.python.org/issue17909
224- payload = json .loads (payload_data .decode ("utf-8" ))
206+ payload = self .get_payload_data (token , key )
225207 token_nonce = payload .get ("nonce" )
226208
227209 if self .get_settings ("OIDC_USE_NONCE" , True ) and nonce != token_nonce :
0 commit comments