Skip to content

Commit

Permalink
[AAP-38778] chore: improve jwt expired exception logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Dostonbek1 committed Feb 18, 2025
1 parent 3fac5f3 commit b44b3a7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
42 changes: 32 additions & 10 deletions ansible_base/jwt_consumer/common/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from datetime import datetime
from typing import Optional, Tuple

import jwt
Expand All @@ -13,6 +14,7 @@

from ansible_base.jwt_consumer.common.cache import JWTCache
from ansible_base.jwt_consumer.common.cert import JWTCert, JWTCertException
from ansible_base.jwt_consumer.common.exceptions import InvalidTokenException
from ansible_base.lib.logging.runtime import log_excess_runtime
from ansible_base.lib.utils.auth import get_user_by_ansible_id
from ansible_base.lib.utils.translations import translatableConditionally as _
Expand Down Expand Up @@ -65,6 +67,7 @@ def parse_jwt_token(self, request):
return

token_from_header = request.headers.get("X-DAB-JW-TOKEN", None)
request_id = request.headers.get("X-Request-Id")
if not token_from_header:
logger.debug("X-DAB-JW-TOKEN header not set for JWT authentication")
return
Expand All @@ -81,7 +84,7 @@ def parse_jwt_token(self, request):
return None, None

try:
self.token = self.validate_token(token_from_header, cert_object.key)
self.token = self.validate_token(token_from_header, cert_object.key, request_id)
except jwt.exceptions.DecodeError as de:
# This exception means the decryption key failed... maybe it was because the cache is bad.
if not cert_object.cached:
Expand All @@ -99,7 +102,7 @@ def parse_jwt_token(self, request):
self.log_and_raise(_("JWT decoding failed: %(e)s, cached key was correct; check your key and generated token"), {"e": de})
# Since we got a new key, lets go ahead and try to validate the token again.
# If it fails this time we can just raise whatever
self.token = self.validate_token(token_from_header, cert_object.key)
self.token = self.validate_token(token_from_header, cert_object.key, request_id)

# Let's see if we have the same user info in the cache already
is_cached, user_defaults = self.cache.check_user_in_cache(self.token)
Expand Down Expand Up @@ -139,9 +142,13 @@ def parse_jwt_token(self, request):

logger.info(f"User {self.user.username} authenticated from JWT auth")

def log_and_raise(self, conditional_translate_object, expand_values={}):
def log_and_raise(self, conditional_translate_object, expand_values={}, error_code=None):
logger.error(conditional_translate_object.not_translated() % expand_values)
raise AuthenticationFailed(conditional_translate_object.translated() % expand_values)
translated_error_message = conditional_translate_object.translated() % expand_values
if error_code == 498:
raise InvalidTokenException(translated_error_message)
else:
raise AuthenticationFailed(translated_error_message)

def map_user_fields(self):
if self.token is None or self.user is None:
Expand All @@ -162,26 +169,31 @@ def map_user_fields(self):
logger.info(f"Saving user {self.user.username}")
self.user.save()

def validate_token(self, unencrypted_token, decryption_key):
def validate_token(self, unencrypted_token, decryption_key, request_id=None):
validated_body = None

local_required_field = ["sub", "user_data", "exp", "objects", "object_roles", "global_roles", "version"]

# Decrypt the token
try:
logger.info("Decrypting token")
validated_body = jwt.decode(
validated_body = self.decode_jwt_token(
unencrypted_token,
decryption_key,
audience="ansible-services",
options={"require": local_required_field},
issuer="ansible-issuer",
algorithms=["RS256"],
)
except jwt.exceptions.DecodeError as e:
raise e # This will be handled higher up
except jwt.exceptions.ExpiredSignatureError:
self.log_and_raise(_("JWT has expired"))
expired_token = self.decode_jwt_token(
unencrypted_token,
decryption_key,
options={"require": local_required_field, "verify_exp": False},
)
expired_time = expired_token.get("exp")
now = datetime.now().timestamp()
time_diff = int(now - expired_time)
self.log_and_raise(_(f"JWT expired {time_diff} seconds ago - check for clock skew. Request ID: {request_id}"), error_code=498)
except jwt.exceptions.InvalidAudienceError:
self.log_and_raise(_("JWT did not come for the correct audience"))
except jwt.exceptions.InvalidIssuerError:
Expand All @@ -205,6 +217,16 @@ def validate_token(self, unencrypted_token, decryption_key):

return validated_body

def decode_jwt_token(self, unencrypted_token, decryption_key, options):
return jwt.decode(
unencrypted_token,
decryption_key,
audience="ansible-services",
options=options,
issuer="ansible-issuer",
algorithms=["RS256"],
)

def get_role_definition(self, name: str) -> Optional[Model]:
"""Simply get the RoleDefinition from the database if it exists and handler corner cases
Expand Down
10 changes: 10 additions & 0 deletions ansible_base/jwt_consumer/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from rest_framework.exceptions import APIException


class InvalidService(Exception):
def __init__(self, service):
super().__init__(f"This authentication class requires {service}.")


class InvalidTokenException(APIException):
status_code = 498
status_text = "Invalid Token"
default_detail = "Invalid or expired token."
default_code = "invalid_token"
7 changes: 6 additions & 1 deletion test_app/tests/jwt_consumer/common/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ansible_base.jwt_consumer.common.auth import JWTAuthentication, JWTCommonAuth, default_mapped_user_fields
from ansible_base.jwt_consumer.common.cert import JWTCert, JWTCertException
from ansible_base.jwt_consumer.common.exceptions import InvalidTokenException
from ansible_base.lib.utils.translations import translatableConditionally as _
from ansible_base.rbac.models import RoleDefinition, RoleUserAssignment
from ansible_base.rbac.permission_registry import permission_registry
Expand Down Expand Up @@ -179,9 +180,13 @@ def test_validate_token_expired_token(self, jwt_token, test_encryption_public_ke
jwt_token.unencrypted_token['exp'] = datetime.now() + timedelta(minutes=-10)
# Test the function
common_auth = JWTCommonAuth()
with pytest.raises(AuthenticationFailed, match="JWT has expired"):
with pytest.raises(InvalidTokenException) as excinfo:
common_auth.validate_token(jwt_token.encrypt_token(), test_encryption_public_key)

assert "JWT expired" in str(excinfo.value)
assert "check for clock skew" in str(excinfo.value)
assert "Request ID" in str(excinfo.value)

@pytest.mark.django_db
@pytest.mark.parametrize(
"item,exception",
Expand Down

0 comments on commit b44b3a7

Please sign in to comment.