Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 85 additions & 7 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,18 @@ class HMACAlgorithm(Algorithm):
def __init__(self, hash_alg: HashlibHash) -> None:
self.hash_alg = hash_alg

def _get_min_key_length(self) -> int:
"""Get minimum key length in bytes based on hash algorithm."""
if self.hash_alg == hashlib.sha256:
return 32 # 256 bits for HS256
elif self.hash_alg == hashlib.sha384:
return 48 # 384 bits for HS384
elif self.hash_alg == hashlib.sha512:
return 64 # 512 bits for HS512
else:
# For any other hash algorithm, require at least 32 bytes (256 bits)
return 32

def prepare_key(self, key: str | bytes) -> bytes:
key_bytes = force_bytes(key)

Expand All @@ -325,6 +337,24 @@ def prepare_key(self, key: str | bytes) -> bytes:
" should not be used as an HMAC secret."
)

# Enforce minimum key lengths per RFC 7518 and NIST guidelines
min_key_length = self._get_min_key_length()
if len(key_bytes) < min_key_length:
# Get algorithm name for error message
alg_name = "HMAC"
if self.hash_alg == hashlib.sha256:
alg_name = "HS256"
elif self.hash_alg == hashlib.sha384:
alg_name = "HS384"
elif self.hash_alg == hashlib.sha512:
alg_name = "HS512"

raise InvalidKeyError(
f"HMAC key must be at least {min_key_length * 8} bits "
f"({min_key_length} bytes) for {alg_name} algorithm. "
f"Key provided is {len(key_bytes) * 8} bits ({len(key_bytes)} bytes)."
)

return key_bytes

@overload
Expand Down Expand Up @@ -366,7 +396,18 @@ def from_jwk(jwk: str | JWKDict) -> bytes:
if obj.get("kty") != "oct":
raise InvalidKeyError("Not an HMAC key")

return base64url_decode(obj["k"])
key_bytes = base64url_decode(obj["k"])

# Validate key length - use a conservative minimum of 32 bytes (256 bits)
min_key_length = 32 # 256 bits minimum
if len(key_bytes) < min_key_length:
raise InvalidKeyError(
f"HMAC key must be at least {min_key_length * 8} bits "
f"({min_key_length} bytes). Key provided is {len(key_bytes) * 8} "
f"bits ({len(key_bytes)} bytes)."
)

return key_bytes

def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()
Expand All @@ -392,8 +433,32 @@ class RSAAlgorithm(Algorithm):
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg

def _validate_rsa_key_size(self, key: AllowedRSAKeys) -> None:
"""Validate RSA key size meets minimum security requirements."""
key_size = key.key_size
min_key_size = 2048 # Minimum 2048 bits per RFC 7518 and NIST SP800-117

if key_size < min_key_size:
raise InvalidKeyError(
f"RSA key must be at least {min_key_size} bits. "
f"Key provided is {key_size} bits."
)

@staticmethod
def _validate_rsa_key_size_static(key: AllowedRSAKeys) -> None:
"""Static version of RSA key size validation for use in static methods."""
key_size = key.key_size
min_key_size = 2048 # Minimum 2048 bits per RFC 7518 and NIST SP800-117

if key_size < min_key_size:
raise InvalidKeyError(
f"RSA key must be at least {min_key_size} bits. "
f"Key provided is {key_size} bits."
)

def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, self._crypto_key_types):
self._validate_rsa_key_size(key)
return key

