Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions spiffe/src/spiffe/svid/jwt_svid.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,13 @@ def parse_insecure(cls, token: str, audience: Set[str]) -> 'JwtSvid':
validator.validate_header(header_params)
claims = jwt.decode(token, options={'verify_signature': False})
validator.validate_claims(claims, audience)
spiffe_id = SpiffeId(claims['sub'])
sub_claim = claims.get('sub')
if not sub_claim:
raise InvalidTokenError('JWT token must contain a non-empty \'sub\' claim')
spiffe_id = SpiffeId(sub_claim)
return JwtSvid(spiffe_id, claims['aud'], claims['exp'], claims, token)
except PyJWTError as err:
raise InvalidTokenError(str(err))
raise InvalidTokenError(str(err)) from err

@classmethod
def parse_and_validate(
Expand Down Expand Up @@ -151,6 +154,9 @@ def parse_and_validate(
header_params = jwt.get_unverified_header(token)
validator = JwtSvidValidator()
validator.validate_header(header_params)
alg = header_params.get('alg')
if not alg:
raise ArgumentError('header alg cannot be empty')
key_id = header_params.get('kid')
signing_key = jwt_bundle.get_jwt_authority(key_id)
if not signing_key:
Expand All @@ -163,7 +169,7 @@ def parse_and_validate(

claims = jwt.decode(
token,
algorithms=header_params.get('alg'),
algorithms=[alg],
key=public_key_pem,
audience=audience,
options={
Expand All @@ -173,7 +179,10 @@ def parse_and_validate(
},
)

spiffe_id = SpiffeId(claims.get('sub', None))
sub_claim = claims.get('sub')
if not sub_claim:
raise InvalidTokenError('JWT token must contain a non-empty \'sub\' claim')
spiffe_id = SpiffeId(sub_claim)

return JwtSvid(spiffe_id, claims['aud'], claims['exp'], claims, token)
except PyJWTError as err:
Expand Down
14 changes: 11 additions & 3 deletions spiffe/src/spiffe/svid/jwt_svid_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,15 @@ def validate_claims(self, payload: Dict[str, Any], expected_audience: Set[str])
if not payload.get(claim):
raise MissingClaimError(claim)

self._validate_exp(str(payload.get('exp')))
exp_value = payload.get('exp')
if exp_value is None:
raise MissingClaimError('exp')
try:
numeric_exp = float(exp_value)
except (TypeError, ValueError):
raise InvalidClaimError("exp claim must be a numeric value")
self._validate_exp(numeric_exp)

aud_claim = payload.get('aud')
if aud_claim is None:
aud_set = set()
Expand All @@ -121,13 +129,13 @@ def validate_claims(self, payload: Dict[str, Any], expected_audience: Set[str])
self._validate_aud(aud_set, expected_audience)

@staticmethod
def _validate_exp(expiration_date: str) -> None:
def _validate_exp(expiration_date: float) -> None:
"""Verifies expiration.

Note: If and when https://github.com/jpadilla/pyjwt/issues/599 is fixed, this can be simplified/removed.

Args:
expiration_date: Date to check if it is expired.
expiration_date: Date to check if it is expired (numeric timestamp).

Raises:
TokenExpiredError: In case it is expired.
Expand Down
69 changes: 56 additions & 13 deletions spiffe/src/spiffe/svid/x509_svid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
under the License.
"""

from spiffe.spiffe_id import spiffe_id

"""
This module manages X.509 SVID objects.
"""
Expand Down Expand Up @@ -255,14 +257,32 @@ def load(


def _extract_spiffe_id(cert: Certificate) -> SpiffeId:
ext = cert.extensions.get_extension_for_oid(x509.ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
if isinstance(ext.value, x509.SubjectAlternativeName):
sans = ext.value.get_values_for_type(x509.UniformResourceIdentifier)
if len(sans) == 0:
try:
ext = cert.extensions.get_extension_for_oid(x509.ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
except x509.ExtensionNotFound:
raise InvalidLeafCertificateError(
'Certificate does not contain a SubjectAlternativeName extension'
)

if not isinstance(ext.value, x509.SubjectAlternativeName):
raise InvalidLeafCertificateError(
'Certificate does not contain a SPIFFE ID in the URI SAN'
)

uri_sans = ext.value.get_values_for_type(x509.UniformResourceIdentifier)
spiffe_uris = [uri for uri in uri_sans if uri.startswith(spiffe_id.SCHEME_PREFIX)]

if len(spiffe_uris) == 0:
raise InvalidLeafCertificateError(
'Certificate does not contain a SPIFFE ID in the URI SAN'
)
return SpiffeId(sans[0])

if len(spiffe_uris) > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem like we have test coverage for this new validation. I think that it would be great to have it covered by tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test coverage.

raise InvalidLeafCertificateError(
'Certificate contains multiple SPIFFE IDs in the URI SAN'
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this validation logic checks only SPIFFE URIs, not all URI SANs. This is a problem because the spec says that "An X.509 SVID MUST contain exactly one URI SAN, and by extension, exactly one SPIFFE ID.".
The way I see this code, a certificate with spiffe://example.org/service + https://example.org would still be accepted by py-spiffe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. Fixed


return SpiffeId(spiffe_uris[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could raise an exception if the SPIFFE URI is malformed. Maybe it could be wrapped in a try-catch to provide clearer error context?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



def _validate_chain(cert_chain: List[Certificate]) -> None:
Expand All @@ -274,13 +294,23 @@ def _validate_chain(cert_chain: List[Certificate]) -> None:


def _validate_leaf_certificate(leaf: Certificate) -> None:
basic_constraints = leaf.extensions.get_extension_for_oid(
x509.ExtensionOID.BASIC_CONSTRAINTS
).value
try:
basic_constraints = leaf.extensions.get_extension_for_oid(
x509.ExtensionOID.BASIC_CONSTRAINTS
).value
except x509.ExtensionNotFound:
raise InvalidLeafCertificateError(
'Leaf certificate must have BasicConstraints extension'
)

if isinstance(basic_constraints, x509.BasicConstraints) and basic_constraints.ca:
raise InvalidLeafCertificateError('Leaf certificate must not have CA flag set to true')

key_usage = leaf.extensions.get_extension_for_oid(x509.ExtensionOID.KEY_USAGE).value
try:
key_usage = leaf.extensions.get_extension_for_oid(x509.ExtensionOID.KEY_USAGE).value
except x509.ExtensionNotFound:
raise InvalidLeafCertificateError('Leaf certificate must have KeyUsage extension')

if isinstance(key_usage, x509.KeyUsage) and not key_usage.digital_signature:
raise InvalidLeafCertificateError(
'Leaf certificate must have \'digitalSignature\' as key usage'
Expand All @@ -296,14 +326,27 @@ def _validate_leaf_certificate(leaf: Certificate) -> None:


def _validate_intermediate_certificate(cert: Certificate) -> None:
basic_constraints = cert.extensions.get_extension_for_oid(
x509.ExtensionOID.BASIC_CONSTRAINTS
).value
try:
basic_constraints = cert.extensions.get_extension_for_oid(
x509.ExtensionOID.BASIC_CONSTRAINTS
).value
except x509.ExtensionNotFound:
raise InvalidIntermediateCertificateError(
'Intermediate certificate must have BasicConstraints extension'
)

if isinstance(basic_constraints, x509.BasicConstraints) and not basic_constraints.ca:
raise InvalidIntermediateCertificateError(
'Signing certificate must have CA flag set to true'
)
key_usage = cert.extensions.get_extension_for_oid(x509.ExtensionOID.KEY_USAGE).value

try:
key_usage = cert.extensions.get_extension_for_oid(x509.ExtensionOID.KEY_USAGE).value
except x509.ExtensionNotFound:
raise InvalidIntermediateCertificateError(
'Intermediate certificate must have KeyUsage extension'
)

if isinstance(key_usage, x509.KeyUsage) and not key_usage.key_cert_sign:
raise InvalidIntermediateCertificateError(
'Signing certificate must have \'keyCertSign\' as key usage'
Expand Down
2 changes: 1 addition & 1 deletion spiffe/src/spiffe/utils/certificate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def write_certificates_to_file(
cert_bytes = serialize_certificate(cert, encoding)
certs_file.write(cert_bytes)
except Exception as err:
raise StoreCertificateError(format(str(err))) from err
raise StoreCertificateError(str(err)) from err


def serialize_certificate(certificate: Certificate, encoding: serialization.Encoding) -> bytes:
Expand Down
8 changes: 4 additions & 4 deletions spiffe/src/spiffe/workloadapi/handle_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def wrapper(*args, **kw):
except ArgumentError as ae:
raise ae
except PySpiffeError as pe:
raise error_cls(str(pe))
raise error_cls(str(pe)) from pe
except grpc.RpcError as rpc_error:
if isinstance(rpc_error, grpc.Call):
raise error_cls(str(rpc_error.details()))
raise error_cls(DEFAULT_WL_API_ERROR_MESSAGE)
raise error_cls(str(rpc_error.details())) from rpc_error
raise error_cls(DEFAULT_WL_API_ERROR_MESSAGE) from rpc_error
except Exception as e:
raise error_cls(str(e))
raise error_cls(str(e)) from e

return wrapper

Expand Down
3 changes: 2 additions & 1 deletion spiffe/tests/unit/svid/jwtsvid/test_jwt_svid.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def test_parse_and_validate_invalid_missing_sub():

with pytest.raises(InvalidTokenError) as exception:
JwtSvid.parse_and_validate(token, JWT_BUNDLE, {'test'})
assert str(exception.value) == 'Invalid SPIFFE ID: cannot be empty'

assert "non-empty 'sub' claim" in str(exception.value)


def test_parse_and_validate_invalid_missing_kid():
Expand Down
111 changes: 111 additions & 0 deletions spiffe/tests/unit/svid/x509svid/test_x509_svid.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,117 @@ def test_load_from_pem_files():
assert _extract_spiffe_id(x509_svid.leaf) == expected_spiffe_id


def test_extract_spiffe_id_missing_san_extension(mocker):
"""Regression test: Missing SubjectAlternativeName extension should raise InvalidLeafCertificateError."""
from cryptography import x509

mock_cert = mocker.Mock()
mock_extensions = mocker.Mock()
mock_extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound(
"SubjectAlternativeName extension not found",
x509.ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
)
mock_cert.extensions = mock_extensions

with pytest.raises(InvalidLeafCertificateError) as exception:
_extract_spiffe_id(mock_cert)

assert 'SubjectAlternativeName extension' in str(exception.value)


def test_validate_leaf_missing_basic_constraints_extension(mocker):
"""Regression test: Missing BasicConstraints extension in leaf should raise InvalidLeafCertificateError."""
from cryptography import x509

mock_cert = mocker.Mock()
mock_extensions = mocker.Mock()
mock_extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound(
"BasicConstraints extension not found", x509.ExtensionOID.BASIC_CONSTRAINTS
)
mock_cert.extensions = mock_extensions

from spiffe.svid.x509_svid import _validate_leaf_certificate

with pytest.raises(InvalidLeafCertificateError) as exception:
_validate_leaf_certificate(mock_cert)

assert 'BasicConstraints extension' in str(exception.value)


def test_validate_leaf_missing_key_usage_extension(mocker):
"""Regression test: Missing KeyUsage extension in leaf should raise InvalidLeafCertificateError."""
from cryptography import x509
from spiffe.svid.x509_svid import _validate_leaf_certificate

mock_cert = mocker.Mock()
mock_extensions = mocker.Mock()

# First call (BasicConstraints) succeeds, second call (KeyUsage) fails
basic_constraints = mocker.Mock()
basic_constraints.value = mocker.Mock()
basic_constraints.value.ca = False

def get_extension_side_effect(oid):
if oid == x509.ExtensionOID.BASIC_CONSTRAINTS:
return basic_constraints
elif oid == x509.ExtensionOID.KEY_USAGE:
raise x509.ExtensionNotFound("KeyUsage extension not found", oid)

mock_extensions.get_extension_for_oid.side_effect = get_extension_side_effect
mock_cert.extensions = mock_extensions

with pytest.raises(InvalidLeafCertificateError) as exception:
_validate_leaf_certificate(mock_cert)

assert 'KeyUsage extension' in str(exception.value)


def test_validate_intermediate_missing_basic_constraints_extension(mocker):
"""Regression test: Missing BasicConstraints extension in intermediate should raise InvalidIntermediateCertificateError."""
from cryptography import x509
from spiffe.svid.x509_svid import _validate_intermediate_certificate

mock_cert = mocker.Mock()
mock_extensions = mocker.Mock()
mock_extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound(
"BasicConstraints extension not found", x509.ExtensionOID.BASIC_CONSTRAINTS
)
mock_cert.extensions = mock_extensions

with pytest.raises(InvalidIntermediateCertificateError) as exception:
_validate_intermediate_certificate(mock_cert)

assert 'BasicConstraints extension' in str(exception.value)


def test_validate_intermediate_missing_key_usage_extension(mocker):
"""Regression test: Missing KeyUsage extension in intermediate should raise InvalidIntermediateCertificateError."""
from cryptography import x509
from spiffe.svid.x509_svid import _validate_intermediate_certificate

mock_cert = mocker.Mock()
mock_extensions = mocker.Mock()

# First call (BasicConstraints) succeeds, second call (KeyUsage) fails
basic_constraints = mocker.Mock()
basic_constraints.value = mocker.Mock()
basic_constraints.value.ca = True

def get_extension_side_effect(oid):
if oid == x509.ExtensionOID.BASIC_CONSTRAINTS:
return basic_constraints
elif oid == x509.ExtensionOID.KEY_USAGE:
raise x509.ExtensionNotFound("KeyUsage extension not found", oid)

mock_extensions.get_extension_for_oid.side_effect = get_extension_side_effect
mock_cert.extensions = mock_extensions

with pytest.raises(InvalidIntermediateCertificateError) as exception:
_validate_intermediate_certificate(mock_cert)

assert 'KeyUsage extension' in str(exception.value)


def test_load_from_der_files():
chain_path = TEST_CERTS_DIR / '1-chain.der'
key_path = TEST_CERTS_DIR / '1-key.der'
Expand Down