Skip to content

Commit ddab703

Browse files
authored
Merge pull request #74 from akx/drf-field
Django REST Framework support
2 parents e9ec7fd + 72c2b11 commit ddab703

File tree

5 files changed

+155
-0
lines changed

5 files changed

+155
-0
lines changed

enumfields/drf/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .fields import EnumField
2+
from .serializers import EnumSupportSerializerMixin

enumfields/drf/fields.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from django.utils import six
2+
from django.utils.encoding import force_text
3+
from rest_framework.fields import ChoiceField
4+
5+
6+
class EnumField(ChoiceField):
7+
def __init__(self, enum, lenient=False, ints_as_names=False, **kwargs):
8+
"""
9+
:param enum: The enumeration class.
10+
:param lenient: Whether to allow lenient parsing (case-insensitive, by value or name)
11+
:type lenient: bool
12+
:param ints_as_names: Whether to serialize integer-valued enums by their name, not the integer value
13+
:type ints_as_names: bool
14+
"""
15+
self.enum = enum
16+
self.lenient = lenient
17+
self.ints_as_names = ints_as_names
18+
kwargs['choices'] = tuple((e.value, getattr(e, 'label', e.name)) for e in self.enum)
19+
super(EnumField, self).__init__(**kwargs)
20+
21+
def to_representation(self, instance):
22+
if instance in ('', u'', None):
23+
return instance
24+
try:
25+
if not isinstance(instance, self.enum):
26+
instance = self.enum(instance) # Try to cast it
27+
if self.ints_as_names and isinstance(instance.value, six.integer_types):
28+
# If the enum value is an int, assume the name is more representative
29+
return instance.name.lower()
30+
return instance.value
31+
except ValueError:
32+
raise ValueError('Invalid value [%r] of enum %s' % (instance, self.enum.__name__))
33+
34+
def to_internal_value(self, data):
35+
if isinstance(data, self.enum):
36+
return data
37+
try:
38+
# Convert the value using the same mechanism DRF uses
39+
converted_value = self.choice_strings_to_values[six.text_type(data)]
40+
return self.enum(converted_value)
41+
except (ValueError, KeyError):
42+
pass
43+
44+
if self.lenient:
45+
# Normal logic:
46+
for choice in self.enum:
47+
if choice.name == data or choice.value == data:
48+
return choice
49+
50+
# Case-insensitive logic:
51+
l_data = force_text(data).lower()
52+
for choice in self.enum:
53+
if choice.name.lower() == l_data or force_text(choice.value).lower() == l_data:
54+
return choice
55+
56+
# Fallback (will likely just raise):
57+
return super(EnumField, self).to_internal_value(data)

enumfields/drf/serializers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from rest_framework.fields import ChoiceField
2+
3+
from enumfields.drf.fields import EnumField as EnumSerializerField
4+
from enumfields.fields import EnumFieldMixin
5+
6+
7+
class EnumSupportSerializerMixin(object):
8+
enumfield_options = {}
9+
10+
def build_standard_field(self, field_name, model_field):
11+
field_class, field_kwargs = (
12+
super(EnumSupportSerializerMixin, self).build_standard_field(field_name, model_field)
13+
)
14+
if field_class == ChoiceField and isinstance(model_field, EnumFieldMixin):
15+
field_class = EnumSerializerField
16+
field_kwargs['enum'] = model_field.enum
17+
field_kwargs.update(self.enumfield_options)
18+
return field_class, field_kwargs

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def run_tests(self):
5555
tests_require=[
5656
'pytest-django',
5757
'Django',
58+
'djangorestframework'
5859
],
5960
extras_require={
6061
":python_version<'3.4'": ['enum34'],

tests/test_serializers.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# -- encoding: UTF-8 --
2+
3+
import uuid
4+
5+
import pytest
6+
from rest_framework import serializers
7+
8+
from enumfields.drf.serializers import EnumSupportSerializerMixin
9+
10+
from .enums import Color, IntegerEnum, Taste
11+
from .models import MyModel
12+
13+
14+
class MySerializer(EnumSupportSerializerMixin, serializers.ModelSerializer):
15+
class Meta:
16+
model = MyModel
17+
fields = '__all__'
18+
19+
20+
class LenientIntNameSerializer(MySerializer):
21+
enumfield_options = {
22+
'lenient': True,
23+
'ints_as_names': True,
24+
}
25+
26+
27+
@pytest.mark.parametrize('int_names', (False, True))
28+
def test_serialize(int_names):
29+
inst = MyModel(color=Color.BLUE, taste=Taste.UMAMI, int_enum=IntegerEnum.B)
30+
data = (LenientIntNameSerializer if int_names else MySerializer)(inst).data
31+
assert data['color'] == Color.BLUE.value
32+
if int_names:
33+
assert data['taste'] == 'umami'
34+
assert data['int_enum'] == 'b'
35+
else:
36+
assert data['taste'] == Taste.UMAMI.value
37+
assert data['int_enum'] == IntegerEnum.B.value
38+
39+
40+
@pytest.mark.django_db
41+
@pytest.mark.parametrize('lenient_serializer', (False, True))
42+
@pytest.mark.parametrize('lenient_data', (False, True))
43+
def test_deserialize(lenient_data, lenient_serializer):
44+
secret_uuid = str(uuid.uuid4())
45+
data = {
46+
'color': Color.BLUE.value,
47+
'taste': Taste.UMAMI.value,
48+
'int_enum': IntegerEnum.B.value,
49+
'random_code': secret_uuid,
50+
}
51+
if lenient_data:
52+
data.update({
53+
'color': 'b',
54+
'taste': 'Umami',
55+
'int_enum': 'B',
56+
})
57+
serializer_cls = (LenientIntNameSerializer if lenient_serializer else MySerializer)
58+
serializer = serializer_cls(data=data)
59+
if lenient_data and not lenient_serializer:
60+
assert not serializer.is_valid()
61+
return
62+
assert serializer.is_valid(), serializer.errors
63+
64+
validated_data = serializer.validated_data
65+
assert validated_data['color'] == Color.BLUE
66+
assert validated_data['taste'] == Taste.UMAMI
67+
assert validated_data['int_enum'] == IntegerEnum.B
68+
69+
inst = serializer.save()
70+
assert inst.color == Color.BLUE
71+
assert inst.taste == Taste.UMAMI
72+
assert inst.int_enum == IntegerEnum.B
73+
74+
inst = MyModel.objects.get(random_code=secret_uuid) # will raise if fails
75+
assert inst.color == Color.BLUE
76+
assert inst.taste == Taste.UMAMI
77+
assert inst.int_enum == IntegerEnum.B

0 commit comments

Comments
 (0)