From 52ee5917dd824ba19230f4acec25c118eed5535d Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 14:35:57 -0500 Subject: [PATCH 01/11] Remove all of the warnings from within the codebase --- django_ariadne_jwt/backends.py | 1 + django_ariadne_jwt/utils.py | 15 +++++++++++++-- tests/test_middleware.py | 4 ++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/django_ariadne_jwt/backends.py b/django_ariadne_jwt/backends.py index 67d0628..96f5d96 100644 --- a/django_ariadne_jwt/backends.py +++ b/django_ariadne_jwt/backends.py @@ -1,5 +1,6 @@ """GraphQL auth backends module""" from django.contrib.auth import get_user_model +from django.conf import settings from .exceptions import JSONWebTokenError from .utils import decode_jwt diff --git a/django_ariadne_jwt/utils.py b/django_ariadne_jwt/utils.py index f1f9aa6..d1eac79 100644 --- a/django_ariadne_jwt/utils.py +++ b/django_ariadne_jwt/utils.py @@ -16,6 +16,7 @@ ORIGINAL_IAT_CLAIM = "orig_iat" HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION" AUTHORIZATION_HEADER_PREFIX = "Token" +DEFAULT_JWT_ALGORITHM = "HS256" def get_token_from_http_header(request): @@ -76,7 +77,11 @@ def create_jwt(user, extra_payload={}): "exp": int((now + expiration_delta).timestamp()), } - return jwt.encode(payload, settings.SECRET_KEY).decode("utf-8") + return jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=getattr(settings, "JWT_ALGORITHM", DEFAULT_JWT_ALGORITHM), + ).decode("utf-8") def refresh_jwt(token): @@ -104,7 +109,13 @@ def refresh_jwt(token): def decode_jwt(token): """Decodes a JWT""" try: - decoded = jwt.decode(token, settings.SECRET_KEY) + decoded = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=getattr( + settings, "JWT_ALGORITHMS", DEFAULT_JWT_ALGORITHM + ), + ) except ExpiredSignatureError: raise ExpiredTokenError() diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 1b0244d..13c7b1f 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -51,7 +51,7 @@ def test_without_user_and_with_valid_token(self): def next(root, info, **kwargs): self.assertTrue(hasattr(info.context, "user")) - self.assertEquals(info.context.user, self.user) + self.assertEqual(info.context.user, self.user) next = Mock(wraps=next) @@ -80,7 +80,7 @@ def test_with_user_and_valid_token(self): def next(root, info, **kwargs): self.assertTrue(hasattr(info.context, "user")) - self.assertEquals(info.context.user, self.user) + self.assertEqual(info.context.user, self.user) settings = { "AUTHENTICATION_BACKENDS": ( From 0e4b7c6057f199022cff79476a1bdf2ba9f54b6f Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 14:43:41 -0500 Subject: [PATCH 02/11] Reduce this into a simpler case --- django_ariadne_jwt/utils.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/django_ariadne_jwt/utils.py b/django_ariadne_jwt/utils.py index d1eac79..bba11cd 100644 --- a/django_ariadne_jwt/utils.py +++ b/django_ariadne_jwt/utils.py @@ -21,22 +21,13 @@ def get_token_from_http_header(request): """Retrieves the http authorization header from the request""" - token = None + header = request.META.get(HTTP_AUTHORIZATION_HEADER, False) + if header is False: + return None - try: - header = request.META.get(HTTP_AUTHORIZATION_HEADER, "") - - except AttributeError: - header = "" - - try: - prefix, payload = header.split() - - except ValueError: - prefix = "-" - - if prefix.lower() == AUTHORIZATION_HEADER_PREFIX.lower(): - token = payload + prefix, token = header.split() + if prefix.lower() != AUTHORIZATION_HEADER_PREFIX.lower(): + return None return token From 7a5572bc5099c1253845141fcab84ee8480246bd Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 14:58:53 -0500 Subject: [PATCH 03/11] Add a few flex points to the JSONWebTokenBackend This moves all loading of users to `get_user()` and expands it to taking kwargs. Developers can override the values that used for finding the user by overriding `get_user_kwargs`. Note: if a `user_id` is provided to `get_user`, it is translated to `pk` and will overwrite any `pk` provided. --- django_ariadne_jwt/backends.py | 36 ++++++++++++++-------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/django_ariadne_jwt/backends.py b/django_ariadne_jwt/backends.py index 96f5d96..8c66303 100644 --- a/django_ariadne_jwt/backends.py +++ b/django_ariadne_jwt/backends.py @@ -10,35 +10,29 @@ class JSONWebTokenBackend(object): def authenticate(self, request, token=None, **kwargs): """Performs authentication""" - user = None + if token is None: + return - if token is not None: - token_data = None - - try: - token_data = decode_jwt(token) - - except JSONWebTokenError: - pass - - if token_data is not None: - User = get_user_model() - credentials = {User.USERNAME_FIELD: token_data["user"]} - - try: - user = User.objects.get(**credentials) + try: + token_data = decode_jwt(token) - except User.DoesNotExist: - pass + except JSONWebTokenError: + return - return user + return self.get_user(**self.get_user_kwargs(token_data)) - def get_user(self, user_id): + def get_user(self, user_id=None, **kwargs): """Gets a user from its id""" User = get_user_model() + if user_id is not None: + kwargs["pk"] = user_id try: - return User.objects.get(pk=user_id) + return User.objects.get(**kwargs) except User.DoesNotExist: return None + + def get_user_kwargs(self, token_data): + User = get_user_model() + return {User.USERNAME_FIELD: token_data["user"]} From 62b93b62b9bfa59d7fffbe92e618dc08b41a086c Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 15:43:49 -0500 Subject: [PATCH 04/11] Rework token utilities to be configurable as part of the backend This allows developers to control what's put into the token payload without having to write their own resolvers. --- django_ariadne_jwt/backends.py | 133 ++++++++++++++++++++++++++++++- django_ariadne_jwt/middleware.py | 4 +- django_ariadne_jwt/resolvers.py | 14 +--- django_ariadne_jwt/utils.py | 117 --------------------------- tests/test_backends.py | 6 +- tests/test_decorators.py | 4 +- tests/test_middleware.py | 8 +- tests/test_resolvers.py | 13 +-- tests/test_utils.py | 60 +++++++++----- 9 files changed, 191 insertions(+), 168 deletions(-) delete mode 100644 django_ariadne_jwt/utils.py diff --git a/django_ariadne_jwt/backends.py b/django_ariadne_jwt/backends.py index 8c66303..1ee5f46 100644 --- a/django_ariadne_jwt/backends.py +++ b/django_ariadne_jwt/backends.py @@ -1,20 +1,60 @@ """GraphQL auth backends module""" +import datetime from django.contrib.auth import get_user_model from django.conf import settings -from .exceptions import JSONWebTokenError -from .utils import decode_jwt +from django.utils import timezone +from django.utils.module_loading import import_string +from django.utils.translation import ugettext_lazy as _ +import jwt +from jwt.exceptions import DecodeError, ExpiredSignatureError + +from .exceptions import ( + AuthenticatedUserRequiredError, + ExpiredTokenError, + InvalidTokenError, + JSONWebTokenError, + MaximumTokenLifeReachedError, +) + + +def load_backend(): + return import_string( + getattr( + settings, + "JWT_BACKEND", + "django_ariadne_jwt.backends.JSONWebTokenBackend", + ) + )() class JSONWebTokenBackend(object): """Authenticates against a JSON Web Token""" + DEFAULT_JWT_ALGORITHM = "HS256" + ORIGINAL_IAT_CLAIM = "orig_iat" + HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION" + AUTHORIZATION_HEADER_PREFIX = "Token" + DEFAULT_JWT_ALGORITHM = "HS256" + + def get_token_from_http_header(self, request): + """Retrieves the http authorization header from the request""" + header = request.META.get(self.HTTP_AUTHORIZATION_HEADER, False) + if header is False: + return None + + prefix, token = header.split() + if prefix.lower() != self.AUTHORIZATION_HEADER_PREFIX.lower(): + return None + + return token + def authenticate(self, request, token=None, **kwargs): """Performs authentication""" if token is None: return try: - token_data = decode_jwt(token) + token_data = self.decode(token) except JSONWebTokenError: return @@ -36,3 +76,90 @@ def get_user(self, user_id=None, **kwargs): def get_user_kwargs(self, token_data): User = get_user_model() return {User.USERNAME_FIELD: token_data["user"]} + + def create(self, user, extra_payload={}): + """Creates a JWT for an authenticated user""" + if not user.is_authenticated: + raise AuthenticatedUserRequiredError( + "JWT generationr requires an authenticated user" + ) + + expiration_delta = getattr( + settings, "JWT_EXPIRATION_DELTA", datetime.timedelta(minutes=5) + ) + + now = timezone.localtime() + + payload = { + **extra_payload, + "user": user.username, + "iat": int(now.timestamp()), + "exp": int((now + expiration_delta).timestamp()), + } + + return jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=getattr( + settings, "JWT_ALGORITHM", self.DEFAULT_JWT_ALGORITHM + ), + ).decode("utf-8") + + def refresh(self, token): + """Refreshes a JWT if possible""" + decoded = self.decode(token) + + oldest_iat_claim = decoded.get( + self.ORIGINAL_IAT_CLAIM, decoded.get("iat") + ) + + if self.has_reached_end_of_life(oldest_iat_claim): + raise MaximumTokenLifeReachedError() + + User = get_user_model() + + credentials = {User.USERNAME_FIELD: decoded["user"]} + + try: + user = User.objects.get(**credentials) + + except User.DoesNotExist: + raise InvalidTokenError(_("User not found")) + + return self.create(user, {self.ORIGINAL_IAT_CLAIM: decoded["iat"]}) + + def decode(self, token): + """Decodes a JWT""" + try: + decoded = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=getattr( + settings, "JWT_ALGORITHMS", self.DEFAULT_JWT_ALGORITHM + ), + ) + + except ExpiredSignatureError: + raise ExpiredTokenError() + + except DecodeError: + raise InvalidTokenError() + + return decoded + + def has_reached_end_of_life(self, oldest_iat_claim): + """Checks if the token has reached its end of life""" + expiration_delta = getattr( + settings, + "JWT_REFRESH_EXPIRATION_DELTA", + datetime.timedelta(days=7), + ) + + now = timezone.localtime() + original_issue_time = timezone.make_aware( + datetime.datetime.fromtimestamp(int(oldest_iat_claim)) + ) + + end_of_life = original_issue_time + expiration_delta + + return now > end_of_life diff --git a/django_ariadne_jwt/middleware.py b/django_ariadne_jwt/middleware.py index 2ad54e7..a38b560 100644 --- a/django_ariadne_jwt/middleware.py +++ b/django_ariadne_jwt/middleware.py @@ -1,7 +1,7 @@ """ariadne_django_jwt middleware module""" from django.contrib.auth import authenticate from django.contrib.auth.models import AnonymousUser -from .utils import get_token_from_http_header +from .backends import load_backend __all__ = ["JSONWebTokenMiddleware"] @@ -13,7 +13,7 @@ def resolve(self, next, root, info, **kwargs): """Performs the middleware relevant operations""" request = info.context - token = get_token_from_http_header(request) + token = load_backend().get_token_from_http_header(request) if token is not None: user = getattr(request, "user", None) diff --git a/django_ariadne_jwt/resolvers.py b/django_ariadne_jwt/resolvers.py index ba2012b..2b54cff 100644 --- a/django_ariadne_jwt/resolvers.py +++ b/django_ariadne_jwt/resolvers.py @@ -1,13 +1,12 @@ """ariadne_django_jwt resolvers module""" from ariadne import gql from django.contrib.auth import authenticate +from .backends import load_backend from .exceptions import ( ExpiredTokenError, InvalidTokenError, MaximumTokenLifeReachedError, ) -from .utils import create_jwt, decode_jwt, refresh_jwt - auth_token_definition = gql( """ @@ -29,20 +28,15 @@ def resolve_token_auth(parent, info, **credentials): """Resolves the token auth mutation""" - token = None user = authenticate(info.context, **credentials) - - if user is not None: - token = create_jwt(user) - - return {"token": token} + return {"token": load_backend().create(user) if user else None} def resolve_refresh_token(parent, info, token): """Resolves the resfresh token mutaiton""" try: - token = refresh_jwt(token) + token = load_backend().refresh(token) except (InvalidTokenError, MaximumTokenLifeReachedError): token = None @@ -55,7 +49,7 @@ def resolve_verify_token(parent, info, token: str): token_verification = {} try: - decoded = decode_jwt(token) + decoded = load_backend().decode(token) token_verification["valid"] = True token_verification["user"] = decoded.get("user") diff --git a/django_ariadne_jwt/utils.py b/django_ariadne_jwt/utils.py deleted file mode 100644 index bba11cd..0000000 --- a/django_ariadne_jwt/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -"""ariadne_django_jwt utils module""" -import datetime -import jwt -from django.contrib.auth import get_user_model -from django.conf import settings -from django.utils import timezone -from django.utils.translation import ugettext_lazy as _ -from jwt.exceptions import DecodeError, ExpiredSignatureError -from .exceptions import ( - AuthenticatedUserRequiredError, - ExpiredTokenError, - MaximumTokenLifeReachedError, - InvalidTokenError, -) - -ORIGINAL_IAT_CLAIM = "orig_iat" -HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION" -AUTHORIZATION_HEADER_PREFIX = "Token" -DEFAULT_JWT_ALGORITHM = "HS256" - - -def get_token_from_http_header(request): - """Retrieves the http authorization header from the request""" - header = request.META.get(HTTP_AUTHORIZATION_HEADER, False) - if header is False: - return None - - prefix, token = header.split() - if prefix.lower() != AUTHORIZATION_HEADER_PREFIX.lower(): - return None - - return token - - -def has_reached_end_of_life(oldest_iat_claim): - """Checks if the token has reached its end of life""" - expiration_delta = getattr( - settings, "JWT_REFRESH_EXPIRATION_DELTA", datetime.timedelta(days=7) - ) - - now = timezone.localtime() - original_issue_time = timezone.make_aware( - datetime.datetime.fromtimestamp(int(oldest_iat_claim)) - ) - - end_of_life = original_issue_time + expiration_delta - - return now > end_of_life - - -def create_jwt(user, extra_payload={}): - """Creates a JWT for an authenticated user""" - if not user.is_authenticated: - raise AuthenticatedUserRequiredError( - "JWT generationr requires an authenticated user" - ) - - expiration_delta = getattr( - settings, "JWT_EXPIRATION_DELTA", datetime.timedelta(minutes=5) - ) - - now = timezone.localtime() - - payload = { - **extra_payload, - "user": user.username, - "iat": int(now.timestamp()), - "exp": int((now + expiration_delta).timestamp()), - } - - return jwt.encode( - payload, - settings.SECRET_KEY, - algorithm=getattr(settings, "JWT_ALGORITHM", DEFAULT_JWT_ALGORITHM), - ).decode("utf-8") - - -def refresh_jwt(token): - """Refreshes a JWT if possible""" - decoded = decode_jwt(token) - - oldest_iat_claim = decoded.get(ORIGINAL_IAT_CLAIM, decoded.get("iat")) - - if has_reached_end_of_life(oldest_iat_claim): - raise MaximumTokenLifeReachedError() - - User = get_user_model() - - credentials = {User.USERNAME_FIELD: decoded["user"]} - - try: - user = User.objects.get(**credentials) - - except User.DoesNotExist: - raise InvalidTokenError(_("User not found")) - - return create_jwt(user, {ORIGINAL_IAT_CLAIM: decoded["iat"]}) - - -def decode_jwt(token): - """Decodes a JWT""" - try: - decoded = jwt.decode( - token, - settings.SECRET_KEY, - algorithms=getattr( - settings, "JWT_ALGORITHMS", DEFAULT_JWT_ALGORITHM - ), - ) - - except ExpiredSignatureError: - raise ExpiredTokenError() - - except DecodeError: - raise InvalidTokenError() - - return decoded diff --git a/tests/test_backends.py b/tests/test_backends.py index 5dcfd82..9eac77c 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -2,7 +2,7 @@ from django.contrib.auth import authenticate, get_user_model from django.http import HttpRequest from django.test import TestCase -from django_ariadne_jwt import backends, utils +from django_ariadne_jwt import backends class BackendTestCase(TestCase): @@ -12,7 +12,7 @@ def setUp(self): User = get_user_model() self.user_data = { - User.USERNAME_FIELD: 'test_user', + User.USERNAME_FIELD: "test_user", "password": "lame_password", } @@ -22,7 +22,7 @@ def setUp(self): def test_authentication_with_valid_token(self): """Tests the authentication of a user from a valid token""" - token = utils.create_jwt(self.user) + token = backends.JSONWebTokenBackend().create(self.user) request = HttpRequest() settings = { diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 4970ae3..6651766 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -5,9 +5,9 @@ from django.http import HttpRequest from django.test import TestCase from unittest.mock import Mock +from django_ariadne_jwt.backends import JSONWebTokenBackend from django_ariadne_jwt.decorators import login_required from django_ariadne_jwt.middleware import JSONWebTokenMiddleware -from django_ariadne_jwt.utils import create_jwt HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION" @@ -62,7 +62,7 @@ def resolve_test(_, info): middleware = [JSONWebTokenMiddleware()] - token = create_jwt(self.user) + token = JSONWebTokenBackend().create(self.user) request = HttpRequest() request.META[HTTP_AUTHORIZATION_HEADER] = f"Token {token}" diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 13c7b1f..a77bc43 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -5,8 +5,8 @@ from django.http import HttpRequest from django.test import TestCase from unittest.mock import Mock, patch +from django_ariadne_jwt.backends import JSONWebTokenBackend from django_ariadne_jwt.middleware import JSONWebTokenMiddleware -from django_ariadne_jwt.utils import create_jwt HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION" @@ -42,7 +42,7 @@ def setUp(self): def test_without_user_and_with_valid_token(self): """Tests resolving with a valid token on a request without user""" - token = create_jwt(self.user) + token = JSONWebTokenBackend().create(self.user) request = HttpRequest() request.META[HTTP_AUTHORIZATION_HEADER] = f"Token {token}" @@ -70,7 +70,7 @@ def next(root, info, **kwargs): def test_with_user_and_valid_token(self): """Tests that the middleware respects the already authenticated user""" - token = create_jwt(self.other_user) + token = JSONWebTokenBackend().create(self.other_user) request = HttpRequest() request.user = self.user @@ -125,7 +125,7 @@ def resolve_test(_, info): middleware = JSONWebTokenMiddleware() - token = create_jwt(self.user) + token = JSONWebTokenBackend().create(self.user) request = HttpRequest() request.META[HTTP_AUTHORIZATION_HEADER] = f"Token {token}" diff --git a/tests/test_resolvers.py b/tests/test_resolvers.py index 39d7ac6..54734ca 100644 --- a/tests/test_resolvers.py +++ b/tests/test_resolvers.py @@ -4,7 +4,8 @@ from django.contrib.auth.models import User from django.http import HttpRequest from django.test import TestCase -from django_ariadne_jwt import resolvers, utils +from django_ariadne_jwt import resolvers +from django_ariadne_jwt.backends import JSONWebTokenBackend @dataclass @@ -63,7 +64,7 @@ def setUp(self): def test_refreshing_for_valid_token(self): """Test refreshing a valid token""" info = InfoObject(context=HttpRequest()) - token = utils.create_jwt(self.user) + token = JSONWebTokenBackend().create(self.user) resolved_data = resolvers.resolve_refresh_token(None, info, token) @@ -82,7 +83,7 @@ def test_refreshing_token_at_end_of_life(self): } with self.settings(**settings): - token = utils.create_jwt(self.user) + token = JSONWebTokenBackend().create(self.user) resolved_data = resolvers.resolve_refresh_token(None, info, token) self.assertIsNotNone(resolved_data) @@ -100,7 +101,7 @@ def test_refreshing_token_not_at_end_of_life(self): } with self.settings(**settings): - token = utils.create_jwt(self.user) + token = JSONWebTokenBackend().create(self.user) resolved_data = resolvers.resolve_refresh_token(None, info, token) self.assertIsNotNone(resolved_data) @@ -139,7 +140,7 @@ def test_verification_for_expired_token(self): settings = {"JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=-10)} with self.settings(**settings): - token = utils.create_jwt(self.user) + token = JSONWebTokenBackend().create(self.user) resolved_data = resolvers.resolve_verify_token(None, info, token) self.assertIsNotNone(resolved_data) @@ -155,7 +156,7 @@ def test_verification_for_valid_token(self): settings = {"JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=2)} with self.settings(**settings): - token = utils.create_jwt(self.user) + token = JSONWebTokenBackend().create(self.user) resolved_data = resolvers.resolve_verify_token(None, info, token) self.assertIsNotNone(resolved_data) diff --git a/tests/test_utils.py b/tests/test_utils.py index 03f2129..05ff854 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,8 @@ from django.http import HttpRequest from django.test import TestCase from django.utils import timezone -from django_ariadne_jwt import exceptions, utils +from django_ariadne_jwt import exceptions +from django_ariadne_jwt.backends import JSONWebTokenBackend HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION" @@ -19,7 +20,7 @@ def test_http_header_retrieval(self): request = HttpRequest() request.META[HTTP_AUTHORIZATION_HEADER] = f"Token {expected_token}" - token = utils.get_token_from_http_header(request) + token = JSONWebTokenBackend().get_token_from_http_header(request) self.assertEqual(expected_token, token) @@ -27,16 +28,19 @@ def test_http_header_retrieval(self): class JWTCreationTestCase(TestCase): """Tests the creation of JWTs""" + def setUp(self): + self.backend = JSONWebTokenBackend() + def test_jwt_creation_for_non_authenticated_user(self): """Tests the creation of a JWT for a non-authenticated user""" with self.assertRaises(exceptions.AuthenticatedUserRequiredError): user = AnonymousUser() - utils.create_jwt(user) + self.backend.create(user) def test_jwt_creation_for_authenticated_user(self): """Tests the creation of a JWT for an authenticated user""" user = User(username="test_user") - token = utils.create_jwt(user) + token = self.backend.create(user) self.assertIsNotNone(token) self.assertIsInstance(token, str) @@ -48,19 +52,22 @@ def test_jwt_creation_for_authenticated_user(self): class JWTDecodingTestCase(TestCase): """Tests the decoding of JWTs""" + def setUp(self): + self.backend = JSONWebTokenBackend() + def test_invalid_jwt_decoding(self): """Tests decoding of an invalid JWT""" with self.assertRaises(exceptions.InvalidTokenError): token = "SOME.FABRICATED.JWT" - utils.decode_jwt(token) + self.backend.decode(token) def test_valid_jwt_decoding(self): """Tests decoding of a valid JWT""" expected_username = "test_user" user = User(username=expected_username) - token = utils.create_jwt(user) - data = utils.decode_jwt(token) + token = self.backend.create(user) + data = self.backend.decode(token) self.assertIn("user", data) self.assertEqual(data["user"], expected_username) @@ -73,15 +80,18 @@ def test_expired_jwt_decoding(self): settings = {"JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=-10)} with self.settings(**settings): - token = utils.create_jwt(user) + token = self.backend.create(user) with self.assertRaises(exceptions.ExpiredTokenError): - utils.decode_jwt(token) + self.backend.decode(token) class JWTRefreshingTestCase(TestCase): """Tests the refreshing of JWTs""" + def setUp(self): + self.backend = JSONWebTokenBackend() + def test_token_not_at_end_of_life_detection(self): """Tests the detection of a token which is at its end of life""" now = timezone.localtime() @@ -96,7 +106,9 @@ def test_token_not_at_end_of_life_detection(self): } with self.settings(**settings): - self.assertFalse(utils.has_reached_end_of_life(original_iat_claim)) + self.assertFalse( + self.backend.has_reached_end_of_life(original_iat_claim) + ) def test_token_at_end_of_life_detection(self): """Tests the detection of a token which isn't yet at its end of life""" @@ -112,7 +124,9 @@ def test_token_at_end_of_life_detection(self): } with self.settings(**settings): - self.assertTrue(utils.has_reached_end_of_life(original_iat_claim)) + self.assertTrue( + self.backend.has_reached_end_of_life(original_iat_claim) + ) def test_refreshing_jwt_not_at_end_of_life(self): """Tests refreshing a JWT for token at its end of life""" @@ -124,15 +138,19 @@ def test_refreshing_jwt_not_at_end_of_life(self): } with self.settings(**settings): - first_token = utils.create_jwt(user) - decoded_first_token = utils.decode_jwt(first_token) - second_token = utils.refresh_jwt(first_token) # Refresh the token - decoded_second_token = utils.decode_jwt(second_token) + first_token = self.backend.create(user) + decoded_first_token = self.backend.decode(first_token) + second_token = self.backend.refresh( + first_token + ) # Refresh the token + decoded_second_token = self.backend.decode(second_token) self.assertIsNotNone(second_token) - self.assertIn(utils.ORIGINAL_IAT_CLAIM, decoded_second_token) + self.assertIn( + self.backend.ORIGINAL_IAT_CLAIM, decoded_second_token + ) self.assertEqual( - decoded_second_token[utils.ORIGINAL_IAT_CLAIM], + decoded_second_token[self.backend.ORIGINAL_IAT_CLAIM], decoded_first_token["iat"], ) @@ -146,10 +164,10 @@ def test_refreshing_jwt_at_end_of_life(self): } with self.settings(**settings): - token = utils.create_jwt(user) + token = self.backend.create(user) with self.assertRaises(exceptions.MaximumTokenLifeReachedError): - utils.refresh_jwt(token) # Refresh the token + self.backend.refresh(token) # Refresh the token def test_jwt_with_non_existent_user(self): """Tests refreshing a JWT for a user that doesn't exist""" @@ -162,7 +180,7 @@ def test_jwt_with_non_existent_user(self): } with self.settings(**settings): - token = utils.create_jwt(user) + token = self.backend.create(user) with self.assertRaises(exceptions.InvalidTokenError): - utils.refresh_jwt(token) + self.backend.refresh(token) From 2705703a6b9f2509921d00e0196dac611f0a1e36 Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 15:54:10 -0500 Subject: [PATCH 05/11] Consolidate test_utils and test_backends --- tests/test_backends.py | 190 +++++++++++++++++++++++++++++++++++++++-- tests/test_utils.py | 186 ---------------------------------------- 2 files changed, 183 insertions(+), 193 deletions(-) delete mode 100644 tests/test_utils.py diff --git a/tests/test_backends.py b/tests/test_backends.py index 9eac77c..d49630b 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1,11 +1,23 @@ """django_ariadne_jwt_auth backends tests""" +import datetime from django.contrib.auth import authenticate, get_user_model +from django.contrib.auth.models import User, AnonymousUser from django.http import HttpRequest from django.test import TestCase -from django_ariadne_jwt import backends +from django.utils import timezone +from django_ariadne_jwt import exceptions +from django_ariadne_jwt.backends import JSONWebTokenBackend +HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION" -class BackendTestCase(TestCase): + +class BaseBackendTestCase(TestCase): + def setUp(self): + super().setUp() + self.backend = JSONWebTokenBackend() + + +class BackendTestCase(BaseBackendTestCase): """Tests for the JWT backend""" def setUp(self): @@ -22,7 +34,7 @@ def setUp(self): def test_authentication_with_valid_token(self): """Tests the authentication of a user from a valid token""" - token = backends.JSONWebTokenBackend().create(self.user) + token = self.backend.create(self.user) request = HttpRequest() settings = { @@ -38,12 +50,176 @@ def test_authentication_with_valid_token(self): def test_existing_user_retrieval(self): """Tests the retrieval of an existing user""" - backend = backends.JSONWebTokenBackend() - user = backend.get_user(self.user.pk) + user = self.backend.get_user(self.user.pk) self.assertEqual(user, self.user) def test_non_existing_user_retrieval(self): """Tests the retrieval of a non existing user""" - backend = backends.JSONWebTokenBackend() - user = backend.get_user(-1) + user = self.backend.get_user(-1) self.assertIsNone(user) + + +class HttpHeaderRetrievalTestCase(BaseBackendTestCase): + """Tests the retrieval of a token from http headers""" + + def test_http_header_retrieval(self): + """Tests the retrieval of a token from http headers""" + expected_token = "EXPECTED_TOKEN_VALUE" + request = HttpRequest() + request.META[HTTP_AUTHORIZATION_HEADER] = f"Token {expected_token}" + + token = self.backend.get_token_from_http_header(request) + + self.assertEqual(expected_token, token) + + +class JWTCreationTestCase(BaseBackendTestCase): + """Tests the creation of JWTs""" + + def test_jwt_creation_for_non_authenticated_user(self): + """Tests the creation of a JWT for a non-authenticated user""" + with self.assertRaises(exceptions.AuthenticatedUserRequiredError): + user = AnonymousUser() + self.backend.create(user) + + def test_jwt_creation_for_authenticated_user(self): + """Tests the creation of a JWT for an authenticated user""" + user = User(username="test_user") + token = self.backend.create(user) + + self.assertIsNotNone(token) + self.assertIsInstance(token, str) + + parts = token.split(".") + self.assertEqual(len(parts), 3) + + +class JWTDecodingTestCase(BaseBackendTestCase): + """Tests the decoding of JWTs""" + + def test_invalid_jwt_decoding(self): + """Tests decoding of an invalid JWT""" + with self.assertRaises(exceptions.InvalidTokenError): + token = "SOME.FABRICATED.JWT" + self.backend.decode(token) + + def test_valid_jwt_decoding(self): + """Tests decoding of a valid JWT""" + expected_username = "test_user" + user = User(username=expected_username) + + token = self.backend.create(user) + data = self.backend.decode(token) + + self.assertIn("user", data) + self.assertEqual(data["user"], expected_username) + + def test_expired_jwt_decoding(self): + """Tests decoding of an expired JWT""" + expected_username = "test_user" + user = User(username=expected_username) + + settings = {"JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=-10)} + + with self.settings(**settings): + token = self.backend.create(user) + + with self.assertRaises(exceptions.ExpiredTokenError): + self.backend.decode(token) + + +class JWTRefreshingTestCase(BaseBackendTestCase): + """Tests the refreshing of JWTs""" + + def test_token_not_at_end_of_life_detection(self): + """Tests the detection of a token which is at its end of life""" + now = timezone.localtime() + delta = 5 + + original_iat_claim = ( + now - datetime.timedelta(minutes=delta, seconds=-1) + ).timestamp() + + settings = { + "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(minutes=delta) + } + + with self.settings(**settings): + self.assertFalse( + self.backend.has_reached_end_of_life(original_iat_claim) + ) + + def test_token_at_end_of_life_detection(self): + """Tests the detection of a token which isn't yet at its end of life""" + now = timezone.localtime() + delta = 5 + + original_iat_claim = ( + now - datetime.timedelta(minutes=delta, seconds=1) + ).timestamp() + + settings = { + "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(minutes=delta) + } + + with self.settings(**settings): + self.assertTrue( + self.backend.has_reached_end_of_life(original_iat_claim) + ) + + def test_refreshing_jwt_not_at_end_of_life(self): + """Tests refreshing a JWT for token at its end of life""" + user = User.objects.create(username="test_user") + + settings = { + "JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=10), + "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(seconds=10), + } + + with self.settings(**settings): + first_token = self.backend.create(user) + decoded_first_token = self.backend.decode(first_token) + second_token = self.backend.refresh( + first_token + ) # Refresh the token + decoded_second_token = self.backend.decode(second_token) + + self.assertIsNotNone(second_token) + self.assertIn( + self.backend.ORIGINAL_IAT_CLAIM, decoded_second_token + ) + self.assertEqual( + decoded_second_token[self.backend.ORIGINAL_IAT_CLAIM], + decoded_first_token["iat"], + ) + + def test_refreshing_jwt_at_end_of_life(self): + """Tests refreshing a JWT for token at its end of life""" + user = User.objects.create(username="test_user") + + settings = { + "JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=3), + "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(seconds=0), + } + + with self.settings(**settings): + token = self.backend.create(user) + + with self.assertRaises(exceptions.MaximumTokenLifeReachedError): + self.backend.refresh(token) # Refresh the token + + def test_jwt_with_non_existent_user(self): + """Tests refreshing a JWT for a user that doesn't exist""" + expected_username = "test_user" + user = User(username=expected_username) + + settings = { + "JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=3), + "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(seconds=3), + } + + with self.settings(**settings): + token = self.backend.create(user) + + with self.assertRaises(exceptions.InvalidTokenError): + self.backend.refresh(token) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 05ff854..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,186 +0,0 @@ -"""django_ariadne_jwt_auth resolvers tests""" -import datetime -from django.contrib.auth.models import User, AnonymousUser -from django.http import HttpRequest -from django.test import TestCase -from django.utils import timezone -from django_ariadne_jwt import exceptions -from django_ariadne_jwt.backends import JSONWebTokenBackend - - -HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION" - - -class HttpHeaderRetrievalTestCase(TestCase): - """Tests the retrieval of a token from http headers""" - - def test_http_header_retrieval(self): - """Tests the retrieval of a token from http headers""" - expected_token = "EXPECTED_TOKEN_VALUE" - request = HttpRequest() - request.META[HTTP_AUTHORIZATION_HEADER] = f"Token {expected_token}" - - token = JSONWebTokenBackend().get_token_from_http_header(request) - - self.assertEqual(expected_token, token) - - -class JWTCreationTestCase(TestCase): - """Tests the creation of JWTs""" - - def setUp(self): - self.backend = JSONWebTokenBackend() - - def test_jwt_creation_for_non_authenticated_user(self): - """Tests the creation of a JWT for a non-authenticated user""" - with self.assertRaises(exceptions.AuthenticatedUserRequiredError): - user = AnonymousUser() - self.backend.create(user) - - def test_jwt_creation_for_authenticated_user(self): - """Tests the creation of a JWT for an authenticated user""" - user = User(username="test_user") - token = self.backend.create(user) - - self.assertIsNotNone(token) - self.assertIsInstance(token, str) - - parts = token.split(".") - self.assertEqual(len(parts), 3) - - -class JWTDecodingTestCase(TestCase): - """Tests the decoding of JWTs""" - - def setUp(self): - self.backend = JSONWebTokenBackend() - - def test_invalid_jwt_decoding(self): - """Tests decoding of an invalid JWT""" - with self.assertRaises(exceptions.InvalidTokenError): - token = "SOME.FABRICATED.JWT" - self.backend.decode(token) - - def test_valid_jwt_decoding(self): - """Tests decoding of a valid JWT""" - expected_username = "test_user" - user = User(username=expected_username) - - token = self.backend.create(user) - data = self.backend.decode(token) - - self.assertIn("user", data) - self.assertEqual(data["user"], expected_username) - - def test_expired_jwt_decoding(self): - """Tests decoding of an expired JWT""" - expected_username = "test_user" - user = User(username=expected_username) - - settings = {"JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=-10)} - - with self.settings(**settings): - token = self.backend.create(user) - - with self.assertRaises(exceptions.ExpiredTokenError): - self.backend.decode(token) - - -class JWTRefreshingTestCase(TestCase): - """Tests the refreshing of JWTs""" - - def setUp(self): - self.backend = JSONWebTokenBackend() - - def test_token_not_at_end_of_life_detection(self): - """Tests the detection of a token which is at its end of life""" - now = timezone.localtime() - delta = 5 - - original_iat_claim = ( - now - datetime.timedelta(minutes=delta, seconds=-1) - ).timestamp() - - settings = { - "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(minutes=delta) - } - - with self.settings(**settings): - self.assertFalse( - self.backend.has_reached_end_of_life(original_iat_claim) - ) - - def test_token_at_end_of_life_detection(self): - """Tests the detection of a token which isn't yet at its end of life""" - now = timezone.localtime() - delta = 5 - - original_iat_claim = ( - now - datetime.timedelta(minutes=delta, seconds=1) - ).timestamp() - - settings = { - "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(minutes=delta) - } - - with self.settings(**settings): - self.assertTrue( - self.backend.has_reached_end_of_life(original_iat_claim) - ) - - def test_refreshing_jwt_not_at_end_of_life(self): - """Tests refreshing a JWT for token at its end of life""" - user = User.objects.create(username="test_user") - - settings = { - "JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=10), - "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(seconds=10), - } - - with self.settings(**settings): - first_token = self.backend.create(user) - decoded_first_token = self.backend.decode(first_token) - second_token = self.backend.refresh( - first_token - ) # Refresh the token - decoded_second_token = self.backend.decode(second_token) - - self.assertIsNotNone(second_token) - self.assertIn( - self.backend.ORIGINAL_IAT_CLAIM, decoded_second_token - ) - self.assertEqual( - decoded_second_token[self.backend.ORIGINAL_IAT_CLAIM], - decoded_first_token["iat"], - ) - - def test_refreshing_jwt_at_end_of_life(self): - """Tests refreshing a JWT for token at its end of life""" - user = User.objects.create(username="test_user") - - settings = { - "JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=3), - "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(seconds=0), - } - - with self.settings(**settings): - token = self.backend.create(user) - - with self.assertRaises(exceptions.MaximumTokenLifeReachedError): - self.backend.refresh(token) # Refresh the token - - def test_jwt_with_non_existent_user(self): - """Tests refreshing a JWT for a user that doesn't exist""" - expected_username = "test_user" - user = User(username=expected_username) - - settings = { - "JWT_EXPIRATION_DELTA": datetime.timedelta(seconds=3), - "JWT_REFRESH_EXPIRATION_DELTA": datetime.timedelta(seconds=3), - } - - with self.settings(**settings): - token = self.backend.create(user) - - with self.assertRaises(exceptions.InvalidTokenError): - self.backend.refresh(token) From 2b039b0992fb4bfd8f33351cf4f4a42b7eeaf6c5 Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 16:00:27 -0500 Subject: [PATCH 06/11] =?UTF-8?q?Swore=20I=20ran=20this=20prior=20to=20pus?= =?UTF-8?q?hing=E2=80=A6=20:-(?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_backends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_backends.py b/tests/test_backends.py index d49630b..bc81ea1 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -21,6 +21,7 @@ class BackendTestCase(BaseBackendTestCase): """Tests for the JWT backend""" def setUp(self): + super().setUp() User = get_user_model() self.user_data = { From ae26f791ad683942809a011849eb9f8d957cafec Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 16:04:49 -0500 Subject: [PATCH 07/11] Swap out to using get_user() instead of interacting with the model directly --- django_ariadne_jwt/backends.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/django_ariadne_jwt/backends.py b/django_ariadne_jwt/backends.py index 1ee5f46..21fca3d 100644 --- a/django_ariadne_jwt/backends.py +++ b/django_ariadne_jwt/backends.py @@ -116,14 +116,8 @@ def refresh(self, token): if self.has_reached_end_of_life(oldest_iat_claim): raise MaximumTokenLifeReachedError() - User = get_user_model() - - credentials = {User.USERNAME_FIELD: decoded["user"]} - - try: - user = User.objects.get(**credentials) - - except User.DoesNotExist: + user = self.get_user(**self.get_user_kwargs(decoded)) + if user is None: raise InvalidTokenError(_("User not found")) return self.create(user, {self.ORIGINAL_IAT_CLAIM: decoded["iat"]}) From aeef469992dbca33ebc7c79c1af36a8fe3bd7faf Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 16:08:24 -0500 Subject: [PATCH 08/11] Extract generation of payload into its own method --- django_ariadne_jwt/backends.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/django_ariadne_jwt/backends.py b/django_ariadne_jwt/backends.py index 21fca3d..839ee7d 100644 --- a/django_ariadne_jwt/backends.py +++ b/django_ariadne_jwt/backends.py @@ -77,28 +77,32 @@ def get_user_kwargs(self, token_data): User = get_user_model() return {User.USERNAME_FIELD: token_data["user"]} - def create(self, user, extra_payload={}): - """Creates a JWT for an authenticated user""" - if not user.is_authenticated: - raise AuthenticatedUserRequiredError( - "JWT generationr requires an authenticated user" - ) - + def generate_token_payload(self, user, extra_payload=None): + """Return a dictionary containing the JWT payload""" + if extra_payload is None: + extra_payload = {} expiration_delta = getattr( settings, "JWT_EXPIRATION_DELTA", datetime.timedelta(minutes=5) ) now = timezone.localtime() - payload = { + return { **extra_payload, "user": user.username, "iat": int(now.timestamp()), "exp": int((now + expiration_delta).timestamp()), } + def create(self, user, extra_payload=None): + """Creates a JWT for an authenticated user""" + if not user.is_authenticated: + raise AuthenticatedUserRequiredError( + "JWT generationr requires an authenticated user" + ) + return jwt.encode( - payload, + self.generate_token_payload(user, extra_payload=extra_payload), settings.SECRET_KEY, algorithm=getattr( settings, "JWT_ALGORITHM", self.DEFAULT_JWT_ALGORITHM From f857c2a8a4dc23b2970485d3bdbb40afddf03dc2 Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Thu, 10 Oct 2019 16:15:11 -0500 Subject: [PATCH 09/11] Extract out EOL check to simplify refresh --- django_ariadne_jwt/backends.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/django_ariadne_jwt/backends.py b/django_ariadne_jwt/backends.py index 839ee7d..d56eaf2 100644 --- a/django_ariadne_jwt/backends.py +++ b/django_ariadne_jwt/backends.py @@ -113,11 +113,7 @@ def refresh(self, token): """Refreshes a JWT if possible""" decoded = self.decode(token) - oldest_iat_claim = decoded.get( - self.ORIGINAL_IAT_CLAIM, decoded.get("iat") - ) - - if self.has_reached_end_of_life(oldest_iat_claim): + if self.is_token_end_of_life(decoded): raise MaximumTokenLifeReachedError() user = self.get_user(**self.get_user_kwargs(decoded)) @@ -126,6 +122,11 @@ def refresh(self, token): return self.create(user, {self.ORIGINAL_IAT_CLAIM: decoded["iat"]}) + def is_token_end_of_life(self, token_data): + return self.has_reached_end_of_life( + token_data.get(self.ORIGINAL_IAT_CLAIM, token_data.get("iat")) + ) + def decode(self, token): """Decodes a JWT""" try: From bcff9406590d9d9fc2e3adeb3f58755da9ef7db4 Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Fri, 11 Oct 2019 12:31:42 -0500 Subject: [PATCH 10/11] Refactor class-based resolvers --- django_ariadne_jwt/resolvers.py | 57 +++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/django_ariadne_jwt/resolvers.py b/django_ariadne_jwt/resolvers.py index 2b54cff..2d9a26c 100644 --- a/django_ariadne_jwt/resolvers.py +++ b/django_ariadne_jwt/resolvers.py @@ -26,34 +26,49 @@ ) -def resolve_token_auth(parent, info, **credentials): - """Resolves the token auth mutation""" - user = authenticate(info.context, **credentials) - return {"token": load_backend().create(user) if user else None} +class TokenAuthResolver: + def get_payload(self, user): + return {"token": load_backend().create(user) if user else None} + def __call__(self, parent, info, **credentials): + user = authenticate(info.context, **credentials) + return self.get_payload(user) -def resolve_refresh_token(parent, info, token): - """Resolves the resfresh token mutaiton""" - try: - token = load_backend().refresh(token) +# TODO Add DeprecationWarning? +resolve_token_auth = TokenAuthResolver() - except (InvalidTokenError, MaximumTokenLifeReachedError): - token = None - return {"token": token} +class RefreshTokenResolver: + def __call__(self, parent, info, token): + """Resolves the resfresh token mutaiton""" + try: + token = load_backend().refresh(token) -def resolve_verify_token(parent, info, token: str): - """Resolves the verify token mutation""" - token_verification = {} + except (InvalidTokenError, MaximumTokenLifeReachedError): + token = None - try: - decoded = load_backend().decode(token) - token_verification["valid"] = True - token_verification["user"] = decoded.get("user") + return {"token": token} - except (InvalidTokenError, ExpiredTokenError): - token_verification["valid"] = False - return token_verification +resolve_refresh_token = RefreshTokenResolver() + + +class VerifyTokenResolver: + def __call__(self, parent, info, token: str): + """Resolves the verify token mutation""" + token_verification = {} + + try: + decoded = load_backend().decode(token) + token_verification["valid"] = True + token_verification["user"] = decoded.get("user") + + except (InvalidTokenError, ExpiredTokenError): + token_verification["valid"] = False + + return token_verification + + +resolve_verify_token = VerifyTokenResolver() From 0e53ec93f7d5d1284b3e82bb99ea56759bbb45cd Mon Sep 17 00:00:00 2001 From: Travis Swicegood Date: Fri, 11 Oct 2019 12:43:21 -0500 Subject: [PATCH 11/11] Refactor to allow overriding the payload in all of the resolvers --- django_ariadne_jwt/resolvers.py | 52 ++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/django_ariadne_jwt/resolvers.py b/django_ariadne_jwt/resolvers.py index 2d9a26c..b9119eb 100644 --- a/django_ariadne_jwt/resolvers.py +++ b/django_ariadne_jwt/resolvers.py @@ -26,49 +26,55 @@ ) -class TokenAuthResolver: - def get_payload(self, user): - return {"token": load_backend().create(user) if user else None} +class BaseTokenResolver: + def get_token(self): + raise NotImplementedError() + + def get_payload(self): + return {"token": self.get_token()} + + +class TokenAuthResolver(BaseTokenResolver): + def get_token(self): + return load_backend().create(self.user) if self.user else None def __call__(self, parent, info, **credentials): - user = authenticate(info.context, **credentials) - return self.get_payload(user) + self.user = authenticate(info.context, **credentials) + return self.get_payload() # TODO Add DeprecationWarning? resolve_token_auth = TokenAuthResolver() -class RefreshTokenResolver: - def __call__(self, parent, info, token): - """Resolves the resfresh token mutaiton""" - +class RefreshTokenResolver(BaseTokenResolver): + def get_token(self): try: - token = load_backend().refresh(token) - + return load_backend().refresh(self.token) except (InvalidTokenError, MaximumTokenLifeReachedError): - token = None + pass - return {"token": token} + def __call__(self, parent, info, token): + """Resolves the resfresh token mutaiton""" + self.token = token + return self.get_payload() resolve_refresh_token = RefreshTokenResolver() class VerifyTokenResolver: - def __call__(self, parent, info, token: str): - """Resolves the verify token mutation""" - token_verification = {} - + def get_payload(self): try: - decoded = load_backend().decode(token) - token_verification["valid"] = True - token_verification["user"] = decoded.get("user") - + decoded = load_backend().decode(self.token) + return {"valid": True, "user": decoded.get("user")} except (InvalidTokenError, ExpiredTokenError): - token_verification["valid"] = False + return {"valid": False} - return token_verification + def __call__(self, parent, info, token: str): + """Resolves the verify token mutation""" + self.token = token + return self.get_payload() resolve_verify_token = VerifyTokenResolver()