if not isinstance(key, (bytes, str)):
Expand All @@ -405,18 +470,24 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if key_bytes.startswith(b"ssh-rsa"):
public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
self.check_crypto_key_type(public_key)
return cast(RSAPublicKey, public_key)
rsa_public_key = cast(RSAPublicKey, public_key)
self._validate_rsa_key_size(rsa_public_key)
return rsa_public_key
else:
private_key: PrivateKeyTypes = load_pem_private_key(
key_bytes, password=None
)
self.check_crypto_key_type(private_key)
return cast(RSAPrivateKey, private_key)
rsa_private_key = cast(RSAPrivateKey, private_key)
self._validate_rsa_key_size(rsa_private_key)
return rsa_private_key
except ValueError:
try:
public_key = load_pem_public_key(key_bytes)
self.check_crypto_key_type(public_key)
return cast(RSAPublicKey, public_key)
rsa_public_key = cast(RSAPublicKey, public_key)
self._validate_rsa_key_size(rsa_public_key)
return rsa_public_key
except (ValueError, UnsupportedAlgorithm):
raise InvalidKeyError(
"Could not parse the provided public key."
Expand Down Expand Up @@ -519,6 +590,9 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
iqmp=from_base64url_uint(obj["qi"]),
public_numbers=public_numbers,
)
private_key = numbers.private_key()
RSAAlgorithm._validate_rsa_key_size_static(private_key)
return private_key
else:
d = from_base64url_uint(obj["d"])
p, q = rsa_recover_prime_factors(
Expand All @@ -535,13 +609,17 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
public_numbers=public_numbers,
)

return numbers.private_key()
private_key = numbers.private_key()
RSAAlgorithm._validate_rsa_key_size_static(private_key)
return private_key
elif "n" in obj and "e" in obj:
# Public key
return RSAPublicNumbers(
public_key = RSAPublicNumbers(
from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]),
).public_key()
RSAAlgorithm._validate_rsa_key_size_static(public_key)
return public_key
else:
raise InvalidKeyError("Not a public or private key")

Expand Down Expand Up @@ -793,7 +871,7 @@ def __init__(self, **kwargs: Any) -> None:
def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
if not isinstance(key, (str, bytes)):
self.check_crypto_key_type(key)
return cast("AllowedOKPKeys", key)
return cast("AllowedOKPKeys", key) # type: ignore[redundant-cast] # Explicit for clarity

key_str = key.decode("utf-8") if isinstance(key, bytes) else key
key_bytes = key.encode("utf-8") if isinstance(key, str) else key
Expand Down
126 changes: 123 additions & 3 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_hmac_should_reject_nonstring_key(self):
def test_hmac_should_accept_unicode_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)

algo.prepare_key("awesome")
algo.prepare_key("awesome" * 5) # 35 characters > 32 bytes minimum

@pytest.mark.parametrize(
"key",
Expand Down Expand Up @@ -101,12 +101,12 @@ def test_hmac_jwk_should_parse_and_verify(self):
@pytest.mark.parametrize("as_dict", (False, True))
def test_hmac_to_jwk_returns_correct_values(self, as_dict):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
key: Any = algo.to_jwk("secret", as_dict=as_dict)
key: Any = algo.to_jwk("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", as_dict=as_dict)

if not as_dict:
key = json.loads(key)

assert key == {"kty": "oct", "k": "c2VjcmV0"}
assert key == {"kty": "oct", "k": "YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE"}

def test_hmac_from_jwk_should_raise_exception_if_not_hmac_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
Expand All @@ -122,6 +122,57 @@ def test_hmac_from_jwk_should_raise_exception_if_empty_json(self):
with pytest.raises(InvalidKeyError):
algo.from_jwk(keyfile.read())

# CVE-2025-45768: Test minimum key length enforcement
@pytest.mark.parametrize(
"hash_alg,min_length,weak_key",
[
(HMACAlgorithm.SHA256, 32, b"short"), # 5 bytes, too short for HS256
(HMACAlgorithm.SHA256, 32, b"a" * 31), # 31 bytes, just under minimum
(HMACAlgorithm.SHA384, 48, b"b" * 47), # 47 bytes, just under minimum
(HMACAlgorithm.SHA512, 64, b"c" * 63), # 63 bytes, just under minimum
],
)
def test_hmac_should_reject_weak_keys(self, hash_alg, min_length, weak_key):
"""Test that HMAC keys below minimum length are rejected (CVE-2025-45768)"""
algo = HMACAlgorithm(hash_alg)

with pytest.raises(InvalidKeyError) as excinfo:
algo.prepare_key(weak_key)

error_msg = str(excinfo.value)
assert f"at least {min_length * 8} bits" in error_msg
assert f"Key provided is {len(weak_key) * 8} bits" in error_msg

@pytest.mark.parametrize(
"hash_alg,adequate_key",
[
(HMACAlgorithm.SHA256, b"a" * 32), # 32 bytes for HS256
(HMACAlgorithm.SHA384, b"b" * 48), # 48 bytes for HS384
(HMACAlgorithm.SHA512, b"c" * 64), # 64 bytes for HS512
],
)
def test_hmac_should_accept_adequate_keys(self, hash_alg, adequate_key):
"""Test that HMAC keys at or above minimum length are accepted"""
algo = HMACAlgorithm(hash_alg)

# Should not raise an exception
prepared_key = algo.prepare_key(adequate_key)
assert prepared_key == adequate_key

def test_hmac_from_jwk_should_reject_weak_keys(self):
"""Test that weak HMAC keys are rejected when loaded from JWK (CVE-2025-45768)"""
algo = HMACAlgorithm(HMACAlgorithm.SHA256)

# Create a JWK with a weak key (5 bytes)
weak_jwk = {"kty": "oct", "k": "c2hvcnQ"} # base64url("short") - only 5 bytes

with pytest.raises(InvalidKeyError) as excinfo:
algo.from_jwk(weak_jwk)

error_msg = str(excinfo.value)
assert "at least 256 bits" in error_msg
assert "40 bits" in error_msg # 5 bytes * 8 = 40 bits

@crypto_required
def test_rsa_should_parse_pem_public_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
Expand Down Expand Up @@ -173,6 +224,75 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self):
result = algo.verify(message, pub_key, sig)
assert not result

