6565 load_ssh_public_key ,
6666 )
6767
68+ # pyjwt-964: we use these both for type checking below, as well as for validating the key passed in.
69+ # in Py >= 3.10, we can replace this with the Union types below
70+ ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey , RSAPublicKey )
71+ ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey , EllipticCurvePublicKey )
72+ ALLOWED_OKP_KEY_TYPES = (
73+ Ed25519PrivateKey ,
74+ Ed25519PublicKey ,
75+ Ed448PrivateKey ,
76+ Ed448PublicKey ,
77+ )
78+ ALLOWED_KEY_TYPES = (
79+ ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES
80+ )
81+ ALLOWED_PRIVATE_KEY_TYPES = (
82+ RSAPrivateKey ,
83+ EllipticCurvePrivateKey ,
84+ Ed25519PrivateKey ,
85+ Ed448PrivateKey ,
86+ )
87+ ALLOWED_PUBLIC_KEY_TYPES = (
88+ RSAPublicKey ,
89+ EllipticCurvePublicKey ,
90+ Ed25519PublicKey ,
91+ Ed448PublicKey ,
92+ )
93+
6894 has_crypto = True
6995except ModuleNotFoundError :
7096 has_crypto = False
7197
7298
7399if TYPE_CHECKING :
100+ from typing import TypeAlias
101+
102+ from cryptography .hazmat .primitives .asymmetric .types import (
103+ PrivateKeyTypes ,
104+ PublicKeyTypes ,
105+ )
106+
74107 # Type aliases for convenience in algorithms method signatures
75- AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
76- AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
77- AllowedOKPKeys = (
108+ AllowedRSAKeys : TypeAlias = RSAPrivateKey | RSAPublicKey
109+ AllowedECKeys : TypeAlias = EllipticCurvePrivateKey | EllipticCurvePublicKey
110+ AllowedOKPKeys : TypeAlias = (
78111 Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
79112 )
80- AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
81- AllowedPrivateKeys = (
113+ AllowedKeys : TypeAlias = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
114+ AllowedPrivateKeys : TypeAlias = (
82115 RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
83116 )
84- AllowedPublicKeys = (
117+ AllowedPublicKeys : TypeAlias = (
85118 RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
86119 )
87120
@@ -141,6 +174,9 @@ class Algorithm(ABC):
141174 The interface for an algorithm used to sign and verify tokens.
142175 """
143176
177+ # pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
178+ _crypto_key_types : tuple [type [AllowedKeys ], ...] | None = None
179+
144180 def compute_hash_digest (self , bytestr : bytes ) -> bytes :
145181 """
146182 Compute a hash digest using the specified algorithm's hash algorithm.
@@ -163,6 +199,30 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
163199 else :
164200 return bytes (hash_alg (bytestr ).digest ())
165201
202+ def check_crypto_key_type (self , key : PublicKeyTypes | PrivateKeyTypes ):
203+ """Check that the key belongs to the right cryptographic family.
204+
205+ Note that this method only works when `cryptography` is installed.
206+
207+ Args:
208+ key (Any): Potentially a cryptography key
209+ Raises:
210+ ValueError: if `cryptography` is not installed, or this method is called by a non-cryptography algorithm
211+ InvalidKeyError: if the key doesn't match the expected key classes
212+ """
213+ if not has_crypto or self ._crypto_key_types is None :
214+ raise ValueError (
215+ "This method requires the cryptography library, and should only be used by cryptography-based algorithms."
216+ )
217+
218+ if not isinstance (key , self ._crypto_key_types ):
219+ valid_classes = (cls .__name__ for cls in self ._crypto_key_types )
220+ actual_class = key .__class__ .__name__
221+ self_class = self .__class__ .__name__
222+ raise InvalidKeyError (
223+ f"Expected one of { valid_classes } , got: { actual_class } . Invalid Key type for { self_class } "
224+ )
225+
166226 @abstractmethod
167227 def prepare_key (self , key : Any ) -> Any :
168228 """
@@ -323,11 +383,13 @@ class RSAAlgorithm(Algorithm):
323383 SHA384 : ClassVar [type [hashes .HashAlgorithm ]] = hashes .SHA384
324384 SHA512 : ClassVar [type [hashes .HashAlgorithm ]] = hashes .SHA512
325385
386+ _crypto_key_types = ALLOWED_RSA_KEY_TYPES
387+
326388 def __init__ (self , hash_alg : type [hashes .HashAlgorithm ]) -> None :
327389 self .hash_alg = hash_alg
328390
329391 def prepare_key (self , key : AllowedRSAKeys | str | bytes ) -> AllowedRSAKeys :
330- if isinstance (key , ( RSAPrivateKey , RSAPublicKey ) ):
392+ if isinstance (key , self . _crypto_key_types ):
331393 return key
332394
333395 if not isinstance (key , (bytes , str )):
@@ -337,14 +399,20 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
337399
338400 try :
339401 if key_bytes .startswith (b"ssh-rsa" ):
340- return cast (RSAPublicKey , load_ssh_public_key (key_bytes ))
402+ public_key : PublicKeyTypes = load_ssh_public_key (key_bytes )
403+ self .check_crypto_key_type (public_key )
404+ return cast (RSAPublicKey , public_key )
341405 else :
342- return cast (
343- RSAPrivateKey , load_pem_private_key ( key_bytes , password = None )
406+ private_key : PrivateKeyTypes = load_pem_private_key (
407+ key_bytes , password = None
344408 )
409+ self .check_crypto_key_type (private_key )
410+ return cast (RSAPrivateKey , private_key )
345411 except ValueError :
346412 try :
347- return cast (RSAPublicKey , load_pem_public_key (key_bytes ))
413+ public_key = load_pem_public_key (key_bytes )
414+ self .check_crypto_key_type (public_key )
415+ return cast (RSAPublicKey , public_key )
348416 except (ValueError , UnsupportedAlgorithm ):
349417 raise InvalidKeyError (
350418 "Could not parse the provided public key."
@@ -493,11 +561,13 @@ class ECAlgorithm(Algorithm):
493561 SHA384 : ClassVar [type [hashes .HashAlgorithm ]] = hashes .SHA384
494562 SHA512 : ClassVar [type [hashes .HashAlgorithm ]] = hashes .SHA512
495563
564+ _crypto_key_types = ALLOWED_EC_KEY_TYPES
565+
496566 def __init__ (self , hash_alg : type [hashes .HashAlgorithm ]) -> None :
497567 self .hash_alg = hash_alg
498568
499569 def prepare_key (self , key : AllowedECKeys | str | bytes ) -> AllowedECKeys :
500- if isinstance (key , ( EllipticCurvePrivateKey , EllipticCurvePublicKey ) ):
570+ if isinstance (key , self . _crypto_key_types ):
501571 return key
502572
503573 if not isinstance (key , (bytes , str )):
@@ -510,21 +580,17 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
510580 # the Verifying Key first.
511581 try :
512582 if key_bytes .startswith (b"ecdsa-sha2-" ):
513- crypto_key = load_ssh_public_key (key_bytes )
583+ public_key : PublicKeyTypes = load_ssh_public_key (key_bytes )
514584 else :
515- crypto_key = load_pem_public_key (key_bytes ) # type: ignore[assignment]
516- except ValueError :
517- crypto_key = load_pem_private_key (key_bytes , password = None ) # type: ignore[assignment]
585+ public_key = load_pem_public_key (key_bytes )
518586
519- # Explicit check the key to prevent confusing errors from cryptography
520- if not isinstance (
521- crypto_key , (EllipticCurvePrivateKey , EllipticCurvePublicKey )
522- ):
523- raise InvalidKeyError (
524- "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
525- ) from None
526-
527- return crypto_key
587+ # Explicit check the key to prevent confusing errors from cryptography
588+ self .check_crypto_key_type (public_key )
589+ return cast (EllipticCurvePublicKey , public_key )
590+ except ValueError :
591+ private_key = load_pem_private_key (key_bytes , password = None )
592+ self .check_crypto_key_type (private_key )
593+ return cast (EllipticCurvePrivateKey , private_key )
528594
529595 def sign (self , msg : bytes , key : EllipticCurvePrivateKey ) -> bytes :
530596 der_sig = key .sign (msg , ECDSA (self .hash_alg ()))
@@ -715,31 +781,32 @@ class OKPAlgorithm(Algorithm):
715781 This class requires ``cryptography>=2.6`` to be installed.
716782 """
717783
784+ _crypto_key_types = ALLOWED_OKP_KEY_TYPES
785+
718786 def __init__ (self , ** kwargs : Any ) -> None :
719787 pass
720788
721789 def prepare_key (self , key : AllowedOKPKeys | str | bytes ) -> AllowedOKPKeys :
722- if isinstance (key , (bytes , str )):
723- key_str = key .decode ("utf-8" ) if isinstance (key , bytes ) else key
724- key_bytes = key .encode ("utf-8" ) if isinstance (key , str ) else key
725-
726- if "-----BEGIN PUBLIC" in key_str :
727- key = load_pem_public_key (key_bytes ) # type: ignore[assignment]
728- elif "-----BEGIN PRIVATE" in key_str :
729- key = load_pem_private_key (key_bytes , password = None ) # type: ignore[assignment]
730- elif key_str [0 :4 ] == "ssh-" :
731- key = load_ssh_public_key (key_bytes ) # type: ignore[assignment]
790+ if not isinstance (key , (str , bytes )):
791+ self .check_crypto_key_type (key )
792+ return cast ("AllowedOKPKeys" , key )
793+
794+ key_str = key .decode ("utf-8" ) if isinstance (key , bytes ) else key
795+ key_bytes = key .encode ("utf-8" ) if isinstance (key , str ) else key
796+
797+ loaded_key : PublicKeyTypes | PrivateKeyTypes
798+ if "-----BEGIN PUBLIC" in key_str :
799+ loaded_key = load_pem_public_key (key_bytes )
800+ elif "-----BEGIN PRIVATE" in key_str :
801+ loaded_key = load_pem_private_key (key_bytes , password = None )
802+ elif key_str [0 :4 ] == "ssh-" :
803+ loaded_key = load_ssh_public_key (key_bytes )
804+ else :
805+ raise InvalidKeyError ("Not a public or private key" )
732806
733807 # Explicit check the key to prevent confusing errors from cryptography
734- if not isinstance (
735- key ,
736- (Ed25519PrivateKey , Ed25519PublicKey , Ed448PrivateKey , Ed448PublicKey ),
737- ):
738- raise InvalidKeyError (
739- "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
740- )
741-
742- return key
808+ self .check_crypto_key_type (loaded_key )
809+ return cast ("AllowedOKPKeys" , loaded_key )
743810
744811 def sign (
745812 self , msg : str | bytes , key : Ed25519PrivateKey | Ed448PrivateKey
0 commit comments