Skip to content

Commit d5ca701

Browse files
Fixes jpadilla#964: Validate key against allowed types for Algorithm family (jpadilla#985)
* Fixes jpadilla#964: Validate key against allowed types for Algorithm family * fix mypy issues * fix mypy errors part 2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix OKPAlgorithm return value * fix tests * add changelog entry * tests, and change check_crypto_key_type to throw when used by non-crypto * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docstring * fix tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add mypy comment * remove TODO --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9b3f9ad commit d5ca701

File tree

5 files changed

+177
-56
lines changed

5 files changed

+177
-56
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ This project adheres to `Semantic Versioning <https://semver.org/>`__.
77
`Unreleased <https://github.com/jpadilla/pyjwt/compare/2.10.1...HEAD>`__
88
------------------------------------------------------------------------
99

10+
Fixed
11+
~~~~~
12+
- Validate key against allowed types for Algorithm family in `#964 <https://github.com/jpadilla/pyjwt/pull/964>`__
13+
1014
Added
1115
~~~~~
1216

jwt/algorithms.py

Lines changed: 111 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,56 @@
6565
load_ssh_public_key,
6666
)
6767

68+
# pyjwt-964: we use these both for type checking below, as well as for validating the key passed in.
69+
# in Py >= 3.10, we can replace this with the Union types below
70+
ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey, RSAPublicKey)
71+
ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey, EllipticCurvePublicKey)
72+
ALLOWED_OKP_KEY_TYPES = (
73+
Ed25519PrivateKey,
74+
Ed25519PublicKey,
75+
Ed448PrivateKey,
76+
Ed448PublicKey,
77+
)
78+
ALLOWED_KEY_TYPES = (
79+
ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES
80+
)
81+
ALLOWED_PRIVATE_KEY_TYPES = (
82+
RSAPrivateKey,
83+
EllipticCurvePrivateKey,
84+
Ed25519PrivateKey,
85+
Ed448PrivateKey,
86+
)
87+
ALLOWED_PUBLIC_KEY_TYPES = (
88+
RSAPublicKey,
89+
EllipticCurvePublicKey,
90+
Ed25519PublicKey,
91+
Ed448PublicKey,
92+
)
93+
6894
has_crypto = True
6995
except ModuleNotFoundError:
7096
has_crypto = False
7197

7298

