Skip to content

Commit 3ab412d

Browse files
committed
fix: Various typing improvements
1 parent acacec8 commit 3ab412d

File tree

6 files changed

+56
-40
lines changed

6 files changed

+56
-40
lines changed

rest_framework_simplejwt/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from importlib.metadata import PackageNotFoundError, version
2+
from typing import Union
3+
4+
__version__: Union[str, None]
25

36
try:
47
__version__ = version("djangorestframework_simplejwt")

rest_framework_simplejwt/authentication.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, TypeVar
1+
from typing import Optional
22

33
from django.contrib.auth import get_user_model
44
from django.contrib.auth.models import AbstractBaseUser
@@ -7,7 +7,7 @@
77
from rest_framework.request import Request
88

99
from .exceptions import AuthenticationFailed, InvalidToken, TokenError
10-
from .models import TokenUser
10+
from .models import TokenUserBase
1111
from .settings import api_settings
1212
from .tokens import Token
1313
from .utils import get_md5_hash_password
@@ -21,8 +21,6 @@
2121
h.encode(HTTP_HEADER_ENCODING) for h in AUTH_HEADER_TYPES
2222
}
2323

24-
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
25-
2624

2725
class JWTAuthentication(authentication.BaseAuthentication):
2826
"""
@@ -37,7 +35,7 @@ def __init__(self, *args, **kwargs) -> None:
3735
super().__init__(*args, **kwargs)
3836
self.user_model = get_user_model()
3937

40-
def authenticate(self, request: Request) -> Optional[tuple[AuthUser, Token]]:
38+
def authenticate(self, request: Request) -> Optional[tuple[TokenUserBase, Token]]:
4139
header = self.get_header(request)
4240
if header is None:
4341
return None
@@ -117,7 +115,7 @@ def get_validated_token(self, raw_token: bytes) -> Token:
117115
}
118116
)
119117

120-
def get_user(self, validated_token: Token) -> AuthUser:
118+
def get_user(self, validated_token: Token) -> AbstractBaseUser:
121119
"""
122120
Attempts to find and return a user using the given validated token.
123121
"""
@@ -155,7 +153,7 @@ class JWTStatelessUserAuthentication(JWTAuthentication):
155153
token provided in a request header without performing a database lookup to obtain a user instance.
156154
"""
157155

158-
def get_user(self, validated_token: Token) -> AuthUser:
156+
def get_user(self, validated_token: Token) -> AbstractBaseUser:
159157
"""
160158
Returns a stateless user object which is backed by the given validated
161159
token.
@@ -171,7 +169,7 @@ def get_user(self, validated_token: Token) -> AuthUser:
171169
JWTTokenUserAuthentication = JWTStatelessUserAuthentication
172170

173171

174-
def default_user_authentication_rule(user: AuthUser) -> bool:
172+
def default_user_authentication_rule(user: TokenUserBase) -> bool:
175173
# Prior to Django 1.10, inactive users could be authenticated with the
176174
# default `ModelBackend`. As of Django 1.10, the `ModelBackend`
177175
# prevents inactive users from authenticating. App designers can still

rest_framework_simplejwt/models.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Optional, Union
1+
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union
22

33
from django.contrib.auth import models as auth_models
44
from django.db.models.manager import EmptyManager
@@ -7,10 +7,17 @@
77
from .settings import api_settings
88

99
if TYPE_CHECKING:
10+
from django.contrib.auth.models import AbstractBaseUser
11+
1012
from .tokens import Token
1113

14+
TokenUserBase: TypeAlias = AbstractBaseUser
15+
16+
else:
17+
TokenUserBase = object
18+
1219

13-
class TokenUser:
20+
class TokenUser(TokenUserBase):
1421
"""
1522
A dummy user class modeled after django.contrib.auth.models.AnonymousUser.
1623
Used in conjunction with the `JWTStatelessUserAuthentication` backend to

rest_framework_simplejwt/serializers.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from typing import Any, Optional, TypeVar
1+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
22

