Skip to content

Commit 27f4675

Browse files
committed
Harden revoke access token for password changes
1 parent faf92e8 commit 27f4675

File tree

5 files changed

+70
-46
lines changed

5 files changed

+70
-46
lines changed

rest_framework_simplejwt/authentication.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from django.contrib.auth import get_user_model
44
from django.contrib.auth.models import AbstractBaseUser
5+
from django.utils.crypto import constant_time_compare
56
from django.utils.translation import gettext_lazy as _
67
from rest_framework import HTTP_HEADER_ENCODING, authentication
78
from rest_framework.request import Request
@@ -10,7 +11,7 @@
1011
from .models import TokenUser
1112
from .settings import api_settings
1213
from .tokens import Token
13-
from .utils import get_md5_hash_password
14+
from .utils import get_fallback_token_auth_hash, get_token_auth_hash
1415

1516
AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES
1617

@@ -135,9 +136,17 @@ def get_user(self, validated_token: Token) -> AuthUser:
135136
raise AuthenticationFailed(_("User is inactive"), code="user_inactive")
136137

137138
if api_settings.CHECK_REVOKE_TOKEN:
138-
if validated_token.get(
139-
api_settings.REVOKE_TOKEN_CLAIM
140-
) != get_md5_hash_password(user.password):
139+
validation_claim = validated_token.get(api_settings.REVOKE_TOKEN_CLAIM)
140+
if (
141+
validation_claim is None
142+
or not constant_time_compare(
143+
validation_claim, get_token_auth_hash(user)
144+
)
145+
and not any(
146+
constant_time_compare(validation_claim, fallback_auth_hash)
147+
for fallback_auth_hash in get_fallback_token_auth_hash(user)
148+
)
149+
):
141150
raise AuthenticationFailed(
142151
_("The user's password has been changed."), code="password_changed"
143152
)

rest_framework_simplejwt/tokens.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
datetime_from_epoch,
1717
datetime_to_epoch,
1818
format_lazy,
19-
get_md5_hash_password,
19+
get_token_auth_hash,
2020
)
2121

2222
if TYPE_CHECKING:
@@ -208,9 +208,7 @@ def for_user(cls, user: AuthUser) -> "Token":
208208
token[api_settings.USER_ID_CLAIM] = user_id
209209

210210
if api_settings.CHECK_REVOKE_TOKEN:
211-
token[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(
212-
user.password
213-
)
211+
token[api_settings.REVOKE_TOKEN_CLAIM] = get_token_auth_hash(user)
214212

215213
return token
216214

rest_framework_simplejwt/utils.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,44 @@
1-
import hashlib
21
from calendar import timegm
32
from datetime import datetime, timezone
4-
from typing import Callable
3+
from typing import TYPE_CHECKING, Callable, TypeVar
54

65
from django.conf import settings
6+
from django.contrib.auth.models import AbstractBaseUser
7+
from django.utils.crypto import salted_hmac
78
from django.utils.functional import lazy
89
from django.utils.timezone import is_naive, make_aware
910

11+
if TYPE_CHECKING:
12+
from .models import TokenUser
1013

11-
def get_md5_hash_password(password: str) -> str:
14+
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
15+
16+
17+
def _get_token_auth_hash(user: "AuthUser", secret=None) -> str:
18+
key_salt = "rest_framework_simplejwt.utils.get_token_auth_hash"
19+
return salted_hmac(key_salt, user.password, secret=secret).hexdigest()
20+
21+
22+
def get_token_auth_hash(user: "AuthUser") -> str:
1223
"""
13-
Returns MD5 hash of the given password
24+
Return an HMAC of the given user password field.
1425
"""
15-
return hashlib.md5(password.encode()).hexdigest().upper()
26+
if hasattr(user, "get_session_auth_hash"):
27+
return user.get_session_auth_hash()
28+
return _get_token_auth_hash(user)
29+
30+
31+
def get_fallback_token_auth_hash(user: "AuthUser") -> str:
32+
"""
33+
Yields a sequence of fallback HMACs of the given user password field.
34+
"""
35+
if hasattr(user, "get_session_auth_fallback_hash"):
36+
yield from user.get_session_auth_fallback_hash()
37+
38+
fallback_keys = getattr(settings, "SECRET_KEY_FALLBACKS", [])
39+
yield from (
40+
_get_token_auth_hash(user, fallback_secret) for fallback_secret in fallback_keys
41+
)
1642

1743

1844
def make_utc(dt: datetime) -> datetime:

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ def pytest_configure():
1616
},
1717
SITE_ID=1,
1818
SECRET_KEY="not very secret in tests",
19+
SECRET_KEY_FALLBACKS=[
20+
"old not very secure secret",
21+
"other old not very secure secret",
22+
],
1923
USE_I18N=True,
2024
STATIC_URL="/static/",
2125
ROOT_URLCONF="tests.urls",