# CVE-2025-45768: Test RSA minimum key size enforcement
@crypto_required
def test_rsa_should_reject_weak_keys(self):
"""Test that RSA keys below 2048 bits are rejected (CVE-2025-45768)"""
from cryptography.hazmat.primitives.asymmetric import rsa

algo = RSAAlgorithm(RSAAlgorithm.SHA256)

# Generate a weak 1024-bit RSA key
weak_private_key = rsa.generate_private_key(
public_exponent=65537, key_size=1024
)
weak_public_key = weak_private_key.public_key()

# Test with private key
with pytest.raises(InvalidKeyError) as excinfo:
algo.prepare_key(weak_private_key)

error_msg = str(excinfo.value)
assert "at least 2048 bits" in error_msg
assert "1024 bits" in error_msg

# Test with public key
with pytest.raises(InvalidKeyError) as excinfo:
algo.prepare_key(weak_public_key)

error_msg = str(excinfo.value)
assert "at least 2048 bits" in error_msg
assert "1024 bits" in error_msg

@crypto_required
def test_rsa_should_accept_adequate_keys(self):
"""Test that RSA keys at or above 2048 bits are accepted"""
from cryptography.hazmat.primitives.asymmetric import rsa

algo = RSAAlgorithm(RSAAlgorithm.SHA256)

# Generate a strong 2048-bit RSA key
strong_private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
strong_public_key = strong_private_key.public_key()

# Should not raise exceptions
prepared_private = algo.prepare_key(strong_private_key)
prepared_public = algo.prepare_key(strong_public_key)

assert prepared_private == strong_private_key
assert prepared_public == strong_public_key

@crypto_required
def test_rsa_from_jwk_should_reject_weak_keys(self):
"""Test that weak RSA keys are rejected when loaded from JWK (CVE-2025-45768)"""
from cryptography.hazmat.primitives.asymmetric import rsa

# Generate a weak 1024-bit RSA key and convert to JWK
weak_key = rsa.generate_private_key(public_exponent=65537, key_size=1024)

# Convert to JWK format (this will work since to_jwk doesn't validate)
weak_jwk = RSAAlgorithm.to_jwk(weak_key, as_dict=True)

# Now try to load it back - should fail
with pytest.raises(InvalidKeyError) as excinfo:
RSAAlgorithm.from_jwk(weak_jwk)

error_msg = str(excinfo.value)
assert "at least 2048 bits" in error_msg
assert "1024 bits" in error_msg

@crypto_required
def test_ec_jwk_public_and_private_keys_should_parse_and_verify(self):
tests = {
Expand Down
Loading
Loading