diff --git a/rest_framework_simplejwt/__init__.py b/rest_framework_simplejwt/__init__.py index b001cf5d3..0a8e9d57f 100644 --- a/rest_framework_simplejwt/__init__.py +++ b/rest_framework_simplejwt/__init__.py @@ -1,4 +1,7 @@ from importlib.metadata import PackageNotFoundError, version +from typing import Union + +__version__: Union[str, None] try: __version__ = version("djangorestframework_simplejwt") diff --git a/rest_framework_simplejwt/authentication.py b/rest_framework_simplejwt/authentication.py index 239d242a2..3da2f8625 100644 --- a/rest_framework_simplejwt/authentication.py +++ b/rest_framework_simplejwt/authentication.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import Optional from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractBaseUser @@ -7,7 +7,7 @@ from rest_framework.request import Request from .exceptions import AuthenticationFailed, InvalidToken, TokenError -from .models import TokenUser +from .models import TokenUserBase from .settings import api_settings from .tokens import Token from .utils import get_md5_hash_password @@ -21,8 +21,6 @@ h.encode(HTTP_HEADER_ENCODING) for h in AUTH_HEADER_TYPES } -AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser) - class JWTAuthentication(authentication.BaseAuthentication): """ @@ -37,7 +35,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.user_model = get_user_model() - def authenticate(self, request: Request) -> Optional[tuple[AuthUser, Token]]: + def authenticate(self, request: Request) -> Optional[tuple[TokenUserBase, Token]]: header = self.get_header(request) if header is None: return None @@ -117,7 +115,7 @@ def get_validated_token(self, raw_token: bytes) -> Token: } ) - def get_user(self, validated_token: Token) -> AuthUser: + def get_user(self, validated_token: Token) -> AbstractBaseUser: """ Attempts to find and return a user using the given validated token. """ @@ -155,7 +153,7 @@ class JWTStatelessUserAuthentication(JWTAuthentication): token provided in a request header without performing a database lookup to obtain a user instance. """ - def get_user(self, validated_token: Token) -> AuthUser: + def get_user(self, validated_token: Token) -> AbstractBaseUser: """ Returns a stateless user object which is backed by the given validated token. @@ -171,7 +169,7 @@ def get_user(self, validated_token: Token) -> AuthUser: JWTTokenUserAuthentication = JWTStatelessUserAuthentication -def default_user_authentication_rule(user: AuthUser) -> bool: +def default_user_authentication_rule(user: TokenUserBase) -> bool: # Prior to Django 1.10, inactive users could be authenticated with the # default `ModelBackend`. As of Django 1.10, the `ModelBackend` # prevents inactive users from authenticating. App designers can still diff --git a/rest_framework_simplejwt/models.py b/rest_framework_simplejwt/models.py index a0e2c8345..872898aa1 100644 --- a/rest_framework_simplejwt/models.py +++ b/rest_framework_simplejwt/models.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union from django.contrib.auth import models as auth_models from django.db.models.manager import EmptyManager @@ -7,10 +7,17 @@ from .settings import api_settings if TYPE_CHECKING: + from django.contrib.auth.models import AbstractBaseUser + from .tokens import Token + TokenUserBase: TypeAlias = AbstractBaseUser + +else: + TokenUserBase = object + -class TokenUser: +class TokenUser(TokenUserBase): """ A dummy user class modeled after django.contrib.auth.models.AnonymousUser. Used in conjunction with the `JWTStatelessUserAuthentication` backend to diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index 45c5a771c..8cd8a625f 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -1,17 +1,17 @@ -from typing import Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from django.conf import settings from django.contrib.auth import authenticate, get_user_model -from django.contrib.auth.models import AbstractBaseUser, update_last_login +from django.contrib.auth.models import update_last_login from django.utils.translation import gettext_lazy as _ from rest_framework import exceptions, serializers from rest_framework.exceptions import AuthenticationFailed, ValidationError -from .models import TokenUser +from .models import TokenUserBase from .settings import api_settings from .tokens import RefreshToken, SlidingToken, Token, UntypedToken -AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser) +TokenTypeVar = TypeVar("TokenTypeVar", bound=Token) if api_settings.BLACKLIST_AFTER_ROTATION: from .token_blacklist.models import BlacklistedToken @@ -27,9 +27,9 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) -class TokenObtainSerializer(serializers.Serializer): +class TokenObtainSerializer(Generic[TokenTypeVar], serializers.Serializer): username_field = get_user_model().USERNAME_FIELD - token_class: Optional[type[Token]] = None + token_class: type[TokenTypeVar] default_error_messages = { "no_active_account": _("No active account found with the given credentials") @@ -62,39 +62,45 @@ def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]: return {} @classmethod - def get_token(cls, user: AuthUser) -> Token: - return cls.token_class.for_user(user) # type: ignore + def get_token(cls, user: TokenUserBase) -> TokenTypeVar: + return cls.token_class.for_user(user) -class TokenObtainPairSerializer(TokenObtainSerializer): +class TokenObtainPairSerializer(TokenObtainSerializer[RefreshToken]): token_class = RefreshToken def validate(self, attrs: dict[str, Any]) -> dict[str, str]: data = super().validate(attrs) + if TYPE_CHECKING: + assert self.user + refresh = self.get_token(self.user) data["refresh"] = str(refresh) data["access"] = str(refresh.access_token) if api_settings.UPDATE_LAST_LOGIN: - update_last_login(None, self.user) + update_last_login(type(self.user), self.user) return data -class TokenObtainSlidingSerializer(TokenObtainSerializer): +class TokenObtainSlidingSerializer(TokenObtainSerializer[SlidingToken]): token_class = SlidingToken def validate(self, attrs: dict[str, Any]) -> dict[str, str]: data = super().validate(attrs) + if TYPE_CHECKING: + assert self.user + token = self.get_token(self.user) data["token"] = str(token) if api_settings.UPDATE_LAST_LOGIN: - update_last_login(None, self.user) + update_last_login(type(self.user), self.user) return data diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index d97b3de96..d16119639 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -1,10 +1,9 @@ from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypeVar from uuid import uuid4 from django.conf import settings from django.contrib.auth import get_user_model -from django.contrib.auth.models import AbstractBaseUser from django.utils.module_loading import import_string from django.utils.translation import gettext_lazy as _ @@ -14,7 +13,7 @@ TokenBackendExpiredToken, TokenError, ) -from .models import TokenUser +from .models import TokenUserBase from .settings import api_settings from .token_blacklist.models import BlacklistedToken, OutstandingToken from .utils import ( @@ -29,9 +28,12 @@ if TYPE_CHECKING: from .backends import TokenBackend -T = TypeVar("T", bound="Token") + TokenBase: TypeAlias = "Token" +else: + TokenBase = object -AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser) + +T = TypeVar("T", bound=TokenBase) class Token: @@ -164,13 +166,13 @@ def set_exp( See here: https://tools.ietf.org/html/rfc7519#section-4.1.4 """ - if from_time is None: - from_time = self.current_time + from_time_datetime = from_time or self.current_time + lifetime_timedelta = lifetime or self.lifetime - if lifetime is None: - lifetime = self.lifetime + if TYPE_CHECKING: + assert lifetime_timedelta - self.payload[claim] = datetime_to_epoch(from_time + lifetime) + self.payload[claim] = datetime_to_epoch(from_time_datetime + lifetime_timedelta) def set_iat(self, claim: str = "iat", at_time: Optional[datetime] = None) -> None: """ @@ -213,7 +215,7 @@ def outstand(self) -> Optional[OutstandingToken]: return None @classmethod - def for_user(cls: type[T], user: AuthUser) -> T: + def for_user(cls: type[T], user: TokenUserBase) -> T: """ Returns an authorization token for the given user that will be provided after authenticating the user's credentials. @@ -221,7 +223,7 @@ def for_user(cls: type[T], user: AuthUser) -> T: if hasattr(user, "is_active") and not user.is_active: logger.warning( - f"Creating token for inactive user: {user.id}. If this is not intentional, consider checking the user's status before calling the `for_user` method." + f"Creating token for inactive user: {user.pk}. If this is not intentional, consider checking the user's status before calling the `for_user` method." ) user_id = getattr(user, api_settings.USER_ID_FIELD) @@ -253,7 +255,7 @@ def get_token_backend(self) -> "TokenBackend": return self.token_backend -class BlacklistMixin(Generic[T]): +class BlacklistMixin(TokenBase): """ If the `rest_framework_simplejwt.token_blacklist` app was configured to be used, tokens created from `BlacklistMixin` subclasses will insert @@ -333,7 +335,7 @@ def outstand(self) -> Optional[OutstandingToken]: ) @classmethod - def for_user(cls: type[T], user: AuthUser) -> T: + def for_user(cls: type[T], user: TokenUserBase) -> T: """ Adds this token to the outstanding token list. """ @@ -353,7 +355,7 @@ def for_user(cls: type[T], user: AuthUser) -> T: return token -class SlidingToken(BlacklistMixin["SlidingToken"], Token): +class SlidingToken(BlacklistMixin, Token): token_type = "sliding" lifetime = api_settings.SLIDING_TOKEN_LIFETIME @@ -374,7 +376,7 @@ class AccessToken(Token): lifetime = api_settings.ACCESS_TOKEN_LIFETIME -class RefreshToken(BlacklistMixin["RefreshToken"], Token): +class RefreshToken(BlacklistMixin, Token): token_type = "refresh" lifetime = api_settings.REFRESH_TOKEN_LIFETIME no_copy_claims = ( diff --git a/rest_framework_simplejwt/utils.py b/rest_framework_simplejwt/utils.py index 202f12e9e..d6e11eff4 100644 --- a/rest_framework_simplejwt/utils.py +++ b/rest_framework_simplejwt/utils.py @@ -42,10 +42,10 @@ def datetime_from_epoch(ts: float) -> datetime: return dt -def format_lazy(s: str, *args, **kwargs) -> str: +def _format_lazy(s: str, *args, **kwargs) -> str: return s.format(*args, **kwargs) -format_lazy: Callable = lazy(format_lazy, str) +format_lazy: Callable = lazy(_format_lazy, str) logger = logging.getLogger("rest_framework_simplejwt")