Skip to content

Commit c6806b5

Browse files
authored
Add Enum support (#10)
* Add ENUM capability * Add additional testing * Clean up EnumField and update tests * Add type hinting to EnumField, remove patch from test as it's not needed * Add return type hinting
1 parent 43f93ae commit c6806b5

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

src/drf_pydantic/fields.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from enum import Enum
2+
from typing import Type, Optional, Union
3+
4+
from rest_framework.fields import empty
5+
from rest_framework.serializers import ChoiceField
6+
7+
8+
class EnumField(ChoiceField):
9+
"""
10+
Custom DRF field that restricts accepted values to that of a defined enum
11+
"""
12+
13+
default_error_messages = {"invalid": "No matching enum type"}
14+
15+
def __init__(self, enum: Type[Enum], **kwargs):
16+
self.enum = enum
17+
kwargs.setdefault("choices", [(x, x.name) for x in self.enum])
18+
super().__init__(**kwargs)
19+
20+
def run_validation(
21+
self, data: Optional[Union[Enum, str, empty]] = empty
22+
) -> Optional[Enum]:
23+
if data and data != empty and not isinstance(data, self.enum):
24+
match_found = False
25+
for x in self.enum:
26+
if x.value == data:
27+
match_found = True
28+
break
29+
30+
if not match_found:
31+
self.fail("invalid")
32+
33+
return super().run_validation(data)
34+
35+
def to_internal_value(self, data: Optional[Union[Enum, str]]) -> Enum:
36+
for choice in self.enum:
37+
if choice == data or choice.name == data or choice.value == data:
38+
return choice
39+
self.fail("invalid")
40+
41+
def to_representation(self, value: Optional[Union[Enum, str]]) -> Optional[str]:
42+
if isinstance(value, self.enum):
43+
return value.value
44+
45+
return value

src/drf_pydantic/parse.py

+9
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import typing
66
import uuid
77
import warnings
8+
from enum import Enum
89

910
import pydantic
1011

1112
from rest_framework import serializers
13+
from drf_pydantic.fields import EnumField
1214

1315
# Cache serializer classes to ensure that there is a one-to-one relationship
1416
# between pydantic models and serializer classes
@@ -38,6 +40,8 @@
3840
# Constraint fields
3941
pydantic.ConstrainedStr: serializers.CharField,
4042
pydantic.ConstrainedInt: serializers.IntegerField,
43+
# Enum fields
44+
Enum: EnumField
4145
}
4246

4347

@@ -122,6 +126,11 @@ def _convert_field(field: pydantic.fields.ModelField) -> serializers.Field:
122126
extra_kwargs["min_length"] = field.type_.min_length
123127
extra_kwargs["max_length"] = field.type_.max_length
124128

129+
if inspect.isclass(field.type_) and issubclass(
130+
field.type_, Enum
131+
):
132+
extra_kwargs['enum'] = field.type_
133+
125134
# Scalar field
126135
if field.outer_type_ is field.type_:
127136
# Normal class

tests/test_models.py

+101
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import datetime
22
import typing
3+
from enum import Enum
4+
from unittest.mock import patch
35

46
import pydantic
57
import pytest
68

79
from rest_framework import serializers
10+
from rest_framework.exceptions import ValidationError
811

912
from drf_pydantic import BaseModel
13+
from drf_pydantic.fields import EnumField
1014

1115

1216
def test_simple_model():
@@ -280,3 +284,100 @@ class Cart(BaseModel):
280284

281285
name_field: serializers.Field = items_field.child.fields["name"]
282286
assert isinstance(name_field, serializers.CharField)
287+
288+
289+
def test_enum_model():
290+
class CountryEnum(Enum):
291+
US = 'US'
292+
GB = 'GB'
293+
FR = 'FR'
294+
295+
class NotificationPreferenceEnum(Enum):
296+
NONE = 'no_notifications'
297+
SOME = 'some_notifications'
298+
ALL = 'all_notifications'
299+
300+
class Person(BaseModel):
301+
name: str
302+
email: pydantic.EmailStr
303+
age: int
304+
height: float
305+
date_of_birth: datetime.date
306+
notification_preferences: NotificationPreferenceEnum
307+
original_nationality: typing.Optional[CountryEnum]
308+
nationality: CountryEnum = CountryEnum.GB
309+
310+
serializer = Person.drf_serializer()
311+
312+
assert serializer.__class__.__name__ == "PersonSerializer"
313+
assert len(serializer.fields) == 8
314+
315+
# Regular fields
316+
assert isinstance(serializer.fields["name"], serializers.CharField)
317+
assert isinstance(serializer.fields["email"], serializers.EmailField)
318+
assert isinstance(serializer.fields["age"], serializers.IntegerField)
319+
assert isinstance(serializer.fields["height"], serializers.FloatField)
320+
assert isinstance(serializer.fields["date_of_birth"], serializers.DateField)
321+
assert isinstance(serializer.fields["notification_preferences"], EnumField)
322+
for name in [
323+
"name",
324+
"email",
325+
"age",
326+
"height",
327+
"date_of_birth",
328+
"notification_preferences"
329+
]:
330+
field = serializer.fields[name]
331+
assert field.required is True, name
332+
assert field.default is serializers.empty, name
333+
assert field.allow_null is False, name
334+
if name == 'notification_preferences':
335+
assert field.choices == dict(
336+
[(x, x.name) for x in NotificationPreferenceEnum]
337+
)
338+
339+
# Optional
340+
field: serializers.Field = serializer.fields["original_nationality"]
341+
assert isinstance(field, EnumField)
342+
assert field.allow_null is True
343+
assert field.default is None
344+
assert field.required is False
345+
assert field.choices == dict([(x, x.name) for x in CountryEnum])
346+
347+
# With default
348+
field: serializers.Field = serializer.fields["nationality"]
349+
assert isinstance(field, EnumField)
350+
assert field.allow_null is False
351+
assert field.default == CountryEnum.GB
352+
assert field.required is False
353+
assert field.choices == dict([(x, x.name) for x in CountryEnum])
354+
355+
356+
def test_enum_value():
357+
358+
class SexEnum(Enum):
359+
MALE = 'male'
360+
FEMALE = 'female'
361+
OTHER = 'other'
362+
363+
class Human(BaseModel):
364+
sex: SexEnum
365+
age: int
366+
367+
serializer = Human.drf_serializer
368+
369+
normal_serializer = serializer(data={'sex': SexEnum.MALE, 'age': 25})
370+
371+
assert normal_serializer.is_valid()
372+
assert normal_serializer.validated_data['sex'] == SexEnum.MALE
373+
assert normal_serializer.validated_data['age'] == 25
374+
375+
value_serializer = serializer(data={'sex': 'male', 'age': 25})
376+
377+
assert value_serializer.is_valid()
378+
assert value_serializer.validated_data['sex'] == SexEnum.MALE
379+
assert value_serializer.validated_data['age'] == 25
380+
381+
bad_value_serializer = serializer(data={'sex': 'bad_value', 'age': 25})
382+
383+
assert bad_value_serializer.is_valid() is False

0 commit comments

Comments
 (0)