Skip to content

Commit 123a7ab

Browse files
Support customizing tagged union tag values
1 parent 45efa03 commit 123a7ab

6 files changed

Lines changed: 89 additions & 15 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ orjson = { version = "*", markers = "extra == 'orjson' or extra == 'all'", optio
4545
plum-dispatch = ">=2.3"
4646
beartype = ">=0.18.4"
4747
sqlalchemy = { version = ">2", markers = "extra == 'sqlalchemy' or extra == 'all'", optional = true }
48+
frozendict = ">2"
4849

4950
[tool.poetry.dev-dependencies]
5051
pyyaml = "*"

serde/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
coerce,
3636
DefaultTagging,
3737
ExternalTagging,
38+
ExternalTagging_,
3839
InternalTagging,
3940
disabled,
4041
strict,
@@ -81,6 +82,7 @@
8182
"SerdeSkip",
8283
"AdjacentTagging",
8384
"ExternalTagging",
85+
"ExternalTagging_",
8486
"InternalTagging",
8587
"Untagged",
8688
"disabled",

serde/core.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import re
1212
import casefy
1313
from dataclasses import dataclass
14+
from frozendict import frozendict
1415

1516
from beartype.door import is_bearable
1617
from collections.abc import Mapping, Sequence, Callable
@@ -776,6 +777,13 @@ class Kind(enum.Enum):
776777
tag: Optional[str] = None
777778
content: Optional[str] = None
778779
kind: Kind = Kind.External
780+
tags: Optional[frozendict[type[Any], str]] = None
781+
782+
def __init__(self, tag: Optional[str] = None, content: Optional[str] = None, kind: Kind = Kind.External, tags: Optional[dict[type[Any], str]] = None) -> None:
783+
self.tag = tag
784+
self.content = content
785+
self.kind = kind
786+
self.tags = frozendict(tags) if tags else None
779787

780788
def is_external(self) -> bool:
781789
return self.kind == self.Kind.External
@@ -793,6 +801,11 @@ def is_untagged(self) -> bool:
793801
def is_taggable(cls, typ: type[Any]) -> bool:
794802
return dataclasses.is_dataclass(typ)
795803

804+
def tag_for(self, typ: type[Any]) -> str:
805+
if self.tags and typ in self.tags:
806+
return self.tags[typ]
807+
return typename(typ)
808+
796809
def check(self) -> None:
797810
if self.is_internal() and self.tag is None:
798811
raise SerdeError('"tag" must be specified in InternalTagging')
@@ -825,39 +838,58 @@ def __call__(self, cls: T) -> _WithTagging[T]:
825838

826839

827840
@overload
828-
def InternalTagging(tag: str) -> Tagging: ...
841+
def InternalTagging(tag: str, *, tags: Optional[dict[type[Any], str]] = None) -> Tagging: ...
829842

830843

831844
@overload
832-
def InternalTagging(tag: str, cls: T) -> _WithTagging[T]: ...
845+
def InternalTagging(tag: str, cls: T, *, tags: Optional[dict[type[Any], str]] = None) -> _WithTagging[T]: ...
833846

834847

835-
def InternalTagging(tag: str, cls: Optional[T] = None) -> Union[Tagging, _WithTagging[T]]:
836-
tagging = Tagging(tag, kind=Tagging.Kind.Internal)
848+
def InternalTagging(tag: str, cls: Optional[T] = None, *, tags: Optional[dict[type[Any], str]] = None) -> Union[Tagging, _WithTagging[T]]:
849+
tagging = Tagging(tag, kind=Tagging.Kind.Internal, tags=tags)
837850
if cls:
838851
return tagging(cls)
839852
else:
840853
return tagging
841854

842855

843856
@overload
844-
def AdjacentTagging(tag: str, content: str) -> Tagging: ...
857+
def AdjacentTagging(tag: str, content: str, *, tags: Optional[dict[type[Any], str]] = None) -> Tagging: ...
845858

846859

847860
@overload
848-
def AdjacentTagging(tag: str, content: str, cls: T) -> _WithTagging[T]: ...
861+
def AdjacentTagging(tag: str, content: str, cls: T, *, tags: Optional[dict[type[Any], str]] = None) -> _WithTagging[T]: ...
849862

850863

851864
def AdjacentTagging(
852-
tag: str, content: str, cls: Optional[T] = None
865+
tag: str, content: str, cls: Optional[T] = None, *, tags: Optional[dict[type[Any], str]] = None
853866
) -> Union[Tagging, _WithTagging[T]]:
854-
tagging = Tagging(tag, content, kind=Tagging.Kind.Adjacent)
867+
tagging = Tagging(tag, content, kind=Tagging.Kind.Adjacent, tags=tags)
868+
if cls:
869+
return tagging(cls)
870+
else:
871+
return tagging
872+
873+
874+
@overload
875+
def ExternalTagging_(*, tags: dict[type[Any], str]) -> Tagging: ...
876+
877+
878+
@overload
879+
def ExternalTagging_(cls: T, *, tags: dict[type[Any], str]) -> _WithTagging[T]: ...
880+
881+
882+
def ExternalTagging_(cls: Optional[T] = None, *, tags: dict[type[Any], str]) -> Union[Tagging, _WithTagging[T]]:
883+
tagging = Tagging(kind=Tagging.Kind.External, tags=tags)
855884
if cls:
856885
return tagging(cls)
857886
else:
858887
return tagging
859888

860889

890+
# TODO: This is an instance rather than a function for backwards-compatibility
891+
# reasons. In the next major version increase this should be replaced with a
892+
# function.
861893
ExternalTagging = Tagging()
862894

863895
Untagged = Tagging(kind=Tagging.Kind.Untagged)

serde/de.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,18 +1129,18 @@ def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=Non
11291129
try:
11301130
# create fake dict so we can reuse the normal render function
11311131
{% if tagging.is_external() and is_taggable(t) %}
1132-
ensure("{{typename(t)}}" in data , "'{{typename(t)}}' key is not present")
1133-
fake_dict = {"fake_key": data["{{typename(t)}}"]}
1132+
ensure("{{tagging.tag_for(t)}}" in data , "'{{tagging.tag_for(t)}}' key is not present")
1133+
fake_dict = {"fake_key": data["{{tagging.tag_for(t)}}"]}
11341134
11351135
{% elif tagging.is_internal() and is_taggable(t) %}
11361136
ensure("{{tagging.tag}}" in data , "'{{tagging.tag}}' key is not present")
1137-
ensure("{{typename(t)}}" == data["{{tagging.tag}}"], "tag '{{typename(t)}}' isn't found")
1137+
ensure("{{tagging.tag_for(t)}}" == data["{{tagging.tag}}"], "tag '{{tagging.tag_for(t)}}' isn't found")
11381138
fake_dict = {"fake_key": data}
11391139
11401140
{% elif tagging.is_adjacent() and is_taggable(t) %}
11411141
ensure("{{tagging.tag}}" in data , "'{{tagging.tag}}' key is not present")
11421142
ensure("{{tagging.content}}" in data , "'{{tagging.content}}' key is not present")
1143-
ensure("{{typename(t)}}" == data["{{tagging.tag}}"], "tag '{{typename(t)}}' isn't found")
1143+
ensure("{{tagging.tag_for(t)}}" == data["{{tagging.tag}}"], "tag '{{tagging.tag_for(t)}}' isn't found")
11441144
fake_dict = {"fake_key": data["{{tagging.content}}"]}
11451145
11461146
{% else %}

serde/se.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,16 +619,16 @@ def {{func}}(obj, reuse_instances, convert_sets, skip_none=False):
619619
{% for t in union_args %}
620620
if is_instance(obj, union_args[{{loop.index0}}]):
621621
{% if tagging.is_external() and is_taggable(t) %}
622-
return {"{{typename(t)}}": {{rvalue(arg(t))}}}
622+
return {"{{tagging.tag_for(t)}}": {{rvalue(arg(t))}}}
623623
624624
{% elif tagging.is_internal() and is_taggable(t) %}
625625
res = {{rvalue(arg(t))}}
626-
res["{{tagging.tag}}"] = "{{typename(t)}}"
626+
res["{{tagging.tag}}"] = "{{tagging.tag_for(t)}}"
627627
return res
628628
629629
{% elif tagging.is_adjacent() and is_taggable(t) %}
630630
res = {"{{tagging.content}}": {{rvalue(arg(t))}}}
631-
res["{{tagging.tag}}"] = "{{typename(t)}}"
631+
res["{{tagging.tag}}"] = "{{tagging.tag_for(t)}}"
632632
return res
633633
634634
{% else %}

tests/test_union.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
to_tuple,
2828
InternalTagging,
2929
AdjacentTagging,
30+
ExternalTagging_,
3031
Untagged,
3132
)
3233
from serde.json import from_json, to_json
@@ -842,3 +843,41 @@ class Foo:
842843

