Skip to content
This repository was archived by the owner on Apr 25, 2021. It is now read-only.

Better token creation support #1

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
163 changes: 142 additions & 21 deletions django_ariadne_jwt/backends.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,164 @@
"""GraphQL auth backends module"""
import datetime
from django.contrib.auth import get_user_model
from .exceptions import JSONWebTokenError
from .utils import decode_jwt
from django.conf import settings
from django.utils import timezone
from django.utils.module_loading import import_string
from django.utils.translation import ugettext_lazy as _
import jwt
from jwt.exceptions import DecodeError, ExpiredSignatureError

from .exceptions import (
AuthenticatedUserRequiredError,
ExpiredTokenError,
InvalidTokenError,
JSONWebTokenError,
MaximumTokenLifeReachedError,
)


def load_backend():
return import_string(
getattr(
settings,
"JWT_BACKEND",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of thinking this might should be DJANGO_ARIADNE_JWT_BACKEND to avoid any possible confusion.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I guess we should prefix JWT_EXPIRATION_DELTA and JWT_ALGORITHM as well to keep it consistent.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, I guess let's leave it since everything is using JWT_ as the prefix. Seems like a candidate for supporting both with a DeprecationWarning if the old version is detected.

"django_ariadne_jwt.backends.JSONWebTokenBackend",
)
)()
Comment on lines +20 to +27
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this code interact with the value from settings.AUTHENTICATION_BACKENDS? Wouldn't it allow a configuration issue where two different backends could be specified?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't interact with the settings.AUTHENTICATION_BACKENDS code in any way other than having the one object that's pulling double duty and that I'm using the same pattern for loading backends. I could potentially see some issues if someone's got a backend that either:

  • Does a lot of heavy lifting in the __init__
  • Requires that only one instance of the backend class was instantiated

Both of those feel like anti-patterns, so I'm not sure it'd worth coding around that.


All that said, I do wonder if it's confusing to folks to see the same backend configuration being used in multiple different areas. I could definitely see making two different parts, but things like loading users and such are shared in both concerns. That was my original thought for combining them.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming there is no case for using two different JSONWebTokenBackend classes in settings.AUTHENTICATION_BACKENDS, maybe we could make load_backend detect the JSONWebTokenBackend class or subclass referenced by the settings in order to avoid the duplicity.



class JSONWebTokenBackend(object):
"""Authenticates against a JSON Web Token"""

def authenticate(self, request, token=None, **kwargs):
"""Performs authentication"""
user = None
DEFAULT_JWT_ALGORITHM = "HS256"
ORIGINAL_IAT_CLAIM = "orig_iat"
HTTP_AUTHORIZATION_HEADER = "HTTP_AUTHORIZATION"
AUTHORIZATION_HEADER_PREFIX = "Token"
DEFAULT_JWT_ALGORITHM = "HS256"

if token is not None:
token_data = None
def get_token_from_http_header(self, request):
"""Retrieves the http authorization header from the request"""
header = request.META.get(self.HTTP_AUTHORIZATION_HEADER, False)
if header is False:
return None

try:
token_data = decode_jwt(token)
prefix, token = header.split()
if prefix.lower() != self.AUTHORIZATION_HEADER_PREFIX.lower():
return None

except JSONWebTokenError:
pass
return token

if token_data is not None:
User = get_user_model()
credentials = {User.USERNAME_FIELD: token_data["user"]}
def authenticate(self, request, token=None, **kwargs):
"""Performs authentication"""
if token is None:
return

try:
user = User.objects.get(**credentials)
try:
token_data = self.decode(token)

except User.DoesNotExist:
pass
except JSONWebTokenError:
return

return user
return self.get_user(**self.get_user_kwargs(token_data))

def get_user(self, user_id):
def get_user(self, user_id=None, **kwargs):
"""Gets a user from its id"""
User = get_user_model()
if user_id is not None:
kwargs["pk"] = user_id

try:
return User.objects.get(pk=user_id)
return User.objects.get(**kwargs)

except User.DoesNotExist:
return None

def get_user_kwargs(self, token_data):
User = get_user_model()
return {User.USERNAME_FIELD: token_data["user"]}

def generate_token_payload(self, user, extra_payload=None):
"""Return a dictionary containing the JWT payload"""
if extra_payload is None:
extra_payload = {}
expiration_delta = getattr(
settings, "JWT_EXPIRATION_DELTA", datetime.timedelta(minutes=5)
)

now = timezone.localtime()

return {
**extra_payload,
"user": user.username,
"iat": int(now.timestamp()),
"exp": int((now + expiration_delta).timestamp()),
}

def create(self, user, extra_payload=None):
"""Creates a JWT for an authenticated user"""
if not user.is_authenticated:
raise AuthenticatedUserRequiredError(
"JWT generationr requires an authenticated user"
)

