diff --git a/spiffe/src/spiffe/svid/jwt_svid.py b/spiffe/src/spiffe/svid/jwt_svid.py index 1a1567da..603f0e6a 100644 --- a/spiffe/src/spiffe/svid/jwt_svid.py +++ b/spiffe/src/spiffe/svid/jwt_svid.py @@ -109,10 +109,13 @@ def parse_insecure(cls, token: str, audience: Set[str]) -> 'JwtSvid': validator.validate_header(header_params) claims = jwt.decode(token, options={'verify_signature': False}) validator.validate_claims(claims, audience) - spiffe_id = SpiffeId(claims['sub']) + sub_claim = claims.get('sub') + if not sub_claim: + raise InvalidTokenError('JWT token must contain a non-empty \'sub\' claim') + spiffe_id = SpiffeId(sub_claim) return JwtSvid(spiffe_id, claims['aud'], claims['exp'], claims, token) except PyJWTError as err: - raise InvalidTokenError(str(err)) + raise InvalidTokenError(str(err)) from err @classmethod def parse_and_validate( @@ -151,6 +154,9 @@ def parse_and_validate( header_params = jwt.get_unverified_header(token) validator = JwtSvidValidator() validator.validate_header(header_params) + alg = header_params.get('alg') + if not alg: + raise ArgumentError('header alg cannot be empty') key_id = header_params.get('kid') signing_key = jwt_bundle.get_jwt_authority(key_id) if not signing_key: @@ -163,7 +169,7 @@ def parse_and_validate( claims = jwt.decode( token, - algorithms=header_params.get('alg'), + algorithms=[alg], key=public_key_pem, audience=audience, options={ @@ -173,7 +179,10 @@ def parse_and_validate( }, ) - spiffe_id = SpiffeId(claims.get('sub', None)) + sub_claim = claims.get('sub') + if not sub_claim: + raise InvalidTokenError('JWT token must contain a non-empty \'sub\' claim') + spiffe_id = SpiffeId(sub_claim) return JwtSvid(spiffe_id, claims['aud'], claims['exp'], claims, token) except PyJWTError as err: diff --git a/spiffe/src/spiffe/svid/jwt_svid_validator.py b/spiffe/src/spiffe/svid/jwt_svid_validator.py index a843f4ca..5778b15d 100644 --- a/spiffe/src/spiffe/svid/jwt_svid_validator.py +++ b/spiffe/src/spiffe/svid/jwt_svid_validator.py @@ -108,7 +108,15 @@ def validate_claims(self, payload: Dict[str, Any], expected_audience: Set[str]) if not payload.get(claim): raise MissingClaimError(claim) - self._validate_exp(str(payload.get('exp'))) + exp_value = payload.get('exp') + if exp_value is None: + raise MissingClaimError('exp') + try: + numeric_exp = float(exp_value) + except (TypeError, ValueError): + raise InvalidClaimError("exp claim must be a numeric value") + self._validate_exp(numeric_exp) + aud_claim = payload.get('aud') if aud_claim is None: aud_set = set() @@ -121,13 +129,13 @@ def validate_claims(self, payload: Dict[str, Any], expected_audience: Set[str]) self._validate_aud(aud_set, expected_audience) @staticmethod - def _validate_exp(expiration_date: str) -> None: + def _validate_exp(expiration_date: float) -> None: """Verifies expiration. Note: If and when https://github.com/jpadilla/pyjwt/issues/599 is fixed, this can be simplified/removed. Args: - expiration_date: Date to check if it is expired. + expiration_date: Date to check if it is expired (numeric timestamp). Raises: TokenExpiredError: In case it is expired. diff --git a/spiffe/src/spiffe/svid/x509_svid.py b/spiffe/src/spiffe/svid/x509_svid.py index 4b5a1277..14844dd6 100644 --- a/spiffe/src/spiffe/svid/x509_svid.py +++ b/spiffe/src/spiffe/svid/x509_svid.py @@ -14,6 +14,10 @@ under the License. """ +from cryptography.x509.oid import ExtensionOID + +from spiffe.spiffe_id import spiffe_id + """ This module manages X.509 SVID objects. """ @@ -254,15 +258,44 @@ def load( ) -def _extract_spiffe_id(cert: Certificate) -> SpiffeId: - ext = cert.extensions.get_extension_for_oid(x509.ExtensionOID.SUBJECT_ALTERNATIVE_NAME) - if isinstance(ext.value, x509.SubjectAlternativeName): - sans = ext.value.get_values_for_type(x509.UniformResourceIdentifier) - if len(sans) == 0: +def _extract_spiffe_id(cert: x509.Certificate) -> SpiffeId: + try: + ext = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + except x509.ExtensionNotFound as e: + raise InvalidLeafCertificateError( + "Certificate does not contain a SubjectAlternativeName extension" + ) from e + + san_value = ext.value + if not isinstance(san_value, x509.SubjectAlternativeName): raise InvalidLeafCertificateError( - 'Certificate does not contain a SPIFFE ID in the URI SAN' + "Certificate does not contain a valid SubjectAlternativeName extension" ) - return SpiffeId(sans[0]) + + san = san_value + uri_sans = san.get_values_for_type(x509.UniformResourceIdentifier) + + # SPIFFE X.509-SVID: MUST contain exactly one URI SAN, and it MUST be a SPIFFE ID. + if len(uri_sans) == 0: + raise InvalidLeafCertificateError( + "Certificate does not contain a URI SAN (expected exactly one SPIFFE ID)" + ) + + if len(uri_sans) != 1: + raise InvalidLeafCertificateError( + "Certificate contains multiple URI SAN entries (expected exactly one SPIFFE ID)" + ) + + uri = uri_sans[0] + if not uri.startswith(spiffe_id.SCHEME_PREFIX): + raise InvalidLeafCertificateError("Certificate URI SAN is not a SPIFFE ID") + + try: + return SpiffeId(uri) + except ArgumentError as e: + raise InvalidLeafCertificateError( + f"Certificate contains a malformed SPIFFE ID in the URI SAN: {uri!r}" + ) from e def _validate_chain(cert_chain: List[Certificate]) -> None: @@ -274,13 +307,23 @@ def _validate_chain(cert_chain: List[Certificate]) -> None: def _validate_leaf_certificate(leaf: Certificate) -> None: - basic_constraints = leaf.extensions.get_extension_for_oid( - x509.ExtensionOID.BASIC_CONSTRAINTS - ).value + try: + basic_constraints = leaf.extensions.get_extension_for_oid( + x509.ExtensionOID.BASIC_CONSTRAINTS + ).value + except x509.ExtensionNotFound: + raise InvalidLeafCertificateError( + 'Leaf certificate must have BasicConstraints extension' + ) + if isinstance(basic_constraints, x509.BasicConstraints) and basic_constraints.ca: raise InvalidLeafCertificateError('Leaf certificate must not have CA flag set to true') - key_usage = leaf.extensions.get_extension_for_oid(x509.ExtensionOID.KEY_USAGE).value + try: + key_usage = leaf.extensions.get_extension_for_oid(x509.ExtensionOID.KEY_USAGE).value + except x509.ExtensionNotFound: + raise InvalidLeafCertificateError('Leaf certificate must have KeyUsage extension') + if isinstance(key_usage, x509.KeyUsage) and not key_usage.digital_signature: raise InvalidLeafCertificateError( 'Leaf certificate must have \'digitalSignature\' as key usage' @@ -296,14 +339,27 @@ def _validate_leaf_certificate(leaf: Certificate) -> None: def _validate_intermediate_certificate(cert: Certificate) -> None: - basic_constraints = cert.extensions.get_extension_for_oid( - x509.ExtensionOID.BASIC_CONSTRAINTS - ).value + try: + basic_constraints = cert.extensions.get_extension_for_oid( + x509.ExtensionOID.BASIC_CONSTRAINTS + ).value + except x509.ExtensionNotFound: + raise InvalidIntermediateCertificateError( + 'Intermediate certificate must have BasicConstraints extension' + ) + if isinstance(basic_constraints, x509.BasicConstraints) and not basic_constraints.ca: raise InvalidIntermediateCertificateError( 'Signing certificate must have CA flag set to true' ) - key_usage = cert.extensions.get_extension_for_oid(x509.ExtensionOID.KEY_USAGE).value + + try: + key_usage = cert.extensions.get_extension_for_oid(x509.ExtensionOID.KEY_USAGE).value + except x509.ExtensionNotFound: + raise InvalidIntermediateCertificateError( + 'Intermediate certificate must have KeyUsage extension' + ) + if isinstance(key_usage, x509.KeyUsage) and not key_usage.key_cert_sign: raise InvalidIntermediateCertificateError( 'Signing certificate must have \'keyCertSign\' as key usage' diff --git a/spiffe/src/spiffe/utils/certificate_utils.py b/spiffe/src/spiffe/utils/certificate_utils.py index e59bb2da..3a80ed2c 100644 --- a/spiffe/src/spiffe/utils/certificate_utils.py +++ b/spiffe/src/spiffe/utils/certificate_utils.py @@ -159,7 +159,7 @@ def write_certificates_to_file( cert_bytes = serialize_certificate(cert, encoding) certs_file.write(cert_bytes) except Exception as err: - raise StoreCertificateError(format(str(err))) from err + raise StoreCertificateError(str(err)) from err def serialize_certificate(certificate: Certificate, encoding: serialization.Encoding) -> bytes: diff --git a/spiffe/src/spiffe/workloadapi/handle_error.py b/spiffe/src/spiffe/workloadapi/handle_error.py index 07ee090f..f8549147 100644 --- a/spiffe/src/spiffe/workloadapi/handle_error.py +++ b/spiffe/src/spiffe/workloadapi/handle_error.py @@ -36,13 +36,13 @@ def wrapper(*args, **kw): except ArgumentError as ae: raise ae except PySpiffeError as pe: - raise error_cls(str(pe)) + raise error_cls(str(pe)) from pe except grpc.RpcError as rpc_error: if isinstance(rpc_error, grpc.Call): - raise error_cls(str(rpc_error.details())) - raise error_cls(DEFAULT_WL_API_ERROR_MESSAGE) + raise error_cls(str(rpc_error.details())) from rpc_error + raise error_cls(DEFAULT_WL_API_ERROR_MESSAGE) from rpc_error except Exception as e: - raise error_cls(str(e)) + raise error_cls(str(e)) from e return wrapper diff --git a/spiffe/tests/unit/svid/jwtsvid/test_jwt_svid.py b/spiffe/tests/unit/svid/jwtsvid/test_jwt_svid.py index 2d103ef7..18c21a42 100644 --- a/spiffe/tests/unit/svid/jwtsvid/test_jwt_svid.py +++ b/spiffe/tests/unit/svid/jwtsvid/test_jwt_svid.py @@ -269,7 +269,8 @@ def test_parse_and_validate_invalid_missing_sub(): with pytest.raises(InvalidTokenError) as exception: JwtSvid.parse_and_validate(token, JWT_BUNDLE, {'test'}) - assert str(exception.value) == 'Invalid SPIFFE ID: cannot be empty' + + assert "non-empty 'sub' claim" in str(exception.value) def test_parse_and_validate_invalid_missing_kid(): diff --git a/spiffe/tests/unit/svid/x509svid/test_x509_svid.py b/spiffe/tests/unit/svid/x509svid/test_x509_svid.py index aeaadada..218edc20 100644 --- a/spiffe/tests/unit/svid/x509svid/test_x509_svid.py +++ b/spiffe/tests/unit/svid/x509svid/test_x509_svid.py @@ -14,10 +14,12 @@ under the License. """ +import datetime import os import pytest -from cryptography.hazmat.primitives import serialization +from cryptography import x509 +from cryptography.hazmat.primitives import serialization, hashes from cryptography.x509 import Certificate from spiffe.spiffe_id.spiffe_id import SpiffeId @@ -34,7 +36,12 @@ ParseCertificateError, ParsePrivateKeyError, ) -from spiffe.svid.x509_svid import X509Svid, _extract_spiffe_id +from spiffe.svid.x509_svid import ( + X509Svid, + _extract_spiffe_id, + _validate_leaf_certificate, + _validate_intermediate_certificate, +) from cryptography.hazmat.primitives.asymmetric import ec, rsa from testutils.certs import TEST_CERTS_DIR @@ -72,14 +79,14 @@ def test_create_x509_svid_no_spiffe_id(mocker): with pytest.raises(ArgumentError) as exc_info: X509Svid(spiffe_id=None, cert_chain=[mocker.Mock()], private_key=mocker.Mock()) - assert str(exc_info.value) == 'spiffe_id cannot be None' + assert str(exc_info.value) == "spiffe_id cannot be None" def test_create_x509_svid_no_cert_chain(mocker): with pytest.raises(ArgumentError) as exc_info: X509Svid(spiffe_id=mocker.Mock(), cert_chain=[], private_key=mocker.Mock()) - assert str(exc_info.value) == 'cert_chain cannot be empty' + assert str(exc_info.value) == "cert_chain cannot be empty" def test_create_x509_svid_no_private_key(mocker): @@ -190,7 +197,7 @@ def test_parse_raw_corrupted_certificate(): assert ( str(exception.value) - == 'Error parsing certificate: Unable to parse DER X.509 certificate' + == "Error parsing certificate: Unable to parse DER X.509 certificate" ) @@ -239,7 +246,7 @@ def test_parse_invalid_spiffe_id(): assert ( str(exception.value) - == 'Invalid leaf certificate: Certificate does not contain a SPIFFE ID in the URI SAN' + == 'Invalid leaf certificate: Certificate does not contain a URI SAN (expected exactly one SPIFFE ID)' ) @@ -336,6 +343,104 @@ def test_load_from_pem_files(): assert _extract_spiffe_id(x509_svid.leaf) == expected_spiffe_id +def test_extract_spiffe_id_missing_san_extension(mocker): + """Regression test: Missing SubjectAlternativeName extension should raise InvalidLeafCertificateError.""" + mock_cert = mocker.Mock() + mock_extensions = mocker.Mock() + mock_extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound( + "SubjectAlternativeName extension not found", + x509.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + ) + mock_cert.extensions = mock_extensions + + with pytest.raises(InvalidLeafCertificateError) as exception: + _extract_spiffe_id(mock_cert) + + assert 'SubjectAlternativeName extension' in str(exception.value) + + +def test_validate_leaf_missing_basic_constraints_extension(mocker): + """Regression test: Missing BasicConstraints extension in leaf should raise InvalidLeafCertificateError.""" + mock_cert = mocker.Mock() + mock_extensions = mocker.Mock() + mock_extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound( + "BasicConstraints extension not found", x509.ExtensionOID.BASIC_CONSTRAINTS + ) + mock_cert.extensions = mock_extensions + + with pytest.raises(InvalidLeafCertificateError) as exception: + _validate_leaf_certificate(mock_cert) + + assert 'BasicConstraints extension' in str(exception.value) + + +def test_validate_leaf_missing_key_usage_extension(mocker): + """Regression test: Missing KeyUsage extension in leaf should raise InvalidLeafCertificateError.""" + mock_cert = mocker.Mock() + mock_extensions = mocker.Mock() + + # First call (BasicConstraints) succeeds, second call (KeyUsage) fails + basic_constraints = mocker.Mock() + basic_constraints.value = mocker.Mock() + basic_constraints.value.ca = False + + def get_extension_side_effect(oid): + if oid == x509.ExtensionOID.BASIC_CONSTRAINTS: + return basic_constraints + if oid == x509.ExtensionOID.KEY_USAGE: + raise x509.ExtensionNotFound("KeyUsage extension not found", oid) + raise AssertionError(f"Unexpected oid: {oid}") + + mock_extensions.get_extension_for_oid.side_effect = get_extension_side_effect + mock_cert.extensions = mock_extensions + + with pytest.raises(InvalidLeafCertificateError) as exception: + _validate_leaf_certificate(mock_cert) + + assert 'KeyUsage extension' in str(exception.value) + + +def test_validate_intermediate_missing_basic_constraints_extension(mocker): + """Regression test: Missing BasicConstraints extension in intermediate should raise InvalidIntermediateCertificateError.""" + mock_cert = mocker.Mock() + mock_extensions = mocker.Mock() + mock_extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound( + "BasicConstraints extension not found", x509.ExtensionOID.BASIC_CONSTRAINTS + ) + mock_cert.extensions = mock_extensions + + with pytest.raises(InvalidIntermediateCertificateError) as exception: + _validate_intermediate_certificate(mock_cert) + + assert 'BasicConstraints extension' in str(exception.value) + + +def test_validate_intermediate_missing_key_usage_extension(mocker): + """Regression test: Missing KeyUsage extension in intermediate should raise InvalidIntermediateCertificateError.""" + mock_cert = mocker.Mock() + mock_extensions = mocker.Mock() + + # First call (BasicConstraints) succeeds, second call (KeyUsage) fails + basic_constraints = mocker.Mock() + basic_constraints.value = mocker.Mock() + basic_constraints.value.ca = True + + def get_extension_side_effect(oid): + if oid == x509.ExtensionOID.BASIC_CONSTRAINTS: + return basic_constraints + if oid == x509.ExtensionOID.KEY_USAGE: + raise x509.ExtensionNotFound("KeyUsage extension not found", oid) + raise AssertionError(f"Unexpected oid: {oid}") + + mock_extensions.get_extension_for_oid.side_effect = get_extension_side_effect + mock_cert.extensions = mock_extensions + + with pytest.raises(InvalidIntermediateCertificateError) as exception: + _validate_intermediate_certificate(mock_cert) + + assert 'KeyUsage extension' in str(exception.value) + + def test_load_from_der_files(): chain_path = TEST_CERTS_DIR / '1-chain.der' key_path = TEST_CERTS_DIR / '1-key.der' @@ -534,7 +639,129 @@ def test_get_chain_returns_a_copy(): assert x509_svid.cert_chain is not x509_svid._cert_chain -def read_bytes(filename): +def test_extract_spiffe_id_rejects_multiple_uri_sans(): + """ + SPIFFE X.509-SVID profile: MUST contain exactly one URI SAN total. + Reject when there are multiple URI SANs even if exactly one is SPIFFE. + """ + cert, _key = _make_cert( + uri_sans=["spiffe://example.org/service", "https://example.org/"], + dns_sans=[], + ) + with pytest.raises(InvalidLeafCertificateError) as exc: + _extract_spiffe_id(cert) + + assert ( + 'Invalid leaf certificate: Certificate contains multiple URI SAN entries (expected exactly one SPIFFE ID)' + in str(exc.value) + ) + + +def test_extract_spiffe_id_rejects_single_uri_san_non_spiffe(): + """ + Exactly one URI SAN is present, but it's not a SPIFFE ID. + """ + cert, _key = _make_cert( + uri_sans=["https://example.org/"], + dns_sans=[], + ) + with pytest.raises(InvalidLeafCertificateError) as exc: + _extract_spiffe_id(cert) + + assert "SPIFFE ID" in str(exc.value) + + +def test_extract_spiffe_id_allows_dns_sans_with_single_spiffe_uri_san(): + """ + DNS SANs are not URI SANs; allow them in addition to the single SPIFFE URI SAN. + """ + cert, _key = _make_cert( + uri_sans=["spiffe://example.org/service"], + dns_sans=["example.org", "workload.example.org"], + ) + + assert _extract_spiffe_id(cert) == SpiffeId("spiffe://example.org/service") + + +def test_parse_rejects_multiple_uri_sans_even_if_one_is_spiffe(): + """ + End-to-end parse path should enforce the same URI SAN cardinality rule. + """ + cert, key = _make_cert( + uri_sans=["spiffe://example.org/service", "https://example.org/"], + dns_sans=[], + ) + cert_pem = cert.public_bytes(serialization.Encoding.PEM) + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + with pytest.raises(InvalidLeafCertificateError) as exc: + X509Svid.parse(cert_pem, key_pem) + + assert "URI SAN" in str(exc.value) + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + + +def _make_cert(*, uri_sans: list[str], dns_sans: list[str]): + """ + Generates a self-signed leaf certificate with SAN entries. This is only for tests. + """ + key = ec.generate_private_key(ec.SECP256R1()) + + subject = issuer = x509.Name( + [ + x509.NameAttribute(x509.oid.NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(x509.oid.NameOID.ORGANIZATION_NAME, "test"), + x509.NameAttribute(x509.oid.NameOID.COMMON_NAME, "leaf"), + ] + ) + + san_entries: list[x509.GeneralName] = [] + for u in uri_sans: + san_entries.append(x509.UniformResourceIdentifier(u)) + for d in dns_sans: + san_entries.append(x509.DNSName(d)) + + now = datetime.datetime.now(datetime.timezone.utc) + + builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now - datetime.timedelta(minutes=1)) + .not_valid_after(now + datetime.timedelta(hours=1)) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=True, # common for ECDSA leafs + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .add_extension(x509.SubjectAlternativeName(san_entries), critical=False) + ) + + cert = builder.sign(private_key=key, algorithm=hashes.SHA256()) + return cert, key + + +def read_bytes(filename: str) -> bytes: path = TEST_CERTS_DIR / filename - with open(path, 'rb') as file: + with open(path, "rb") as file: return file.read()