Skip to content

Commit 366ee13

Browse files
committed
feat: implement minimum key length validation for HMAC and RSA algorithms
for more information, see https://pre-commit.ci
1 parent f2d0ebe commit 366ee13

File tree

7 files changed

+712
-237
lines changed

7 files changed

+712
-237
lines changed

jwt/algorithms.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,18 @@ class HMACAlgorithm(Algorithm):
316316
def __init__(self, hash_alg: HashlibHash) -> None:
317317
self.hash_alg = hash_alg
318318

319+
def _get_min_key_length(self) -> int:
320+
"""Get minimum key length in bytes based on hash algorithm."""
321+
if self.hash_alg == hashlib.sha256:
322+
return 32 # 256 bits for HS256
323+
elif self.hash_alg == hashlib.sha384:
324+
return 48 # 384 bits for HS384
325+
elif self.hash_alg == hashlib.sha512:
326+
return 64 # 512 bits for HS512
327+
else:
328+
# For any other hash algorithm, require at least 32 bytes (256 bits)
329+
return 32
330+
319331
def prepare_key(self, key: str | bytes) -> bytes:
320332
key_bytes = force_bytes(key)
321333

@@ -325,6 +337,24 @@ def prepare_key(self, key: str | bytes) -> bytes:
325337
" should not be used as an HMAC secret."
326338
)
327339

340+
# Enforce minimum key lengths per RFC 7518 and NIST guidelines
341+
min_key_length = self._get_min_key_length()
342+
if len(key_bytes) < min_key_length:
343+
# Get algorithm name for error message
344+
alg_name = "HMAC"
345+
if self.hash_alg == hashlib.sha256:
346+
alg_name = "HS256"
347+
elif self.hash_alg == hashlib.sha384:
348+
alg_name = "HS384"
349+
elif self.hash_alg == hashlib.sha512:
350+
alg_name = "HS512"
351+
352+
raise InvalidKeyError(
353+
f"HMAC key must be at least {min_key_length * 8} bits "
354+
f"({min_key_length} bytes) for {alg_name} algorithm. "
355+
f"Key provided is {len(key_bytes) * 8} bits ({len(key_bytes)} bytes)."
356+
)
357+
328358
return key_bytes
329359

330360
@overload
@@ -366,7 +396,18 @@ def from_jwk(jwk: str | JWKDict) -> bytes:
366396
if obj.get("kty") != "oct":
367397
raise InvalidKeyError("Not an HMAC key")
368398

369-
return base64url_decode(obj["k"])
399+
key_bytes = base64url_decode(obj["k"])
400+
401+
# Validate key length - use a conservative minimum of 32 bytes (256 bits)
402+
min_key_length = 32 # 256 bits minimum
403+
if len(key_bytes) < min_key_length:
404+
raise InvalidKeyError(
405+
f"HMAC key must be at least {min_key_length * 8} bits "
406+
f"({min_key_length} bytes). Key provided is {len(key_bytes) * 8} "
407+
f"bits ({len(key_bytes)} bytes)."
408+
)
409+
410+
return key_bytes
370411

371412
def sign(self, msg: bytes, key: bytes) -> bytes:
372413
return hmac.new(key, msg, self.hash_alg).digest()
@@ -392,8 +433,32 @@ class RSAAlgorithm(Algorithm):
392433
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
393434
self.hash_alg = hash_alg
394435

436+
def _validate_rsa_key_size(self, key: AllowedRSAKeys) -> None:
437+
"""Validate RSA key size meets minimum security requirements."""
438+
key_size = key.key_size
439+
min_key_size = 2048 # Minimum 2048 bits per RFC 7518 and NIST SP800-117
440+
441+
if key_size < min_key_size:
442+
raise InvalidKeyError(
443+
f"RSA key must be at least {min_key_size} bits. "
444+
f"Key provided is {key_size} bits."
445+
)
446+
447+
@staticmethod
448+
def _validate_rsa_key_size_static(key: AllowedRSAKeys) -> None:
449+
"""Static version of RSA key size validation for use in static methods."""
450+
key_size = key.key_size
451+
min_key_size = 2048 # Minimum 2048 bits per RFC 7518 and NIST SP800-117
452+
453+
if key_size < min_key_size:
454+
raise InvalidKeyError(
455+
f"RSA key must be at least {min_key_size} bits. "
456+
f"Key provided is {key_size} bits."
457+
)
458+
395459
def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
396460
if isinstance(key, self._crypto_key_types):
461+
self._validate_rsa_key_size(key)
397462
return key
398463

399464
if not isinstance(key, (bytes, str)):
@@ -405,18 +470,24 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
405470
if key_bytes.startswith(b"ssh-rsa"):
406471
public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
407472
self.check_crypto_key_type(public_key)
408-
return cast(RSAPublicKey, public_key)
473+
rsa_public_key = cast(RSAPublicKey, public_key)
474+
self._validate_rsa_key_size(rsa_public_key)
475+
return rsa_public_key
409476
else:
410477
private_key: PrivateKeyTypes = load_pem_private_key(
411478
key_bytes, password=None
412479
)
413480
self.check_crypto_key_type(private_key)
414-
return cast(RSAPrivateKey, private_key)
481+
rsa_private_key = cast(RSAPrivateKey, private_key)
482+
self._validate_rsa_key_size(rsa_private_key)
483+
return rsa_private_key
415484
except ValueError:
416485
try:
417486
public_key = load_pem_public_key(key_bytes)
418487
self.check_crypto_key_type(public_key)
419-
return cast(RSAPublicKey, public_key)
488+
rsa_public_key = cast(RSAPublicKey, public_key)
489+
self._validate_rsa_key_size(rsa_public_key)
490+
return rsa_public_key
420491
except (ValueError, UnsupportedAlgorithm):
421492
raise InvalidKeyError(
422493
"Could not parse the provided public key."
@@ -519,6 +590,9 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
519590
iqmp=from_base64url_uint(obj["qi"]),
520591
public_numbers=public_numbers,
521592
)
593+
private_key = numbers.private_key()
594+
RSAAlgorithm._validate_rsa_key_size_static(private_key)
595+
return private_key
522596
else:
523597
d = from_base64url_uint(obj["d"])
524598
p, q = rsa_recover_prime_factors(
@@ -535,13 +609,17 @@ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
535609
public_numbers=public_numbers,
536610
)
537611

538-
return numbers.private_key()
612+
private_key = numbers.private_key()
613+
RSAAlgorithm._validate_rsa_key_size_static(private_key)
614+
return private_key
539615
elif "n" in obj and "e" in obj:
540616
# Public key
541-
return RSAPublicNumbers(
617+
public_key = RSAPublicNumbers(
542618
from_base64url_uint(obj["e"]),
543619
from_base64url_uint(obj["n"]),
544620
).public_key()
621+
RSAAlgorithm._validate_rsa_key_size_static(public_key)
622+
return public_key
545623
else:
546624
raise InvalidKeyError("Not a public or private key")
547625

@@ -793,7 +871,7 @@ def __init__(self, **kwargs: Any) -> None:
793871
def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
794872
if not isinstance(key, (str, bytes)):
795873
self.check_crypto_key_type(key)
796-
return cast("AllowedOKPKeys", key)
874+
return key
797875

798876
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
799877
key_bytes = key.encode("utf-8") if isinstance(key, str) else key

0 commit comments

Comments
 (0)