33
from django.conf import settings
44
from django.contrib.auth import authenticate, get_user_model
5-
from django.contrib.auth.models import AbstractBaseUser, update_last_login
5+
from django.contrib.auth.models import update_last_login
66
from django.utils.translation import gettext_lazy as _
77
from rest_framework import exceptions, serializers
88
from rest_framework.exceptions import AuthenticationFailed, ValidationError
99

10-
from .models import TokenUser
10+
from .models import TokenUserBase
1111
from .settings import api_settings
1212
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken
1313

14-
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
14+
TokenTypeVar = TypeVar("TokenTypeVar", bound=Token)
1515

1616
if api_settings.BLACKLIST_AFTER_ROTATION:
1717
from .token_blacklist.models import BlacklistedToken
@@ -27,9 +27,9 @@ def __init__(self, *args, **kwargs) -> None:
2727
super().__init__(*args, **kwargs)
2828

2929

30-
class TokenObtainSerializer(serializers.Serializer):
30+
class TokenObtainSerializer(Generic[TokenTypeVar], serializers.Serializer):
3131
username_field = get_user_model().USERNAME_FIELD
32-
token_class: Optional[type[Token]] = None
32+
token_class: type[TokenTypeVar]
3333

3434
default_error_messages = {
3535
"no_active_account": _("No active account found with the given credentials")
@@ -62,39 +62,45 @@ def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]:
6262
return {}
6363

6464
@classmethod
65-
def get_token(cls, user: AuthUser) -> Token:
66-
return cls.token_class.for_user(user) # type: ignore
65+
def get_token(cls, user: TokenUserBase) -> TokenTypeVar:
66+
return cls.token_class.for_user(user)
6767

6868

69-
class TokenObtainPairSerializer(TokenObtainSerializer):
69+
class TokenObtainPairSerializer(TokenObtainSerializer[RefreshToken]):
7070
token_class = RefreshToken
7171

7272
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
7373
data = super().validate(attrs)
7474

75+
if TYPE_CHECKING:
76+
assert self.user
77+
7578
refresh = self.get_token(self.user)
7679

7780
data["refresh"] = str(refresh)
7881
data["access"] = str(refresh.access_token)
7982

8083
if api_settings.UPDATE_LAST_LOGIN:
81-
update_last_login(None, self.user)
84+
update_last_login(type(self.user), self.user)
8285

8386
return data
8487

8588

86-
class TokenObtainSlidingSerializer(TokenObtainSerializer):
89+
class TokenObtainSlidingSerializer(TokenObtainSerializer[SlidingToken]):
8790
token_class = SlidingToken
8891

8992
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
9093
data = super().validate(attrs)
9194

95+
if TYPE_CHECKING:
96+
assert self.user
97+
9298
token = self.get_token(self.user)
9399

94100
data["token"] = str(token)
95101

96102
if api_settings.UPDATE_LAST_LOGIN:
97-
update_last_login(None, self.user)
103+
update_last_login(type(self.user), self.user)
98104

99105
return data
100106

rest_framework_simplejwt/tokens.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from datetime import datetime, timedelta
2-
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
2+
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, TypeVar
33
from uuid import uuid4
44

55
from django.conf import settings
66
from django.contrib.auth import get_user_model
7-
from django.contrib.auth.models import AbstractBaseUser
87
from django.utils.module_loading import import_string
98
from django.utils.translation import gettext_lazy as _
109

