Skip to content

Commit d8e6476

Browse files
committed
Add Support for marshmallow.fields.Enum in marshmallow ≥ v3.18
This fixes #169. Detailed changes: * Introduce distinction between Enums imports from `marshmallow_enum` and `marshmallow.fields` (the latter are refered to as "marshmallow native" Enums) * Add function to find out if the version of marshmallow used supports the native Enum type * Add test cases that reproduce the issue * Adapt the code to also support the native enums
1 parent 45374be commit d8e6476

File tree

3 files changed

+96
-16
lines changed

3 files changed

+96
-16
lines changed

marshmallow_jsonschema/base.py

+42-10
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,21 @@
99
from marshmallow.class_registry import get_class
1010
from marshmallow.decorators import post_dump
1111
from marshmallow.utils import _Missing
12-
1312
from marshmallow import INCLUDE, EXCLUDE, RAISE
13+
# marshmallow.fields.Enum support has been added in marshmallow v3.18
14+
# see https://github.com/marshmallow-code/marshmallow/blob/dev/CHANGELOG.rst#3180-2022-09-15
15+
from marshmallow import __version__ as _MarshmallowVersion
16+
# the package "packaging" is a requirement of marshmallow itself => we don't need to install it separately
17+
# see https://github.com/marshmallow-code/marshmallow/blob/ddbe06f923befe754e213e03fb95be54e996403d/setup.py#L61
18+
from packaging.version import Version
19+
20+
21+
def marshmallow_version_supports_native_enums() -> bool:
22+
"""
23+
returns true if and only if the version of marshmallow installed supports enums natively
24+
"""
25+
return Version(_MarshmallowVersion) >= Version("3.18")
26+
1427

1528
try:
1629
from marshmallow_union import Union
@@ -20,11 +33,15 @@
2033
ALLOW_UNIONS = False
2134

2235
try:
23-
from marshmallow_enum import EnumField, LoadDumpOptions
36+
from marshmallow_enum import EnumField as MarshmallowEnumEnumField, LoadDumpOptions
2437

25-
ALLOW_ENUMS = True
38+
ALLOW_MARSHMALLOW_ENUM_ENUMS = True
2639
except ImportError:
27-
ALLOW_ENUMS = False
40+
ALLOW_MARSHMALLOW_ENUM_ENUMS = False
41+
42+
ALLOW_MARSHMALLOW_NATIVE_ENUMS = marshmallow_version_supports_native_enums()
43+
if ALLOW_MARSHMALLOW_NATIVE_ENUMS:
44+
from marshmallow.fields import Enum as MarshmallowNativeEnumField
2845

2946
from .exceptions import UnsupportedValueError
3047
from .validation import (
@@ -92,10 +109,12 @@
92109
(fields.Nested, dict),
93110
]
94111

95-
if ALLOW_ENUMS:
112+
if ALLOW_MARSHMALLOW_NATIVE_ENUMS:
113+
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((MarshmallowNativeEnumField, Enum))
114+
if ALLOW_MARSHMALLOW_ENUM_ENUMS:
96115
# We currently only support loading enum's from their names. So the possible
97116
# values will always map to string in the JSONSchema
98-
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((EnumField, Enum))
117+
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((MarshmallowEnumEnumField, Enum))
99118

100119

101120
FIELD_VALIDATORS = {
@@ -191,8 +210,10 @@ def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]:
191210
if field.default is not missing and not callable(field.default):
192211
json_schema["default"] = field.default
193212

194-
if ALLOW_ENUMS and isinstance(field, EnumField):
195-
json_schema["enum"] = self._get_enum_values(field)
213+
if ALLOW_MARSHMALLOW_NATIVE_ENUMS and isinstance(field, MarshmallowNativeEnumField):
214+
json_schema["enum"] = self._get_marshmallow_native_enum_values(field)
215+
elif ALLOW_MARSHMALLOW_ENUM_ENUMS and isinstance(field, MarshmallowEnumEnumField):
216+
json_schema["enum"] = self._get_marshmallow_enum_enum_values(field)
196217

197218
if field.allow_none:
198219
previous_type = json_schema["type"]
@@ -218,8 +239,8 @@ def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]:
218239
)
219240
return json_schema
220241

221-
def _get_enum_values(self, field) -> typing.List[str]:
222-
assert ALLOW_ENUMS and isinstance(field, EnumField)
242+
def _get_marshmallow_enum_enum_values(self, field) -> typing.List[str]:
243+
assert ALLOW_MARSHMALLOW_ENUM_ENUMS and isinstance(field, MarshmallowEnumEnumField)
223244

224245
if field.load_by == LoadDumpOptions.value:
225246
# Python allows enum values to be almost anything, so it's easier to just load from the
@@ -229,6 +250,17 @@ def _get_enum_values(self, field) -> typing.List[str]:
229250
)
230251

