Skip to content

Polymorphic (de)serializastion support (#7104) #7127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 187 additions & 42 deletions src/azul/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Any,
Callable,
Iterator,
Literal,
Optional,
Self,
Tuple,
Expand All @@ -37,10 +36,12 @@

from azul import (
R,
cached_property,
config,
require,
)
from azul.json import (
PolymorphicSerializable,
Serializable,
)
from azul.types import (
Expand Down Expand Up @@ -177,6 +178,9 @@ def validator(_instance, field, value):

type Source = list[str | tuple[str, ...] | Source]

type FromJSON = Callable[[AnyJSON], Any]
type ToJSON = Callable[[Any], AnyJSON]


class SerializableAttrs(Serializable, attrs.AttrsInstance):
"""
Expand Down Expand Up @@ -293,10 +297,6 @@ def _assert_concrete(cls):
assert not cls._deferred_fields, R(
'Class has fields of unknown type', cls._deferred_fields)

class Metadata(TypedDict):
from_json: Callable[[AnyJSON], Any] | None
to_json: Callable[[Any], AnyJSON] | None

def __init_subclass__(cls):
super().__init_subclass__()
try:
Expand Down Expand Up @@ -365,31 +365,22 @@ def _make(cls, fields: list[attrs.Attribute]) -> frozenset[str]:
cls._define(to_json)
return deferred_fields

@classmethod
def _serializable(cls,
field: attrs.Attribute,
key: Literal['from_json', 'to_json']
) -> bool:
try:
return field.metadata['azul'][key] is not None
except KeyError:
return True

@classmethod
def _make_from_json(cls, fields: list[attrs.Attribute]) -> Callable:
globals = {cls.__name__: cls}
deserializers = (cls.Deserializer(cls, field, globals) for field in fields)
source = cls._indent([
'@classmethod',
'def _from_json(cls, json):', [
f'kwargs = super({cls.__name__}, cls)._from_json(json)',
*flatten(
[
f'x = json["{field.name}"]',
*(cls.Deserializer(cls, field, globals).handle('x')),
f'kwargs["{field.name}"] = x'
f'x = json["{deserializer.field.name}"]',
*(deserializer.handle('x')),
f'kwargs["{deserializer.field.name}"] = x'
]
for field in fields
if cls._serializable(field, 'from_json')
for deserializer in deserializers
if deserializer.enabled
),
'return kwargs'
]
Expand All @@ -399,6 +390,7 @@ def _make_from_json(cls, fields: list[attrs.Attribute]) -> Callable:
@classmethod
def _make_to_json(cls, fields: list[attrs.Attribute]) -> Callable:
globals = {cls.__name__: cls}
serializers = (cls.Serializer(cls, field, globals) for field in fields)
to_json = cls._indent([
'def to_json(self):', [
# Using the super() shortcut would require messing with the
Expand All @@ -407,11 +399,11 @@ def _make_to_json(cls, fields: list[attrs.Attribute]) -> Callable:
f'json = super({cls.__name__}, self).to_json()',
*flatten(
[
f'x = self.{field.name}',
f'json["{field.name}"] = ' + cls.Serializer(cls, field, globals).handle('x')
f'x = self.{serializer.field.name}',
f'json["{serializer.field.name}"] = ' + serializer.handle('x')
]
for field in fields
if cls._serializable(field, 'to_json')
for serializer in serializers
if serializer.enabled
),
'return json'
]
Expand Down Expand Up @@ -478,13 +470,29 @@ class Strategy[T](metaclass=ABCMeta):
class MustDefer(Exception):
pass

def handle(self, x: str) -> T:
class Custom(TypedDict):
from_json: FromJSON | None
to_json: ToJSON | None

@cached_property
def custom(self) -> Custom | None:
return self._metadata('custom', None)

def _metadata[V](self, key: str, default: V) -> V:
try:
metadata = self.field.metadata['azul']
return self.field.metadata['azul'][key]
except KeyError:
return default

@cached_property
def discriminator(self) -> str | None:
return self._metadata('discriminator', None)

def handle(self, x: str) -> T:
if self.custom is None:
return self._handle(x, self._reify(self.field.type))
else:
return self._custom(x, metadata)
return self._custom(x)

def _owner(self) -> type:
"""
Expand Down Expand Up @@ -533,7 +541,12 @@ def _handle(self, x: str, field_type: Any):
elif issubclass(field_type, Serializable):
inner_cls_name = field_type.__name__
self.globals[inner_cls_name] = field_type
return self._serializable(x, inner_cls_name)
is_polymorphic = issubclass(field_type, PolymorphicSerializable)
has_discriminator = self.discriminator is not None
if is_polymorphic and has_discriminator:
return self._polymorphic(x, inner_cls_name)
else:
return self._serializable(x, inner_cls_name)
else:
origin = get_origin(field_type)
if origin in (Union, UnionType):
Expand All @@ -551,6 +564,11 @@ def _handle(self, x: str, field_type: Any):
return self._dict(x, key_type, value_type)
raise TypeError('Unserializable field', field_type, self.field)

@property
@abstractmethod
def enabled(self) -> bool:
raise NotImplementedError

@abstractmethod
def _primitive(self, x: str, field_type: type) -> T:
raise NotImplementedError
Expand All @@ -567,6 +585,10 @@ def _optional(self, x: str, field_type: type) -> T:
def _serializable(self, x: str, inner_cls_name: str) -> T:
raise NotImplementedError

@abstractmethod
def _polymorphic(self, x: str, inner_cls_name: str) -> T:
raise NotImplementedError

@abstractmethod
def _list(self, x: str, item_type: type) -> T:
raise NotImplementedError
Expand All @@ -576,11 +598,15 @@ def _dict(self, x: str, key_type: type, value_type: type) -> T:
raise NotImplementedError

@abstractmethod
def _custom(self, x: str, metadata: 'SerializableAttrs.Metadata') -> T:
def _custom(self, x: str) -> T:
raise NotImplementedError

class Deserializer(Strategy[Source]):

@property
def enabled(self) -> bool:
return self.custom is None or self.custom['from_json'] is not None

def _optional(self, x: str, field_type: type) -> Source:
return [
f'if {x} is not None:', self._handle(x, field_type)
Expand All @@ -591,6 +617,15 @@ def _serializable(self, x: str, inner_cls_name: str) -> Source:
f'{x} = {inner_cls_name}.from_json({x})'
]

def _polymorphic(self, x: str, inner_cls_name: str) -> Source:
depth = next(self.depth)
cls = f'cls{depth}'
return [
f'{cls} = {x}["{self.discriminator}"]',
f'{cls} = {inner_cls_name}.cls_from_json({cls})',
f'{x} = {cls}.from_json({x})'
]

def _primitive(self, x: str, field_type: type) -> Source:
return [
f'if not isinstance({x}, {field_type.__name__}):', [
Expand Down Expand Up @@ -632,15 +667,20 @@ def _dict(self, x: str, key_type: type, value_type: type) -> Source:
f'{x} = {d}'
]

def _custom(self, x: str, metadata: 'SerializableAttrs.Metadata') -> Source:
def _custom(self, x: str) -> Source:
var_name = self.field.name + '_from_json'
self.globals[var_name] = not_none(metadata['from_json'])
from_json = not_none(not_none(self.custom)['from_json'])
self.globals[var_name] = from_json
return [
f'{x} = {var_name}({x})'
]

class Serializer(Strategy[str]):

@property
def enabled(self) -> bool:
return self.custom is None or self.custom['to_json'] is not None

def _primitive(self, x: str, field_type: type) -> str:
return x

Expand All @@ -653,6 +693,9 @@ def _optional(self, x: str, field_type: type) -> str:
def _serializable(self, x: str, inner_cls_name: str) -> str:
return f'{x}.to_json()'

def _polymorphic(self, x: str, inner_cls_name: str) -> str:
return f'dict({x}.to_json(), {self.discriminator}={x}.cls_to_json())'

def _list(self, x: str, item_type: type) -> str:
depth = next(self.depth)
v = f'v{depth}'
Expand All @@ -665,32 +708,34 @@ def _dict(self, x: str, key_type: type, value_type: type) -> str:
k_, v_ = self._handle(k, key_type), self._handle(v, value_type)
return f'{{{k_}: {v_} for {k}, {v} in x.items()}}'

def _custom(self, x: str, metadata: 'SerializableAttrs.Metadata') -> str:
def _custom(self, x: str) -> str:
to_json = not_none(not_none(self.custom)['to_json'])
var_name = self.field.name + '_to_json'
self.globals[var_name] = not_none(metadata['to_json'])
self.globals[var_name] = to_json
return f'{var_name}({x})'


def serializable[T: attrs.Attribute](field: T,
from_json: Callable[[AnyJSON], Any],
to_json: Callable[[Any], AnyJSON]) -> T:
def serializable[T: attrs.Attribute](field: T | None = None,
*,
from_json: FromJSON,
to_json: ToJSON) -> T:
"""
Use the provided callables to (de)serialize values of the given field,
instead of generating them.

>>> @attrs.frozen
... class Foo(SerializableAttrs):
... x: set[str] = serializable(attrs.field(), to_json=sorted, from_json=set)
... x: set[str] = serializable(to_json=sorted, from_json=set)

>>> Foo(x={'b','a'}).to_json()
{'x': ['a', 'b']}

>>> Foo.from_json({'x': ['a']})
Foo(x={'a'})
"""
field.metadata['azul'] = SerializableAttrs.Metadata(from_json=from_json,
to_json=to_json)
return field
custom = SerializableAttrs.Strategy.Custom(from_json=from_json,
to_json=to_json)
return _set_field_metadata(field, 'custom', custom)


def not_serializable[T: attrs.Attribute](field: T) -> T:
Expand All @@ -710,6 +755,106 @@ def not_serializable[T: attrs.Attribute](field: T) -> T:
>>> Foo.from_json({})
Foo(x=42)
"""
field.metadata['azul'] = SerializableAttrs.Metadata(from_json=None,
to_json=None)
custom = SerializableAttrs.Strategy.Custom(from_json=None,
to_json=None)
return _set_field_metadata(field, 'custom', custom)


def _set_field_metadata[T: attrs.Attribute](field: T | None, key, value):
if field is None:
field = attrs.field()
metadata = field.metadata.setdefault('azul', {})
metadata[key] = value
return field


def polymorphic[T: attrs.Attribute](field: T | None = None,
*,
discriminator: str
) -> T:
"""
Mark an attrs field to use the given name for the discriminator property in
serialized instances of PolymorphicSerializable that occur in the value of
that field. The given discriminator property of a serialized instance
represents the type to use when deserializing that instance again.

>>> from azul.json import RegisteredPolymorphicSerializable

>>> class Inner(SerializableAttrs, RegisteredPolymorphicSerializable):
... pass

>>> @attrs.frozen
... class InnerWithInt(Inner):
... x: int

>>> @attrs.frozen
... class InnerWithStr(Inner):
... y: str

>>> @attrs.frozen(kw_only=True)
... class Outer(SerializableAttrs):
... inner: Inner = polymorphic(discriminator='type')
... inners: list[Inner] = polymorphic(discriminator='_cls')

>>> from azul.doctests import assert_json

>>> outer = Outer(inner=InnerWithInt(42),
... inners=[InnerWithStr('foo'), InnerWithInt(7)])
>>> assert_json(outer.to_json())
{
"inner": {
"x": 42,
"type": "InnerWithInt"
},
"inners": [
{
"y": "foo",
"_cls": "InnerWithStr"
},
{
"x": 7,
"_cls": "InnerWithInt"
}
]
}
>>> Outer.from_json(outer.to_json()) == outer
True

In order to enable polymorphic serialization of the value of a given field,
the discriminator property needs to be specified explicitly, otherwise the
serialization framework will resort to the static type of the field.

>>> @attrs.frozen
... class GenericOuter[T: Inner](SerializableAttrs):
... inner: T

>>> class StaticOuter(GenericOuter[InnerWithInt]):
... pass

>>> outer = StaticOuter(InnerWithInt(42))
>>> outer.to_json()
{'inner': {'x': 42}}

Despite the fact that ``{'x': 42}`` does not encode any type information,
``from_json`` can tell from the static type of the field that {'x': 42}
should be deserialized as an ``InnerWithInt``.

>>> StaticOuter.from_json(outer.to_json()).inner
InnerWithInt(x=42)

>>> StaticOuter.from_json(outer.to_json()) == outer
True

However, when the static type of the field is not concrete, deserialization
may fail or, like in this case, lose information by creating an instance of
the parent class instead of the class that was serialized.

>>> @attrs.frozen
... class AbstractOuter(SerializableAttrs):
... inner: Inner

>>> outer = AbstractOuter(InnerWithInt(42))
>>> AbstractOuter.from_json(outer.to_json()).inner # doctest: +ELLIPSIS
<azul.attrs.Inner object at ...>
"""
return _set_field_metadata(field, 'discriminator', discriminator)
Loading