1111import re
1212import casefy
1313from dataclasses import dataclass
14+ from frozendict import frozendict
1415
1516from beartype .door import is_bearable
1617from collections .abc import Mapping , Sequence , Callable
@@ -776,6 +777,13 @@ class Kind(enum.Enum):
776777 tag : Optional [str ] = None
777778 content : Optional [str ] = None
778779 kind : Kind = Kind .External
780+ tags : Optional [frozendict [type [Any ], str ]] = None
781+
782+ def __init__ (self , tag : Optional [str ] = None , content : Optional [str ] = None , kind : Kind = Kind .External , tags : Optional [dict [type [Any ], str ]] = None ) -> None :
783+ self .tag = tag
784+ self .content = content
785+ self .kind = kind
786+ self .tags = frozendict (tags ) if tags else None
779787
780788 def is_external (self ) -> bool :
781789 return self .kind == self .Kind .External
@@ -793,6 +801,11 @@ def is_untagged(self) -> bool:
793801 def is_taggable (cls , typ : type [Any ]) -> bool :
794802 return dataclasses .is_dataclass (typ )
795803
804+ def tag_for (self , typ : type [Any ]) -> str :
805+ if self .tags and typ in self .tags :
806+ return self .tags [typ ]
807+ return typename (typ )
808+
796809 def check (self ) -> None :
797810 if self .is_internal () and self .tag is None :
798811 raise SerdeError ('"tag" must be specified in InternalTagging' )
@@ -825,39 +838,58 @@ def __call__(self, cls: T) -> _WithTagging[T]:
825838
826839
827840@overload
828- def InternalTagging (tag : str ) -> Tagging : ...
841+ def InternalTagging (tag : str , * , tags : Optional [ dict [ type [ Any ], str ]] = None ) -> Tagging : ...
829842
830843
831844@overload
832- def InternalTagging (tag : str , cls : T ) -> _WithTagging [T ]: ...
845+ def InternalTagging (tag : str , cls : T , * , tags : Optional [ dict [ type [ Any ], str ]] = None ) -> _WithTagging [T ]: ...
833846
834847
835- def InternalTagging (tag : str , cls : Optional [T ] = None ) -> Union [Tagging , _WithTagging [T ]]:
836- tagging = Tagging (tag , kind = Tagging .Kind .Internal )
848+ def InternalTagging (tag : str , cls : Optional [T ] = None , * , tags : Optional [ dict [ type [ Any ], str ]] = None ) -> Union [Tagging , _WithTagging [T ]]:
849+ tagging = Tagging (tag , kind = Tagging .Kind .Internal , tags = tags )
837850 if cls :
838851 return tagging (cls )
839852 else :
840853 return tagging
841854
842855
843856@overload
844- def AdjacentTagging (tag : str , content : str ) -> Tagging : ...
857+ def AdjacentTagging (tag : str , content : str , * , tags : Optional [ dict [ type [ Any ], str ]] = None ) -> Tagging : ...
845858
846859
847860@overload
848- def AdjacentTagging (tag : str , content : str , cls : T ) -> _WithTagging [T ]: ...
861+ def AdjacentTagging (tag : str , content : str , cls : T , * , tags : Optional [ dict [ type [ Any ], str ]] = None ) -> _WithTagging [T ]: ...
849862
850863
851864def AdjacentTagging (
852- tag : str , content : str , cls : Optional [T ] = None
865+ tag : str , content : str , cls : Optional [T ] = None , * , tags : Optional [ dict [ type [ Any ], str ]] = None
853866) -> Union [Tagging , _WithTagging [T ]]:
854- tagging = Tagging (tag , content , kind = Tagging .Kind .Adjacent )
867+ tagging = Tagging (tag , content , kind = Tagging .Kind .Adjacent , tags = tags )
868+ if cls :
869+ return tagging (cls )
870+ else :
871+ return tagging
872+
873+
874+ @overload
875+ def ExternalTagging_ (* , tags : dict [type [Any ], str ]) -> Tagging : ...
876+
877+
878+ @overload
879+ def ExternalTagging_ (cls : T , * , tags : dict [type [Any ], str ]) -> _WithTagging [T ]: ...
880+
881+
882+ def ExternalTagging_ (cls : Optional [T ] = None , * , tags : dict [type [Any ], str ]) -> Union [Tagging , _WithTagging [T ]]:
883+ tagging = Tagging (kind = Tagging .Kind .External , tags = tags )
855884 if cls :
856885 return tagging (cls )
857886 else :
858887 return tagging
859888
860889
890+ # TODO: This is an instance rather than a function for backwards-compatibility
891+ # reasons. In the next major version increase this should be replaced with a
892+ # function.
861893ExternalTagging = Tagging ()
862894
863895Untagged = Tagging (kind = Tagging .Kind .Untagged )
0 commit comments