231252
return [value.name for value in field.enum]
253+
def _get_marshmallow_native_enum_values(self, field) -> typing.List[str]:
254+
assert ALLOW_MARSHMALLOW_NATIVE_ENUMS and isinstance(field, MarshmallowNativeEnumField)
255+
256+
if field.by_value:
257+
# Python allows enum values to be almost anything, so it's easier to just load from the
258+
# names of the enum's which will have to be strings.
259+
raise NotImplementedError(
260+
"Currently do not support JSON schema for enums loaded by value"
261+
)
262+
263+
return [value.name for value in field.enum]
232264

233265
def _from_union_schema(
234266
self, obj, field

tests/test_dump.py

+53-5
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33

44
import pytest
55
from marshmallow import Schema, fields, validate
6-
from marshmallow_enum import EnumField
6+
from marshmallow_enum import EnumField as MarshmallowEnumEnumField
77
from marshmallow_union import Union
88

9+
import marshmallow_jsonschema
910
from marshmallow_jsonschema import JSONSchema, UnsupportedValueError
1011
from . import UserSchema, validate_and_dump
1112

13+
TEST_MARSHMALLOW_NATIVE_ENUM = marshmallow_jsonschema.base.marshmallow_version_supports_native_enums()
14+
try:
15+
from marshmallow.fields import Enum as MarshmallowNativeEnumField
16+
except ImportError:
17+
assert TEST_MARSHMALLOW_NATIVE_ENUM is False
1218

1319
def test_dump_schema():
1420
schema = UserSchema()
@@ -648,14 +654,14 @@ class Meta:
648654
assert properties_names == ["d", "c", "a"]
649655

650656

651-
def test_enum_based():
657+
def test_marshmallow_enum_enum_based():
652658
class TestEnum(Enum):
653659
value_1 = 0
654660
value_2 = 1
655661
value_3 = 2
656662

657663
class TestSchema(Schema):
658-
enum_prop = EnumField(TestEnum)
664+
enum_prop = MarshmallowEnumEnumField(TestEnum)
659665

660666
# Should be sorting of fields
661667
schema = TestSchema()
@@ -671,15 +677,39 @@ class TestSchema(Schema):
671677
)
672678
assert received_enum_values == ["value_1", "value_2", "value_3"]
673679

680+
def test_native_marshmallow_enum_based():
681+
if not TEST_MARSHMALLOW_NATIVE_ENUM:
682+
return
683+
class TestEnum(Enum):
684+
value_1 = 0
685+
value_2 = 1
686+
value_3 = 2
687+
688+
class TestSchema(Schema):
689+
enum_prop = MarshmallowNativeEnumField(TestEnum)
690+
691+
# Should be sorting of fields
692+
schema = TestSchema()
693+
694+
json_schema = JSONSchema()
695+
data = json_schema.dump(schema)
696+
697+
assert (
698+
data["definitions"]["TestSchema"]["properties"]["enum_prop"]["type"] == "string"
699+
)
700+
received_enum_values = sorted(
701+
data["definitions"]["TestSchema"]["properties"]["enum_prop"]["enum"]
702+
)
703+
assert received_enum_values == ["value_1", "value_2", "value_3"]
674704

675-
def test_enum_based_load_dump_value():
705+
def test_marshmallow_enum_enum_based_load_dump_value():
676706
class TestEnum(Enum):
677707
value_1 = 0
678708
value_2 = 1
679709
value_3 = 2
680710

681711
class TestSchema(Schema):
682-
enum_prop = EnumField(TestEnum, by_value=True)
712+
enum_prop = MarshmallowEnumEnumField(TestEnum, by_value=True)
683713

684714
# Should be sorting of fields
685715
schema = TestSchema()
@@ -689,6 +719,24 @@ class TestSchema(Schema):
689719
with pytest.raises(NotImplementedError):
690720
validate_and_dump(json_schema.dump(schema))
691721

722+
def test_native_marshmallow_enum_based_load_dump_value():
723+
if not TEST_MARSHMALLOW_NATIVE_ENUM:
724+
return
725+
class TestEnum(Enum):
726+
value_1 = 0
727+
value_2 = 1
728+
value_3 = 2
729+
730+
class TestSchema(Schema):
731+
enum_prop = MarshmallowNativeEnumField(TestEnum, by_value=True)
732+
733+
# Should be sorting of fields
734+
schema = TestSchema()
735+
736+
json_schema = JSONSchema()
737+
738+
with pytest.raises(NotImplementedError):
739+
validate_and_dump(json_schema.dump(schema))
692740

693741
def test_union_based():
694742
class TestNestedSchema(Schema):

tests/test_imports.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_import_marshmallow_enum(monkeypatch):
1919

2020
base = importlib.reload(marshmallow_jsonschema.base)
2121

22-
assert not base.ALLOW_ENUMS
22+
assert not base.ALLOW_MARSHMALLOW_ENUM_ENUMS
2323

2424
monkeypatch.undo()
2525

0 commit comments

Comments
 (0)