Skip to content

feat: Various typing improvements #890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rest_framework_simplejwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from importlib.metadata import PackageNotFoundError, version
from typing import Union

__version__: Union[str, None]

try:
__version__ = version("djangorestframework_simplejwt")
Expand Down
14 changes: 6 additions & 8 deletions rest_framework_simplejwt/authentication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, TypeVar
from typing import Optional

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractBaseUser
Expand All @@ -7,7 +7,7 @@
from rest_framework.request import Request

from .exceptions import AuthenticationFailed, InvalidToken, TokenError
from .models import TokenUser
from .models import TokenUserBase
from .settings import api_settings
from .tokens import Token
from .utils import get_md5_hash_password
Expand All @@ -21,8 +21,6 @@
h.encode(HTTP_HEADER_ENCODING) for h in AUTH_HEADER_TYPES
}

AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)


class JWTAuthentication(authentication.BaseAuthentication):
"""
Expand All @@ -37,7 +35,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.user_model = get_user_model()

def authenticate(self, request: Request) -> Optional[tuple[AuthUser, Token]]:
def authenticate(self, request: Request) -> Optional[tuple[TokenUserBase, Token]]:
header = self.get_header(request)
if header is None:
return None
Expand Down Expand Up @@ -117,7 +115,7 @@ def get_validated_token(self, raw_token: bytes) -> Token:
}
)

def get_user(self, validated_token: Token) -> AuthUser:
def get_user(self, validated_token: Token) -> AbstractBaseUser:
"""
Attempts to find and return a user using the given validated token.
"""
Expand Down Expand Up @@ -155,7 +153,7 @@ class JWTStatelessUserAuthentication(JWTAuthentication):
token provided in a request header without performing a database lookup to obtain a user instance.
"""

def get_user(self, validated_token: Token) -> AuthUser:
def get_user(self, validated_token: Token) -> AbstractBaseUser:
"""
Returns a stateless user object which is backed by the given validated
token.
Expand All @@ -171,7 +169,7 @@ def get_user(self, validated_token: Token) -> AuthUser:
JWTTokenUserAuthentication = JWTStatelessUserAuthentication


def default_user_authentication_rule(user: AuthUser) -> bool:
def default_user_authentication_rule(user: TokenUserBase) -> bool:
# Prior to Django 1.10, inactive users could be authenticated with the
# default `ModelBackend`. As of Django 1.10, the `ModelBackend`
# prevents inactive users from authenticating. App designers can still
Expand Down
11 changes: 9 additions & 2 deletions rest_framework_simplejwt/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union

from django.contrib.auth import models as auth_models
from django.db.models.manager import EmptyManager
Expand All @@ -7,10 +7,17 @@
from .settings import api_settings

if TYPE_CHECKING:
from django.contrib.auth.models import AbstractBaseUser

from .tokens import Token

TokenUserBase: TypeAlias = AbstractBaseUser

else:
TokenUserBase = object


class TokenUser:
class TokenUser(TokenUserBase):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this being introduced?

Is there a reason why TokenUser should be a subclass of AbstractBaseUser?

A dummy user class modeled after django.contrib.auth.models.AnonymousUser.
Used in conjunction with the `JWTStatelessUserAuthentication` backend to
Expand Down
30 changes: 18 additions & 12 deletions rest_framework_simplejwt/serializers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Any, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from django.conf import settings
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.models import AbstractBaseUser, update_last_login
from django.contrib.auth.models import update_last_login
from django.utils.translation import gettext_lazy as _
from rest_framework import exceptions, serializers
from rest_framework.exceptions import AuthenticationFailed, ValidationError

from .models import TokenUser
from .models import TokenUserBase
from .settings import api_settings
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken

AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
TokenTypeVar = TypeVar("TokenTypeVar", bound=Token)

if api_settings.BLACKLIST_AFTER_ROTATION:
from .token_blacklist.models import BlacklistedToken
Expand All @@ -27,9 +27,9 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)


class TokenObtainSerializer(serializers.Serializer):
class TokenObtainSerializer(Generic[TokenTypeVar], serializers.Serializer):
username_field = get_user_model().USERNAME_FIELD
token_class: Optional[type[Token]] = None
token_class: type[TokenTypeVar]

default_error_messages = {
"no_active_account": _("No active account found with the given credentials")
Expand Down Expand Up @@ -62,39 +62,45 @@ def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]:
return {}

@classmethod
def get_token(cls, user: AuthUser) -> Token:
return cls.token_class.for_user(user) # type: ignore
def get_token(cls, user: TokenUserBase) -> TokenTypeVar:
return cls.token_class.for_user(user)


class TokenObtainPairSerializer(TokenObtainSerializer):
class TokenObtainPairSerializer(TokenObtainSerializer[RefreshToken]):
token_class = RefreshToken

def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
data = super().validate(attrs)

if TYPE_CHECKING:
assert self.user

refresh = self.get_token(self.user)

data["refresh"] = str(refresh)
data["access"] = str(refresh.access_token)

if api_settings.UPDATE_LAST_LOGIN:
update_last_login(None, self.user)
update_last_login(type(self.user), self.user)

return data


class TokenObtainSlidingSerializer(TokenObtainSerializer):
class TokenObtainSlidingSerializer(TokenObtainSerializer[SlidingToken]):
token_class = SlidingToken

def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
data = super().validate(attrs)

if TYPE_CHECKING:
assert self.user

token = self.get_token(self.user)

data["token"] = str(token)

if api_settings.UPDATE_LAST_LOGIN:
update_last_login(None, self.user)
update_last_login(type(self.user), self.user)

return data

Expand Down
34 changes: 18 additions & 16 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypeVar
from uuid import uuid4

from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractBaseUser
from django.utils.module_loading import import_string
from django.utils.translation import gettext_lazy as _

Expand All @@ -14,7 +13,7 @@
TokenBackendExpiredToken,
TokenError,
)
from .models import TokenUser
from .models import TokenUserBase
from .settings import api_settings
from .token_blacklist.models import BlacklistedToken, OutstandingToken
from .utils import (
Expand All @@ -29,9 +28,12 @@
if TYPE_CHECKING:
from .backends import TokenBackend

T = TypeVar("T", bound="Token")
TokenBase: TypeAlias = "Token"
else:
TokenBase = object

AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)

T = TypeVar("T", bound=TokenBase)


class Token:
Expand Down Expand Up @@ -164,13 +166,13 @@ def set_exp(
See here:
https://tools.ietf.org/html/rfc7519#section-4.1.4
"""
if from_time is None:
from_time = self.current_time
from_time_datetime = from_time or self.current_time
lifetime_timedelta = lifetime or self.lifetime

