Skip to content
This repository was archived by the owner on Aug 19, 2023. It is now read-only.

Commit ac32af8

Browse files
authored
Support using discriminator to determine the Union variant (#171)
Support overlapping Union types with `anyOf`
1 parent b2c9d7f commit ac32af8

File tree

3 files changed

+58
-16
lines changed

3 files changed

+58
-16
lines changed

dataclasses_jsonschema/__init__.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,13 @@ def _get_fields_uncached():
394394
cls.__mapped_fields = _get_fields_uncached()
395395
return cls.__mapped_fields # type: ignore
396396

397-
def to_dict(self, omit_none: bool = True, validate: bool = False, validate_enums: bool = True) -> JsonDict:
397+
def to_dict(
398+
self,
399+
omit_none: bool = True,
400+
validate: bool = False,
401+
validate_enums: bool = True,
402+
schema_type: SchemaType = DEFAULT_SCHEMA_TYPE,
403+
) -> JsonDict:
398404
"""Converts the dataclass instance to a JSON encodable dict, with optional JSON schema validation.
399405
400406
If omit_none (default True) is specified, any items with value None are removed
@@ -417,7 +423,7 @@ def to_dict(self, omit_none: bool = True, validate: bool = False, validate_enums
417423
data[self.__discriminator_name] = self.__class__.__name__
418424

419425
if validate:
420-
self._validate(data, validate_enums)
426+
self._validate(data, validate_enums, schema_type)
421427
return data
422428

423429
@classmethod
@@ -487,11 +493,14 @@ def decoder(_, __, val): return val
487493
return decoder(field, field_type, value)
488494

489495
@classmethod
490-
def _validate(cls, data: JsonDict, validate_enums: bool = True):
496+
def _validate(cls, data: JsonDict, validate_enums: bool = True, schema_type: SchemaType = DEFAULT_SCHEMA_TYPE):
497+
if schema_type == SchemaType.OPENAPI_3 or schema_type == SchemaType.SWAGGER_V2:
498+
warnings.warn("Only draft-04, draft-06 and draft-07 schema types are supported for validation")
499+
schema_type = DEFAULT_SCHEMA_TYPE
500+
491501
try:
492502
if fast_validation:
493-
# TODO: Support validating with other schema types
494-
schema_validator = cls.__compiled_schema.get(SchemaOptions(DEFAULT_SCHEMA_TYPE, validate_enums))
503+
schema_validator = cls.__compiled_schema.get(SchemaOptions(schema_type, validate_enums))
495504
if schema_validator is None:
496505
formats = {}
497506
for encoder in cls._field_encoders.values():
@@ -500,17 +509,23 @@ def _validate(cls, data: JsonDict, validate_enums: bool = True):
500509
formats[schema['format']] = schema['pattern']
501510

502511
schema_validator = fastjsonschema.compile(
503-
cls.json_schema(validate_enums=validate_enums), formats=formats
512+
cls.json_schema(schema_type=schema_type, validate_enums=validate_enums), formats=formats
504513
)
505-
cls.__compiled_schema[SchemaOptions(DEFAULT_SCHEMA_TYPE, validate_enums)] = schema_validator
514+
cls.__compiled_schema[SchemaOptions(schema_type, validate_enums)] = schema_validator
506515
schema_validator(data)
507516
else:
508-
validate_func(data, cls.json_schema(validate_enums=validate_enums))
517+
validate_func(data, cls.json_schema(schema_type=schema_type, validate_enums=validate_enums))
509518
except JsonSchemaValidationError as e:
510519
raise ValidationError(str(e)) from e
511520

512521
@classmethod
513-
def from_dict(cls: Type[T], data: JsonDict, validate=True, validate_enums: bool = True) -> T:
522+
def from_dict(
523+
cls: Type[T],
524+
data: JsonDict,
525+
validate: bool = True,
526+
validate_enums: bool = True,
527+
schema_type: SchemaType = DEFAULT_SCHEMA_TYPE,
528+
) -> T:
514529
"""Returns a dataclass instance with all nested classes converted from the dict given"""
515530
if cls is JsonSchemaMixin:
516531
raise NotImplementedError
@@ -520,11 +535,14 @@ def from_dict(cls: Type[T], data: JsonDict, validate=True, validate_enums: bool
520535
for subclass in cls.__subclasses__():
521536
if subclass.__name__ == data[cls.__discriminator_name]:
522537
return subclass.from_dict(data, validate)
538+
raise TypeError(
539+
f"Class '{cls.__name__}' does not match discriminator '{data[cls.__discriminator_name]}'"
540+
)
523541

524542
init_values: Dict[str, Any] = {}
525543
non_init_values: Dict[str, Any] = {}
526544
if validate:
527-
cls._validate(data, validate_enums)
545+
cls._validate(data, validate_enums, schema_type)
528546

529547
for f in cls._get_fields():
530548
values = init_values if f.field.init else non_init_values
@@ -683,8 +701,8 @@ def _get_field_schema(cls, field: Union[Field, Type], schema_options: SchemaOpti
683701
elif field_type_name == 'Union':
684702
if schema_options.schema_type == SchemaType.SWAGGER_V2:
685703
raise TypeError('Type unions unsupported in Swagger 2.0')
686-
field_schema = {'oneOf': [cls._get_field_schema(variant, schema_options)[0] for variant in field_args]}
687-
field_schema['oneOf'].sort(key=lambda item: item.get('type', ''))
704+
field_schema = {'anyOf': [cls._get_field_schema(variant, schema_options)[0] for variant in field_args]}
705+
field_schema['anyOf'].sort(key=lambda item: item.get('type', ''))
688706
elif field_type_name in MAPPING_TYPES:
689707
field_schema = {'type': 'object'}
690708
if field_args[1] is not Any:

tests/test_core.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@
132132
'description': "Type with union field",
133133
'properties': {
134134
'a': {
135-
'oneOf': [
135+
'anyOf': [
136136
{'$ref': '#/definitions/Point'},
137137
{'type': 'string', 'enum': ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday']},
138138
]
@@ -621,7 +621,7 @@ class Baz(JsonSchemaMixin):
621621
"description": "Class with optional union",
622622
"type": "object",
623623
"properties": {
624-
"a": {"oneOf": [{"type": "integer"}, {"type": "string"}]}
624+
"a": {"anyOf": [{"type": "integer"}, {"type": "string"}]}
625625
}
626626
})
627627
assert Baz.json_schema() == expected_schema
@@ -1028,3 +1028,27 @@ class Config(JsonSchemaMixin):
10281028
c = Config({1: 'foo', 2: 'bar'})
10291029
data = c.to_json()
10301030
assert c == Config.from_json(data)
1031+
1032+
1033+
def test_union_with_discriminator():
1034+
@dataclass
1035+
class Pet(JsonSchemaMixin, discriminator=True):
1036+
pass
1037+
1038+
@dataclass
1039+
class Cat(Pet):
1040+
breed: str
1041+
1042+
@dataclass
1043+
class Dog(Pet):
1044+
breed: str
1045+
walk_distance: Optional[float] = None
1046+
1047+
@dataclass
1048+
class Person(JsonSchemaMixin):
1049+
name: str
1050+
pet: Union[Cat, Dog]
1051+
1052+
data = Person(name="Joe", pet=Dog(breed="Pug")).to_dict()
1053+
p = Person.from_dict(data)
1054+
assert p.pet == Dog(breed="Pug")

tests/test_peps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ class Post(JsonSchemaMixin):
1616

1717
schema = Post.json_schema()
1818
assert schema['properties']['tags'] == {
19-
'oneOf': [{'type': 'array', 'items': {'type': 'string'}}, {'type': 'string'}]
19+
'anyOf': [{'type': 'array', 'items': {'type': 'string'}}, {'type': 'string'}]
2020
}
2121
assert schema['properties']['metadata'] == {
22-
'type': 'array', 'items': {'oneOf': [{'type': 'integer'}, {'type': 'string'}]}
22+
'type': 'array', 'items': {'anyOf': [{'type': 'integer'}, {'type': 'string'}]}
2323
}
2424
assert schema['required'] == ['body', 'metadata']
2525

0 commit comments

Comments
 (0)