From 49b84ed8f04fdd59d3638092cde0aed0f41a2f2a Mon Sep 17 00:00:00 2001 From: Josh Gardner Date: Wed, 8 Jan 2025 09:39:41 -0700 Subject: [PATCH] Fix user_id type mismatch when user claim is not pk Regarding changes made at https://github.com/jazzband/djangorestframework-simplejwt/pull/806/files We're using a USER_ID_CLAIM that is neither the primary key field nor is it the same type as the primary key, and these previous changes fail at this point when attempting to create an OutstandingToken, because it assumes that the ID pulled out of the token claims is usable as the database key for a user. So to mitigate this gets the user from the database using the USER_ID_FIELD setting and uses that in the get_or_create call. Also include a test of handling the case where the user is deleted when the token is blacklisted. --- rest_framework_simplejwt/tokens.py | 8 +++++++- tests/test_token_blacklist.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index 9e9c3b9df..744f04fb1 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -3,6 +3,7 @@ from uuid import uuid4 from django.conf import settings +from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractBaseUser from django.utils.module_loading import import_string from django.utils.translation import gettext_lazy as _ @@ -266,12 +267,17 @@ def blacklist(self) -> BlacklistedToken: jti = self.payload[api_settings.JTI_CLAIM] exp = self.payload["exp"] user_id = self.payload.get(api_settings.USER_ID_CLAIM) + User = get_user_model() + try: + user = User.objects.get(**{api_settings.USER_ID_FIELD: user_id}) + except User.DoesNotExist: + user = None # Ensure outstanding token exists with given jti token, _ = OutstandingToken.objects.get_or_create( jti=jti, defaults={ - "user_id": user_id, + "user": user, "created_at": self.current_time, "token": str(self), "expires_at": datetime_from_epoch(exp), diff --git a/tests/test_token_blacklist.py b/tests/test_token_blacklist.py index fc45adf2e..5fc60f29d 100644 --- a/tests/test_token_blacklist.py +++ b/tests/test_token_blacklist.py @@ -6,7 +6,6 @@ from django.db.models import BigAutoField from django.test import TestCase from django.utils import timezone - from rest_framework_simplejwt.exceptions import TokenError from rest_framework_simplejwt.serializers import TokenVerifySerializer from rest_framework_simplejwt.settings import api_settings @@ -160,6 +159,35 @@ def test_outstanding_token_and_blacklisted_token_user(self): outstanding_token = OutstandingToken.objects.get(token=token) self.assertEqual(outstanding_token.user, self.user) + @override_api_settings(USER_ID_FIELD="email", USER_ID_CLAIM="email") + def test_outstanding_token_and_blacklisted_token_created_at_with_modified_user_id_field( + self, + ): + token = RefreshToken.for_user(self.user) + + token.blacklist() + outstanding_token = OutstandingToken.objects.get(token=token) + self.assertEqual(outstanding_token.created_at, token.current_time) + + @override_api_settings(USER_ID_FIELD="email", USER_ID_CLAIM="email") + def test_outstanding_token_and_blacklisted_token_user_with_modifed_user_id_field( + self, + ): + token = RefreshToken.for_user(self.user) + + token.blacklist() + outstanding_token = OutstandingToken.objects.get(token=token) + self.assertEqual(outstanding_token.user, self.user) + + + @override_api_settings(USER_ID_FIELD="email", USER_ID_CLAIM="email") + def test_outstanding_token_with_deleted_user_and_modifed_user_id_field(self): + self.assertFalse(BlacklistedToken.objects.exists()) + token = RefreshToken.for_user(self.user) + self.user.delete() + token.blacklist() + self.assertTrue(BlacklistedToken.objects.count(), 1) + class TestTokenBlacklistFlushExpiredTokens(TestCase): def setUp(self):