if lifetime is None:
lifetime = self.lifetime
if TYPE_CHECKING:
assert lifetime_timedelta

self.payload[claim] = datetime_to_epoch(from_time + lifetime)
self.payload[claim] = datetime_to_epoch(from_time_datetime + lifetime_timedelta)

def set_iat(self, claim: str = "iat", at_time: Optional[datetime] = None) -> None:
"""
Expand Down Expand Up @@ -213,15 +215,15 @@ def outstand(self) -> Optional[OutstandingToken]:
return None

@classmethod
def for_user(cls: type[T], user: AuthUser) -> T:
def for_user(cls: type[T], user: TokenUserBase) -> T:
"""
Returns an authorization token for the given user that will be provided
after authenticating the user's credentials.
"""

if hasattr(user, "is_active") and not user.is_active:
logger.warning(
f"Creating token for inactive user: {user.id}. If this is not intentional, consider checking the user's status before calling the `for_user` method."
f"Creating token for inactive user: {user.pk}. If this is not intentional, consider checking the user's status before calling the `for_user` method."
)

user_id = getattr(user, api_settings.USER_ID_FIELD)
Expand Down Expand Up @@ -253,7 +255,7 @@ def get_token_backend(self) -> "TokenBackend":
return self.token_backend


class BlacklistMixin(Generic[T]):
class BlacklistMixin(TokenBase):
"""
If the `rest_framework_simplejwt.token_blacklist` app was configured to be
used, tokens created from `BlacklistMixin` subclasses will insert
Expand Down Expand Up @@ -333,7 +335,7 @@ def outstand(self) -> Optional[OutstandingToken]:
)

@classmethod
def for_user(cls: type[T], user: AuthUser) -> T:
def for_user(cls: type[T], user: TokenUserBase) -> T:
"""
Adds this token to the outstanding token list.
"""
Expand All @@ -353,7 +355,7 @@ def for_user(cls: type[T], user: AuthUser) -> T:
return token


class SlidingToken(BlacklistMixin["SlidingToken"], Token):
class SlidingToken(BlacklistMixin, Token):
token_type = "sliding"
lifetime = api_settings.SLIDING_TOKEN_LIFETIME

Expand All @@ -374,7 +376,7 @@ class AccessToken(Token):
lifetime = api_settings.ACCESS_TOKEN_LIFETIME


class RefreshToken(BlacklistMixin["RefreshToken"], Token):
class RefreshToken(BlacklistMixin, Token):
token_type = "refresh"
lifetime = api_settings.REFRESH_TOKEN_LIFETIME
no_copy_claims = (
Expand Down
4 changes: 2 additions & 2 deletions rest_framework_simplejwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def datetime_from_epoch(ts: float) -> datetime:
return dt


def format_lazy(s: str, *args, **kwargs) -> str:
def _format_lazy(s: str, *args, **kwargs) -> str:
return s.format(*args, **kwargs)


format_lazy: Callable = lazy(format_lazy, str)
format_lazy: Callable = lazy(_format_lazy, str)

logger = logging.getLogger("rest_framework_simplejwt")
Loading