Skip to content

Commit 9dc2682

Browse files
Update obtain_auth_token to authenticate using 'USERNAME_FIELD' and 'password' instead of 'username' and 'password' for both the built-in and custom User models
1 parent ac50cec commit 9dc2682

File tree

3 files changed

+42
-20
lines changed

3 files changed

+42
-20
lines changed

docs/api-guide/authentication.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ The `obtain_auth_token` view will return a JSON response when valid `username` a
220220

221221
{ 'token' : '9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b' }
222222

223-
Note that the default `obtain_auth_token` view explicitly uses JSON requests and responses, rather than using default renderer and parser classes in your settings.
223+
Note that the default `obtain_auth_token` view explicitly uses JSON requests and responses, rather than using default renderer and parser classes in your settings. If you use a `custom User` model as `AUTH_USER_MODEL` in `settings.py`, authentication will use the `USERNAME_FIELD` and `password` defined in your custom model.
224224

225225
By default, there are no permissions or throttling applied to the `obtain_auth_token` view. If you do wish to apply throttling you'll need to override the view class,
226226
and include them using the `throttle_classes` attribute.

rest_framework/authtoken/serializers.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,50 @@
1-
from django.contrib.auth import authenticate
1+
from django.contrib.auth import authenticate, get_user_model
22
from django.utils.translation import gettext_lazy as _
33

44
from rest_framework import serializers
55

6+
USER_MODEL = get_user_model()
7+
68

79
class AuthTokenSerializer(serializers.Serializer):
8-
username = serializers.CharField(
9-
label=_("Username"),
10-
write_only=True
11-
)
12-
password = serializers.CharField(
13-
label=_("Password"),
14-
style={'input_type': 'password'},
15-
trim_whitespace=False,
16-
write_only=True
17-
)
10+
def __init__(self, instance=None, data=None, **kwargs):
11+
super().__init__(instance, data=data, **kwargs)
12+
self.identifier_fiend_name = USER_MODEL.USERNAME_FIELD
13+
if USER_MODEL.get_email_field_name() == self.identifier_fiend_name:
14+
self.fields[self.identifier_fiend_name] = serializers.EmailField(
15+
label=_(self.identifier_fiend_name.title()),
16+
write_only=True
17+
)
18+
else:
19+
self.fields[self.identifier_fiend_name] = serializers.CharField(
20+
label=_(self.identifier_fiend_name.title()),
21+
write_only=True
22+
)
23+
self.fields["password"] = serializers.CharField(
24+
label=_("Password"),
25+
style={'input_type': 'password'},
26+
trim_whitespace=False,
27+
write_only=True
28+
)
29+
1830
token = serializers.CharField(
1931
label=_("Token"),
2032
read_only=True
2133
)
2234

2335
def validate(self, attrs):
24-
username = attrs.get('username')
36+
identifier_value = attrs.get(self.identifier_fiend_name)
2537
password = attrs.get('password')
2638

27-
if username and password:
28-
user = authenticate(request=self.context.get('request'),
29-
username=username, password=password)
39+
if identifier_value and password:
40+
credentials = {
41+
self.identifier_fiend_name: identifier_value,
42+
"password": password,
43+
}
44+
user = authenticate(
45+
request=self.context.get('request'),
46+
**credentials,
47+
)
3048

3149
# The authenticate call simply returns None for is_active=False
3250
# users. (Assuming the default ModelBackend authentication
@@ -35,7 +53,7 @@ def validate(self, attrs):
3553
msg = _('Unable to log in with provided credentials.')
3654
raise serializers.ValidationError(msg, code='authorization')
3755
else:
38-
msg = _('Must include "username" and "password".')
56+
msg = _(f'Must include "{self.identifier_fiend_name}" and "password".')
3957
raise serializers.ValidationError(msg, code='authorization')
4058

4159
attrs['user'] = user

rest_framework/authtoken/views.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from django.contrib.auth import get_user_model
2+
13
from rest_framework import parsers, renderers
24
from rest_framework.authtoken.models import Token
35
from rest_framework.authtoken.serializers import AuthTokenSerializer
@@ -16,15 +18,17 @@ class ObtainAuthToken(APIView):
1618
serializer_class = AuthTokenSerializer
1719

1820
if coreapi_schema.is_enabled():
21+
USER_MODEL = get_user_model()
22+
identifier_field_name = USER_MODEL.USERNAME_FIELD
1923
schema = ManualSchema(
2024
fields=[
2125
coreapi.Field(
22-
name="username",
26+
name=identifier_field_name,
2327
required=True,
2428
location='form',
2529
schema=coreschema.String(
26-
title="Username",
27-
description="Valid username for authentication",
30+
title=identifier_field_name.title(),
31+
description=f"Valid {identifier_field_name} for authentication",
2832
),
2933
),
3034
coreapi.Field(

0 commit comments

Comments
 (0)