tests/test_authentication.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from rest_framework_simplejwt.models import TokenUser
1111
from rest_framework_simplejwt.settings import api_settings
1212
from rest_framework_simplejwt.tokens import AccessToken, SlidingToken
13-
from rest_framework_simplejwt.utils import get_md5_hash_password
13+
from rest_framework_simplejwt.utils import _get_token_auth_hash, get_token_auth_hash
1414

1515
from .utils import override_api_settings
1616

@@ -145,60 +145,47 @@ def test_get_user(self):
145145
with self.assertRaises(AuthenticationFailed):
146146
self.backend.get_user(payload)
147147

148-
u = User.objects.create_user(username="markhamill")
149-
u.is_active = False
150-
u.save()
148+
user = User.objects.create_user(username="markhamill", is_active=False)
151149

152-
payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD)
150+
payload[api_settings.USER_ID_CLAIM] = getattr(user, api_settings.USER_ID_FIELD)
153151

154152
# Should raise exception if user is inactive
155153
with self.assertRaises(AuthenticationFailed):
156154
self.backend.get_user(payload)
157155

158-
u.is_active = True
159-
u.save()
156+
user.is_active = True
157+
user.save()
160158

161159
# Otherwise, should return correct user
162-
self.assertEqual(self.backend.get_user(payload).id, u.id)
160+
self.assertEqual(self.backend.get_user(payload).id, user.id)
163161

164162
@override_api_settings(
165163
CHECK_REVOKE_TOKEN=True, REVOKE_TOKEN_CLAIM="revoke_token_claim"
166164
)
167165
def test_get_user_with_check_revoke_token(self):
168-
payload = {"some_other_id": "foo"}
169-
170-
# Should raise error if no recognizable user identification
171-
with self.assertRaises(InvalidToken):
172-
self.backend.get_user(payload)
173-
174-
payload[api_settings.USER_ID_CLAIM] = 42
175-
176-
# Should raise exception if user not found
177-
with self.assertRaises(AuthenticationFailed):
178-
self.backend.get_user(payload)
179-
180-
u = User.objects.create_user(username="markhamill")
181-
u.is_active = False
182-
u.save()
166+
user = User.objects.create_user(username="markhamill")
167+
payload = {
168+
api_settings.USER_ID_CLAIM: getattr(user, api_settings.USER_ID_FIELD)
169+
}
183170

184-
payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD)
185-
186-
# Should raise exception if user is inactive
171+
# Should raise exception if claim is missing
187172
with self.assertRaises(AuthenticationFailed):
188173
self.backend.get_user(payload)
189174

190-
u.is_active = True
191-
u.save()
192-
193-
# Should raise exception if hash password is different
175+
payload[api_settings.REVOKE_TOKEN_CLAIM] = "differenthash"
176+
# Should raise exception if claim is different
194177
with self.assertRaises(AuthenticationFailed):
195178
self.backend.get_user(payload)
196179

197-
if api_settings.CHECK_REVOKE_TOKEN:
198-
payload[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(u.password)
180+
payload[api_settings.REVOKE_TOKEN_CLAIM] = _get_token_auth_hash(
181+
user, "other old not very secure secret"
182+
)
183+
# Should return correct user if claim was signed with an old key
184+
self.assertEqual(self.backend.get_user(payload).id, user.id)
199185

186+
payload[api_settings.REVOKE_TOKEN_CLAIM] = get_token_auth_hash(user)
200187
# Otherwise, should return correct user
201-
self.assertEqual(self.backend.get_user(payload).id, u.id)
188+
self.assertEqual(self.backend.get_user(payload).id, user.id)
202189

203190

204191
class TestJWTStatelessUserAuthentication(TestCase):

0 commit comments

Comments
 (0)