|
1 | 1 | import datetime
|
2 | 2 | import typing
|
| 3 | +from enum import Enum |
| 4 | +from unittest.mock import patch |
3 | 5 |
|
4 | 6 | import pydantic
|
5 | 7 | import pytest
|
6 | 8 |
|
7 | 9 | from rest_framework import serializers
|
| 10 | +from rest_framework.exceptions import ValidationError |
8 | 11 |
|
9 | 12 | from drf_pydantic import BaseModel
|
| 13 | +from drf_pydantic.fields import EnumField |
10 | 14 |
|
11 | 15 |
|
12 | 16 | def test_simple_model():
|
@@ -280,3 +284,100 @@ class Cart(BaseModel):
|
280 | 284 |
|
281 | 285 | name_field: serializers.Field = items_field.child.fields["name"]
|
282 | 286 | assert isinstance(name_field, serializers.CharField)
|
| 287 | + |
| 288 | + |
| 289 | +def test_enum_model(): |
| 290 | + class CountryEnum(Enum): |
| 291 | + US = 'US' |
| 292 | + GB = 'GB' |
| 293 | + FR = 'FR' |
| 294 | + |
| 295 | + class NotificationPreferenceEnum(Enum): |
| 296 | + NONE = 'no_notifications' |
| 297 | + SOME = 'some_notifications' |
| 298 | + ALL = 'all_notifications' |
| 299 | + |
| 300 | + class Person(BaseModel): |
| 301 | + name: str |
| 302 | + email: pydantic.EmailStr |
| 303 | + age: int |
| 304 | + height: float |
| 305 | + date_of_birth: datetime.date |
| 306 | + notification_preferences: NotificationPreferenceEnum |
| 307 | + original_nationality: typing.Optional[CountryEnum] |
| 308 | + nationality: CountryEnum = CountryEnum.GB |
| 309 | + |
| 310 | + serializer = Person.drf_serializer() |
| 311 | + |
| 312 | + assert serializer.__class__.__name__ == "PersonSerializer" |
| 313 | + assert len(serializer.fields) == 8 |
| 314 | + |
| 315 | + # Regular fields |
| 316 | + assert isinstance(serializer.fields["name"], serializers.CharField) |
| 317 | + assert isinstance(serializer.fields["email"], serializers.EmailField) |
| 318 | + assert isinstance(serializer.fields["age"], serializers.IntegerField) |
| 319 | + assert isinstance(serializer.fields["height"], serializers.FloatField) |
| 320 | + assert isinstance(serializer.fields["date_of_birth"], serializers.DateField) |
| 321 | + assert isinstance(serializer.fields["notification_preferences"], EnumField) |
| 322 | + for name in [ |
| 323 | + "name", |
| 324 | + "email", |
| 325 | + "age", |
| 326 | + "height", |
| 327 | + "date_of_birth", |
| 328 | + "notification_preferences" |
| 329 | + ]: |
| 330 | + field = serializer.fields[name] |
| 331 | + assert field.required is True, name |
| 332 | + assert field.default is serializers.empty, name |
| 333 | + assert field.allow_null is False, name |
| 334 | + if name == 'notification_preferences': |
| 335 | + assert field.choices == dict( |
| 336 | + [(x, x.name) for x in NotificationPreferenceEnum] |
| 337 | + ) |
| 338 | + |
| 339 | + # Optional |
| 340 | + field: serializers.Field = serializer.fields["original_nationality"] |
| 341 | + assert isinstance(field, EnumField) |
| 342 | + assert field.allow_null is True |
| 343 | + assert field.default is None |
| 344 | + assert field.required is False |
| 345 | + assert field.choices == dict([(x, x.name) for x in CountryEnum]) |
| 346 | + |
| 347 | + # With default |
| 348 | + field: serializers.Field = serializer.fields["nationality"] |
| 349 | + assert isinstance(field, EnumField) |
| 350 | + assert field.allow_null is False |
| 351 | + assert field.default == CountryEnum.GB |
| 352 | + assert field.required is False |
| 353 | + assert field.choices == dict([(x, x.name) for x in CountryEnum]) |
| 354 | + |
| 355 | + |
| 356 | +def test_enum_value(): |
| 357 | + |
| 358 | + class SexEnum(Enum): |
| 359 | + MALE = 'male' |
| 360 | + FEMALE = 'female' |
| 361 | + OTHER = 'other' |
| 362 | + |
| 363 | + class Human(BaseModel): |
| 364 | + sex: SexEnum |
| 365 | + age: int |
| 366 | + |
| 367 | + serializer = Human.drf_serializer |
| 368 | + |
| 369 | + normal_serializer = serializer(data={'sex': SexEnum.MALE, 'age': 25}) |
| 370 | + |
| 371 | + assert normal_serializer.is_valid() |
| 372 | + assert normal_serializer.validated_data['sex'] == SexEnum.MALE |
| 373 | + assert normal_serializer.validated_data['age'] == 25 |
| 374 | + |
| 375 | + value_serializer = serializer(data={'sex': 'male', 'age': 25}) |
| 376 | + |
| 377 | + assert value_serializer.is_valid() |
| 378 | + assert value_serializer.validated_data['sex'] == SexEnum.MALE |
| 379 | + assert value_serializer.validated_data['age'] == 25 |
| 380 | + |
| 381 | + bad_value_serializer = serializer(data={'sex': 'bad_value', 'age': 25}) |
| 382 | + |
| 383 | + assert bad_value_serializer.is_valid() is False |
0 commit comments