Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 85 additions & 7 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,18 @@ class HMACAlgorithm(Algorithm):
def __init__(self, hash_alg: HashlibHash) -> None:
self.hash_alg = hash_alg

def _get_min_key_length(self) -> int:
"""Get minimum key length in bytes based on hash algorithm."""
if self.hash_alg == hashlib.sha256:
return 32 # 256 bits for HS256
elif self.hash_alg == hashlib.sha384:
return 48 # 384 bits for HS384
elif self.hash_alg == hashlib.sha512:
return 64 # 512 bits for HS512
else:
# For any other hash algorithm, require at least 32 bytes (256 bits)
return 32

def prepare_key(self, key: str | bytes) -> bytes:
key_bytes = force_bytes(key)

Expand All @@ -325,6 +337,24 @@ def prepare_key(self, key: str | bytes) -> bytes:
" should not be used as an HMAC secret."
)

# Enforce minimum key lengths per RFC 7518 and NIST guidelines
min_key_length = self._get_min_key_length()
if len(key_bytes) < min_key_length:
# Get algorithm name for error message
alg_name = "HMAC"
if self.hash_alg == hashlib.sha256:
alg_name = "HS256"
elif self.hash_alg == hashlib.sha384:
alg_name = "HS384"
elif self.hash_alg == hashlib.sha512:
alg_name = "HS512"

raise InvalidKeyError(
f"HMAC key must be at least {min_key_length * 8} bits "
f"({min_key_length} bytes) for {alg_name} algorithm. "
f"Key provided is {len(key_bytes) * 8} bits ({len(key_bytes)} bytes)."
)

return key_bytes

@overload
Expand Down Expand Up @@ -366,7 +396,18 @@ def from_jwk(jwk: str | JWKDict) -> bytes:
if obj.get("kty") != "oct":
raise InvalidKeyError("Not an HMAC key")

return base64url_decode(obj["k"])
key_bytes = base64url_decode(obj["k"])

# Validate key length - use a conservative minimum of 32 bytes (256 bits)
min_key_length = 32 # 256 bits minimum
if len(key_bytes) < min_key_length:
raise InvalidKeyError(
f"HMAC key must be at least {min_key_length * 8} bits "
f"({min_key_length} bytes). Key provided is {len(key_bytes) * 8} "
f"bits ({len(key_bytes)} bytes)."
)

return key_bytes

def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()
Expand All @@ -392,8 +433,32 @@ class RSAAlgorithm(Algorithm):
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg

def _validate_rsa_key_size(self, key: AllowedRSAKeys) -> None:
"""Validate RSA key size meets minimum security requirements."""
key_size = key.key_size
min_key_size = 2048 # Minimum 2048 bits per RFC 7518 and NIST SP800-117

if key_size < min_key_size:
raise InvalidKeyError(
f"RSA key must be at least {min_key_size} bits. "
f"Key provided is {key_size} bits."
)

@staticmethod
def _validate_rsa_key_size_static(key: AllowedRSAKeys) -> None:
"""Static version of RSA key size validation for use in static methods."""
key_size = key.key_size
min_key_size = 2048 # Minimum 2048 bits per RFC 7518 and NIST SP800-117

if key_size < min_key_size:
raise InvalidKeyError(
f"RSA key must be at least {min_key_size} bits. "
f"Key provided is {key_size} bits."
)

def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, self._crypto_key_types):
self._validate_rsa_key_size(key)
return key

if not isinstance(key, (bytes, str)):
Expand All @@ -405,18 +470,24 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if key_bytes.startswith(b"ssh-rsa"):
public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
self.check_crypto_key_type(public_key)
return cast(RSAPublicKey, public_key)
rsa_public_key = cast(RSAPublicKey, public_key)
self._validate_rsa_key_size(rsa_public_key)
return rsa_public_key
else:
private_key: PrivateKeyTypes = load_pem_private_key(
key_bytes, password=None
)
self.check_crypto_key_type(private_key)
return cast(RSAPrivateKey, private_key)
rsa_private_key = cast(RSAPrivateKey, private_key)
self._validate_rsa_key_size(rsa_private_key)
return rsa_private_key
except ValueError:
try:
public_key = load_pem_public_key(key_bytes)
self.check_crypto_key_type(public_key)
return cast(RSAPublicKey, public_key)
rsa_public_key = cast(RSAPublicKey, public_key)
self._validate_rsa_key_size(rsa_public_key)
return rsa_public_key
except (ValueError, UnsupportedAlgorithm):
raise InvalidKeyError(
"Could not parse the provided public key."
Expand Down Expand Up @@ -519,6 +590,9 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
iqmp=from_base64url_uint(obj["qi"]),
public_numbers=public_numbers,
)
private_key = numbers.private_key()
RSAAlgorithm._validate_rsa_key_size_static(private_key)
return private_key
else:
d = from_base64url_uint(obj["d"])
p, q = rsa_recover_prime_factors(
Expand All @@ -535,13 +609,17 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
public_numbers=public_numbers,
)

return numbers.private_key()
private_key = numbers.private_key()
RSAAlgorithm._validate_rsa_key_size_static(private_key)
return private_key
elif "n" in obj and "e" in obj:
# Public key
return RSAPublicNumbers(
public_key = RSAPublicNumbers(
from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]),
).public_key()
RSAAlgorithm._validate_rsa_key_size_static(public_key)
return public_key
else:
raise InvalidKeyError("Not a public or private key")

Expand Down Expand Up @@ -793,7 +871,7 @@ def __init__(self, **kwargs: Any) -> None:
def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
if not isinstance(key, (str, bytes)):
self.check_crypto_key_type(key)
return cast("AllowedOKPKeys", key)
return key

key_str = key.decode("utf-8") if isinstance(key, bytes) else key
key_bytes = key.encode("utf-8") if isinstance(key, str) else key
Expand Down
Loading
Loading