843844
f = Foo([10])
844845
assert f == from_json(Foo, to_json(f))
846+
847+
848+
T = TypeVar("T")
849+
850+
851+
def _test_union_with_custom_tags_arguments():
852+
@dataclass
853+
class Foo:
854+
v: int
855+
856+
@dataclass
857+
class Bar:
858+
w: str
859+
860+
return [
861+
(
862+
ExternalTagging_(tags={Foo: "f", Bar: "b"})(Foo | Bar),
863+
Foo(1),
864+
'{"f":{"v":1}}',
865+
),
866+
(
867+
InternalTagging("type", tags={Foo: "f", Bar: "b"})(Foo | Bar),
868+
Foo(1),
869+
'{"v":1,"type":"f"}',
870+
),
871+
(
872+
AdjacentTagging("type", "content", tags={Foo: "f", Bar: "b"})(Foo | Bar),
873+
Foo(1),
874+
'{"content":{"v":1},"type":"f"}',
875+
),
876+
]
877+
878+
879+
@pytest.mark.parametrize("cls,deserialized,serialized", _test_union_with_custom_tags_arguments())
880+
def test_union_with_custom_tags(cls: type[T], deserialized: T, serialized: str) -> None:
881+
assert to_json(deserialized, cls=cls) == serialized
882+
assert from_json(cls, serialized) == deserialized
883+

0 commit comments

Comments
 (0)