7399
if TYPE_CHECKING:
100+
from typing import TypeAlias
101+
102+
from cryptography.hazmat.primitives.asymmetric.types import (
103+
PrivateKeyTypes,
104+
PublicKeyTypes,
105+
)
106+
74107
# Type aliases for convenience in algorithms method signatures
75-
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
76-
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
77-
AllowedOKPKeys = (
108+
AllowedRSAKeys: TypeAlias = RSAPrivateKey | RSAPublicKey
109+
AllowedECKeys: TypeAlias = EllipticCurvePrivateKey | EllipticCurvePublicKey
110+
AllowedOKPKeys: TypeAlias = (
78111
Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
79112
)
80-
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
81-
AllowedPrivateKeys = (
113+
AllowedKeys: TypeAlias = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
114+
AllowedPrivateKeys: TypeAlias = (
82115
RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
83116
)
84-
AllowedPublicKeys = (
117+
AllowedPublicKeys: TypeAlias = (
85118
RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
86119
)
87120

@@ -141,6 +174,9 @@ class Algorithm(ABC):
141174
The interface for an algorithm used to sign and verify tokens.
142175
"""
143176

177+
# pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
178+
_crypto_key_types: tuple[type[AllowedKeys], ...] | None = None
179+
144180
def compute_hash_digest(self, bytestr: bytes) -> bytes:
145181
"""
146182
Compute a hash digest using the specified algorithm's hash algorithm.
@@ -163,6 +199,30 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
163199
else:
164200
return bytes(hash_alg(bytestr).digest())
165201

202+
def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes):
203+
"""Check that the key belongs to the right cryptographic family.
204+
205+
Note that this method only works when `cryptography` is installed.
206+
207+
Args:
208+
key (Any): Potentially a cryptography key
209+
Raises:
210+
ValueError: if `cryptography` is not installed, or this method is called by a non-cryptography algorithm
211+
InvalidKeyError: if the key doesn't match the expected key classes
212+
"""
213+
if not has_crypto or self._crypto_key_types is None:
214+
raise ValueError(
215+
"This method requires the cryptography library, and should only be used by cryptography-based algorithms."
216+
)
217+
218+
if not isinstance(key, self._crypto_key_types):
219+
valid_classes = (cls.__name__ for cls in self._crypto_key_types)
220+
actual_class = key.__class__.__name__
221+
self_class = self.__class__.__name__
222+
raise InvalidKeyError(
223+
f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}"
224+
)
225+
166226
@abstractmethod
167227
def prepare_key(self, key: Any) -> Any:
168228
"""
@@ -323,11 +383,13 @@ class RSAAlgorithm(Algorithm):
323383
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
324384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
325385

386+
_crypto_key_types = ALLOWED_RSA_KEY_TYPES
387+
326388
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
327389
self.hash_alg = hash_alg
328390

329391
def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
330-
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
392+
if isinstance(key, self._crypto_key_types):
331393
return key
332394

333395
if not isinstance(key, (bytes, str)):
@@ -337,14 +399,20 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
337399

338400
try:
339401
if key_bytes.startswith(b"ssh-rsa"):
340-
return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
402+
public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
403+
self.check_crypto_key_type(public_key)
404+
return cast(RSAPublicKey, public_key)
341405
else:
342-
return cast(
343-
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
406+
private_key: PrivateKeyTypes = load_pem_private_key(
407+
key_bytes, password=None
344408
)
409+
self.check_crypto_key_type(private_key)
410+
return cast(RSAPrivateKey, private_key)
345411
except ValueError:
346412
try:
347-
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
413+
public_key = load_pem_public_key(key_bytes)
414+
self.check_crypto_key_type(public_key)
415+
return cast(RSAPublicKey, public_key)
348416
except (ValueError, UnsupportedAlgorithm):
349417
raise InvalidKeyError(
350418
"Could not parse the provided public key."
@@ -493,11 +561,13 @@ class ECAlgorithm(Algorithm):
493561
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
494562
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
495563

564+
_crypto_key_types = ALLOWED_EC_KEY_TYPES
565+
496566
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
497567
self.hash_alg = hash_alg
498568

499569
def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
500-
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
570+
if isinstance(key, self._crypto_key_types):
501571
return key
502572

503573
if not isinstance(key, (bytes, str)):
@@ -510,21 +580,17 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
510580
# the Verifying Key first.
511581
try:
512582
if key_bytes.startswith(b"ecdsa-sha2-"):
513-
crypto_key = load_ssh_public_key(key_bytes)
583+
public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
514584
else:
515-
crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
516-
except ValueError:
517-
crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
585+
public_key = load_pem_public_key(key_bytes)
518586

519-
# Explicit check the key to prevent confusing errors from cryptography
520-
if not isinstance(
521-
crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
522-
):
523-
raise InvalidKeyError(
524-
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
525-
) from None
526-
527-
return crypto_key
587+
# Explicit check the key to prevent confusing errors from cryptography
588+
self.check_crypto_key_type(public_key)
589+
return cast(EllipticCurvePublicKey, public_key)
590+
except ValueError:
591+
private_key = load_pem_private_key(key_bytes, password=None)
592+
self.check_crypto_key_type(private_key)
593+
return cast(EllipticCurvePrivateKey, private_key)
528594

529595
def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
530596
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
@@ -715,31 +781,32 @@ class OKPAlgorithm(Algorithm):
715781
This class requires ``cryptography>=2.6`` to be installed.
716782
"""
717783

784+
_crypto_key_types = ALLOWED_OKP_KEY_TYPES
785+
718786
def __init__(self, **kwargs: Any) -> None:
719787
pass
720788

721789
def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
722-
if isinstance(key, (bytes, str)):
723-
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
724-
key_bytes = key.encode("utf-8") if isinstance(key, str) else key
725-
726-
if "-----BEGIN PUBLIC" in key_str:
727-
key = load_pem_public_key(key_bytes) # type: ignore[assignment]
728-
elif "-----BEGIN PRIVATE" in key_str:
729-
key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
730-
elif key_str[0:4] == "ssh-":
731-
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
790+
if not isinstance(key, (str, bytes)):
791+
self.check_crypto_key_type(key)
792+
return cast("AllowedOKPKeys", key)
793+
794+
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
795+
key_bytes = key.encode("utf-8") if isinstance(key, str) else key
796+
797+
loaded_key: PublicKeyTypes | PrivateKeyTypes
798+
if "-----BEGIN PUBLIC" in key_str:
799+
loaded_key = load_pem_public_key(key_bytes)
800+
elif "-----BEGIN PRIVATE" in key_str:
801+
loaded_key = load_pem_private_key(key_bytes, password=None)
802+
elif key_str[0:4] == "ssh-":
803+
loaded_key = load_ssh_public_key(key_bytes)
804+
else:
805+
raise InvalidKeyError("Not a public or private key")
732806

733807
# Explicit check the key to prevent confusing errors from cryptography
734-
if not isinstance(
735-
key,
736-
(Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
737-
):
738-
raise InvalidKeyError(
739-
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
740-
)
741-
742-
return key
808+
self.check_crypto_key_type(loaded_key)
809+
return cast("AllowedOKPKeys", loaded_key)
743810

744811
def sign(
745812
self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey

tests/keys/testkey_ed25519.pem

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-----BEGIN PRIVATE KEY-----
2+
MC4CAQAwBQYDK2VwBCIEIJb2MBNIWqpJ2zwLlbw8JkHNPIBkFCv/g127aQI7dQ1Q
3+
-----END PRIVATE KEY-----

tests/keys/testkey_ed25519.pub.pem

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-----BEGIN PUBLIC KEY-----
2+
MCowBQYDK2VwAyEASmyuOjH4q3bPqsOwf61G4jBH5L2g9kWnCDOp/7IOHKg=
3+
-----END PUBLIC KEY-----

tests/test_algorithms.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@
3333

3434

3535
class TestAlgorithms:
36+
def test_check_crypto_key_type_should_fail_when_not_using_crypto(self):
37+
"""If has_crypto is False, or if _crypto_key_types is None, then this method should throw."""
38+
39+
algo = NoneAlgorithm()
40+
with pytest.raises(ValueError):
41+
algo.check_crypto_key_type("key") # type: ignore[arg-type]
42+
3643
def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
3744
algo = NoneAlgorithm()
3845

@@ -811,6 +818,7 @@ def test_ec_verify_should_return_true_for_test_vector(self):
811818
@crypto_required
812819
class TestOKPAlgorithms:
813820
hello_world_sig = b"Qxa47mk/azzUgmY2StAOguAd4P7YBLpyCfU3JdbaiWnXM4o4WibXwmIHvNYgN3frtE2fcyd8OYEaOiD/KiwkCg=="
821+
hello_world_sig_pem = b"9ueQE7PT8uudHIQb2zZZ7tB7k1X3jeTnIfOVvGCINZejrqQbru1EXPeuMlGcQEZrGkLVcfMmr99W/+byxfppAg=="
814822
hello_world = b"Hello World!"
815823

816824
def test_okp_ed25519_should_reject_non_string_key(self):
@@ -825,58 +833,94 @@ def test_okp_ed25519_should_reject_non_string_key(self):
825833
with open(key_path("testkey_ed25519.pub")) as keyfile:
826834
algo.prepare_key(keyfile.read())
827835

828-
def test_okp_ed25519_sign_should_generate_correct_signature_value(self):
836+
@pytest.mark.parametrize(
837+
"private_key_file,public_key_file,sig_attr",
838+
[
839+
("testkey_ed25519", "testkey_ed25519.pub", "hello_world_sig"),
840+
("testkey_ed25519.pem", "testkey_ed25519.pub.pem", "hello_world_sig_pem"),
841+
],
842+
)
843+
def test_okp_ed25519_sign_should_generate_correct_signature_value(
844+
self, private_key_file, public_key_file, sig_attr
845+
):
829846
algo = OKPAlgorithm()
830847

831848
jwt_message = self.hello_world
832849

833-
expected_sig = base64.b64decode(self.hello_world_sig)
850+
expected_sig = base64.b64decode(getattr(self, sig_attr))
834851

835-
with open(key_path("testkey_ed25519")) as keyfile:
852+
with open(key_path(private_key_file)) as keyfile:
836853
jwt_key = cast(Ed25519PrivateKey, algo.prepare_key(keyfile.read()))
837854

838-
with open(key_path("testkey_ed25519.pub")) as keyfile:
855+
with open(key_path(public_key_file)) as keyfile:
839856
jwt_pub_key = cast(Ed25519PublicKey, algo.prepare_key(keyfile.read()))
840857

841858
algo.sign(jwt_message, jwt_key)
842859
result = algo.verify(jwt_message, jwt_pub_key, expected_sig)
843860
assert result
844861

845-
def test_okp_ed25519_verify_should_return_false_if_signature_invalid(self):
862+
@pytest.mark.parametrize(
863+
"public_key_file,sig_attr",
864+
[
865+
("testkey_ed25519.pub", "hello_world_sig"),
866+
("testkey_ed25519.pub.pem", "hello_world_sig_pem"),
867+
],
868+
)
869+
def test_okp_ed25519_verify_should_return_false_if_signature_invalid(
870+
self, public_key_file, sig_attr
871+
):
846872
algo = OKPAlgorithm()
847873

848874
jwt_message = self.hello_world
849-
jwt_sig = base64.b64decode(self.hello_world_sig)
875+
jwt_sig = base64.b64decode(getattr(self, sig_attr))
850876

851877
jwt_sig += b"123" # Signature is now invalid
852878

853-
with open(key_path("testkey_ed25519.pub")) as keyfile:
879+
with open(key_path(public_key_file)) as keyfile:
854880
jwt_pub_key = algo.prepare_key(keyfile.read())
855881

856882
result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
857883
assert not result
858884

859-
def test_okp_ed25519_verify_should_return_true_if_signature_valid(self):
885+
@pytest.mark.parametrize(
886+
"public_key_file,sig_attr",
887+
[
888+
("testkey_ed25519.pub", "hello_world_sig"),
889+
("testkey_ed25519.pub.pem", "hello_world_sig_pem"),
890+
],
891+
)
892+
def test_okp_ed25519_verify_should_return_true_if_signature_valid(
893+
self, public_key_file, sig_attr
894+
):
860895
algo = OKPAlgorithm()
861896

862897
jwt_message = self.hello_world
863-
jwt_sig = base64.b64decode(self.hello_world_sig)
898+
jwt_sig = base64.b64decode(getattr(self, sig_attr))
864899

865-
with open(key_path("testkey_ed25519.pub")) as keyfile:
900+
with open(key_path(public_key_file)) as keyfile:
866901
jwt_pub_key = algo.prepare_key(keyfile.read())
867902

868903
result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
869904
assert result
870905

871-
def test_okp_ed25519_prepare_key_should_be_idempotent(self):
906+
@pytest.mark.parametrize(
907+
"public_key_file", ("testkey_ed25519.pub", "testkey_ed25519.pub.pem")
908+
)
909+
def test_okp_ed25519_prepare_key_should_be_idempotent(self, public_key_file):
872910
algo = OKPAlgorithm()
873911

874-
with open(key_path("testkey_ed25519.pub")) as keyfile:
912+
with open(key_path(public_key_file)) as keyfile:
875913
jwt_pub_key_first = algo.prepare_key(keyfile.read())
876914
jwt_pub_key_second = algo.prepare_key(jwt_pub_key_first)
877915

878916
assert jwt_pub_key_first == jwt_pub_key_second
879917

918+
def test_okp_ed25519_prepare_key_should_reject_invalid_key(self):
919+
algo = OKPAlgorithm()
920+
921+
with pytest.raises(InvalidKeyError):
922+
algo.prepare_key("not a valid key")
923+
880924
def test_okp_ed25519_jwk_private_key_should_parse_and_verify(self):
881925
algo = OKPAlgorithm()
882926

0 commit comments

Comments
 (0)