@@ -14,7 +13,7 @@
1413
TokenBackendExpiredToken,
1514
TokenError,
1615
)
17-
from .models import TokenUser
16+
from .models import TokenUserBase
1817
from .settings import api_settings
1918
from .token_blacklist.models import BlacklistedToken, OutstandingToken
2019
from .utils import (
@@ -29,9 +28,12 @@
2928
if TYPE_CHECKING:
3029
from .backends import TokenBackend
3130

32-
T = TypeVar("T", bound="Token")
31+
TokenBase: TypeAlias = "Token"
32+
else:
33+
TokenBase = object
3334

34-
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
35+
36+
T = TypeVar("T", bound=TokenBase)
3537

3638

3739
class Token:
@@ -164,13 +166,13 @@ def set_exp(
164166
See here:
165167
https://tools.ietf.org/html/rfc7519#section-4.1.4
166168
"""
167-
if from_time is None:
168-
from_time = self.current_time
169+
from_time_datetime = from_time or self.current_time
170+
lifetime_timedelta = lifetime or self.lifetime
169171

170-
if lifetime is None:
171-
lifetime = self.lifetime
172+
if TYPE_CHECKING:
173+
assert lifetime_timedelta
172174

173-
self.payload[claim] = datetime_to_epoch(from_time + lifetime)
175+
self.payload[claim] = datetime_to_epoch(from_time_datetime + lifetime_timedelta)
174176

175177
def set_iat(self, claim: str = "iat", at_time: Optional[datetime] = None) -> None:
176178
"""
@@ -213,15 +215,15 @@ def outstand(self) -> Optional[OutstandingToken]:
213215
return None
214216

215217
@classmethod
216-
def for_user(cls: type[T], user: AuthUser) -> T:
218+
def for_user(cls: type[T], user: TokenUserBase) -> T:
217219
"""
218220
Returns an authorization token for the given user that will be provided
219221
after authenticating the user's credentials.
220222
"""
221223

222224
if hasattr(user, "is_active") and not user.is_active:
223225
logger.warning(
224-
f"Creating token for inactive user: {user.id}. If this is not intentional, consider checking the user's status before calling the `for_user` method."
226+
f"Creating token for inactive user: {user.pk}. If this is not intentional, consider checking the user's status before calling the `for_user` method."
225227
)
226228

227229
user_id = getattr(user, api_settings.USER_ID_FIELD)
@@ -253,7 +255,7 @@ def get_token_backend(self) -> "TokenBackend":
253255
return self.token_backend
254256

255257

256-
class BlacklistMixin(Generic[T]):
258+
class BlacklistMixin(TokenBase):
257259
"""
258260
If the `rest_framework_simplejwt.token_blacklist` app was configured to be
259261
used, tokens created from `BlacklistMixin` subclasses will insert
@@ -333,7 +335,7 @@ def outstand(self) -> Optional[OutstandingToken]:
333335
)
334336

335337
@classmethod
336-
def for_user(cls: type[T], user: AuthUser) -> T:
338+
def for_user(cls: type[T], user: TokenUserBase) -> T:
337339
"""
338340
Adds this token to the outstanding token list.
339341
"""
@@ -353,7 +355,7 @@ def for_user(cls: type[T], user: AuthUser) -> T:
353355
return token
354356

355357

356-
class SlidingToken(BlacklistMixin["SlidingToken"], Token):
358+
class SlidingToken(BlacklistMixin, Token):
357359
token_type = "sliding"
358360
lifetime = api_settings.SLIDING_TOKEN_LIFETIME
359361

@@ -374,7 +376,7 @@ class AccessToken(Token):
374376
lifetime = api_settings.ACCESS_TOKEN_LIFETIME
375377

376378

377-
class RefreshToken(BlacklistMixin["RefreshToken"], Token):
379+
class RefreshToken(BlacklistMixin, Token):
378380
token_type = "refresh"
379381
lifetime = api_settings.REFRESH_TOKEN_LIFETIME
380382
no_copy_claims = (

rest_framework_simplejwt/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def datetime_from_epoch(ts: float) -> datetime:
4242
return dt
4343

4444

45-
def format_lazy(s: str, *args, **kwargs) -> str:
45+
def _format_lazy(s: str, *args, **kwargs) -> str:
4646
return s.format(*args, **kwargs)
4747

4848

49-
format_lazy: Callable = lazy(format_lazy, str)
49+
format_lazy: Callable = lazy(_format_lazy, str)
5050

5151
logger = logging.getLogger("rest_framework_simplejwt")

0 commit comments

Comments
 (0)