return jwt.encode(
self.generate_token_payload(user, extra_payload=extra_payload),
settings.SECRET_KEY,
algorithm=getattr(
settings, "JWT_ALGORITHM", self.DEFAULT_JWT_ALGORITHM
),
).decode("utf-8")

def refresh(self, token):
"""Refreshes a JWT if possible"""
decoded = self.decode(token)

if self.is_token_end_of_life(decoded):
raise MaximumTokenLifeReachedError()

user = self.get_user(**self.get_user_kwargs(decoded))
if user is None:
raise InvalidTokenError(_("User not found"))

return self.create(user, {self.ORIGINAL_IAT_CLAIM: decoded["iat"]})

def is_token_end_of_life(self, token_data):
return self.has_reached_end_of_life(
token_data.get(self.ORIGINAL_IAT_CLAIM, token_data.get("iat"))
)

def decode(self, token):
"""Decodes a JWT"""
try:
decoded = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=getattr(
settings, "JWT_ALGORITHMS", self.DEFAULT_JWT_ALGORITHM
),
)

except ExpiredSignatureError:
raise ExpiredTokenError()

except DecodeError:
raise InvalidTokenError()

return decoded

def has_reached_end_of_life(self, oldest_iat_claim):
"""Checks if the token has reached its end of life"""
expiration_delta = getattr(
settings,
"JWT_REFRESH_EXPIRATION_DELTA",
datetime.timedelta(days=7),
)

now = timezone.localtime()
original_issue_time = timezone.make_aware(
datetime.datetime.fromtimestamp(int(oldest_iat_claim))
)

end_of_life = original_issue_time + expiration_delta

return now > end_of_life
4 changes: 2 additions & 2 deletions django_ariadne_jwt/middleware.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""ariadne_django_jwt middleware module"""
from django.contrib.auth import authenticate
from django.contrib.auth.models import AnonymousUser
from .utils import get_token_from_http_header
from .backends import load_backend

__all__ = ["JSONWebTokenMiddleware"]

Expand All @@ -13,7 +13,7 @@ def resolve(self, next, root, info, **kwargs):
"""Performs the middleware relevant operations"""
request = info.context

token = get_token_from_http_header(request)
token = load_backend().get_token_from_http_header(request)

if token is not None:
user = getattr(request, "user", None)
Expand Down
67 changes: 41 additions & 26 deletions django_ariadne_jwt/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""ariadne_django_jwt resolvers module"""
from ariadne import gql
from django.contrib.auth import authenticate
from .backends import load_backend
from .exceptions import (
ExpiredTokenError,
InvalidTokenError,
MaximumTokenLifeReachedError,
)
from .utils import create_jwt, decode_jwt, refresh_jwt


auth_token_definition = gql(
"""
Expand All @@ -27,39 +26,55 @@
)


def resolve_token_auth(parent, info, **credentials):
"""Resolves the token auth mutation"""
token = None
user = authenticate(info.context, **credentials)
class BaseTokenResolver:
def get_token(self):
raise NotImplementedError()

def get_payload(self):
return {"token": self.get_token()}


class TokenAuthResolver(BaseTokenResolver):
def get_token(self):
return load_backend().create(self.user) if self.user else None

def __call__(self, parent, info, **credentials):
self.user = authenticate(info.context, **credentials)
return self.get_payload()

if user is not None:
token = create_jwt(user)

return {"token": token}
# TODO Add DeprecationWarning?
resolve_token_auth = TokenAuthResolver()


def resolve_refresh_token(parent, info, token):
"""Resolves the resfresh token mutaiton"""
class RefreshTokenResolver(BaseTokenResolver):
def get_token(self):
try:
return load_backend().refresh(self.token)
except (InvalidTokenError, MaximumTokenLifeReachedError):
pass

try:
token = refresh_jwt(token)
def __call__(self, parent, info, token):
"""Resolves the resfresh token mutaiton"""
self.token = token
return self.get_payload()

except (InvalidTokenError, MaximumTokenLifeReachedError):
token = None

return {"token": token}
resolve_refresh_token = RefreshTokenResolver()


def resolve_verify_token(parent, info, token: str):
"""Resolves the verify token mutation"""
token_verification = {}
class VerifyTokenResolver:
def get_payload(self):
try:
decoded = load_backend().decode(self.token)
return {"valid": True, "user": decoded.get("user")}
except (InvalidTokenError, ExpiredTokenError):
return {"valid": False}

try:
decoded = decode_jwt(token)
token_verification["valid"] = True
token_verification["user"] = decoded.get("user")
def __call__(self, parent, info, token: str):
"""Resolves the verify token mutation"""
self.token = token
return self.get_payload()

except (InvalidTokenError, ExpiredTokenError):
token_verification["valid"] = False

return token_verification
resolve_verify_token = VerifyTokenResolver()
Loading