|
| 1 | +import dataclasses |
| 2 | +import importlib |
| 3 | +import json |
| 4 | +import typing |
| 5 | + |
| 6 | +from ..datastore.artifacts.serializer import SerializedBlob |
| 7 | +from .base import IOType, _UNSET |
| 8 | + |
| 9 | + |
| 10 | +def _reconstruct(dc_type, data): |
| 11 | + """ |
| 12 | + Rebuild a dataclass instance from JSON-decoded ``data``, recursing into |
| 13 | + fields whose annotation is itself a dataclass. Containerized annotations |
| 14 | + (``List[Foo]``, ``Dict[str, Foo]``, ``Optional[Foo]``, ...) are left as |
| 15 | + raw JSON-decoded values; callers that need rich container reconstruction |
| 16 | + should wrap the field explicitly (e.g. in a ``List`` IOType shipped by |
| 17 | + an extension). |
| 18 | + """ |
| 19 | + try: |
| 20 | + hints = typing.get_type_hints(dc_type) |
| 21 | + except Exception: |
| 22 | + hints = {} |
| 23 | + kwargs = {} |
| 24 | + for f in dataclasses.fields(dc_type): |
| 25 | + raw = data.get(f.name) |
| 26 | + annotation = hints.get(f.name, f.type) |
| 27 | + if ( |
| 28 | + isinstance(annotation, type) |
| 29 | + and dataclasses.is_dataclass(annotation) |
| 30 | + and isinstance(raw, dict) |
| 31 | + ): |
| 32 | + kwargs[f.name] = _reconstruct(annotation, raw) |
| 33 | + else: |
| 34 | + kwargs[f.name] = raw |
| 35 | + return dc_type(**kwargs) |
| 36 | + |
| 37 | + |
| 38 | +class Struct(IOType): |
| 39 | + """ |
| 40 | + Structured type mapping to a Python ``@dataclass``. |
| 41 | +
|
| 42 | + Wire: JSON string. Storage: JSON UTF-8 bytes. |
| 43 | +
|
| 44 | + Wraps a ``@dataclass`` instance. On save, ``dataclasses.asdict`` flattens |
| 45 | + the whole tree to plain dicts; on load, fields typed as dataclasses are |
| 46 | + recursively rebuilt into their original types. Generic container |
| 47 | + annotations (``List[Foo]``, ``Dict[str, Foo]``, ``Optional[Foo]``) are |
| 48 | + not walked — those fields come back as raw JSON-decoded values. Wrap |
| 49 | + those explicitly (e.g. via ``List[Struct]`` support shipped by an |
| 50 | + extension) when you need typed containers. |
| 51 | +
|
| 52 | + .. warning:: |
| 53 | + ``Struct._storage_deserialize`` imports the dataclass module named in |
| 54 | + the artifact metadata. Metadata written by this class is safe, but |
| 55 | + metadata supplied from an untrusted source can trigger arbitrary |
| 56 | + imports (and any import-time side effects those modules carry). |
| 57 | + Only load artifacts from sources you trust. |
| 58 | +
|
| 59 | + Parameters |
| 60 | + ---------- |
| 61 | + value : dataclass instance or dict, optional |
| 62 | + The wrapped value. Dataclass instances are serialized via |
| 63 | + ``dataclasses.asdict``; plain dicts are serialized directly. |
| 64 | + dataclass_type : type, optional |
| 65 | + The ``@dataclass`` class, for type-descriptor use (no value). |
| 66 | + """ |
| 67 | + |
| 68 | + type_name = "struct" |
| 69 | + |
| 70 | + def __init__(self, value=_UNSET, dataclass_type=None): |
| 71 | + if value is not _UNSET and dataclasses.is_dataclass(value): |
| 72 | + self._dataclass_type = type(value) |
| 73 | + elif dataclass_type is not None: |
| 74 | + self._dataclass_type = dataclass_type |
| 75 | + else: |
| 76 | + self._dataclass_type = None |
| 77 | + super().__init__(value) |
| 78 | + |
| 79 | + def _to_dict(self): |
| 80 | + """Convert value to dict, handling both dataclass and plain dict.""" |
| 81 | + if dataclasses.is_dataclass(self._value): |
| 82 | + return dataclasses.asdict(self._value) |
| 83 | + if isinstance(self._value, dict): |
| 84 | + return self._value |
| 85 | + raise TypeError( |
| 86 | + "Struct value must be a dataclass instance or dict, got %s" |
| 87 | + % type(self._value).__name__ |
| 88 | + ) |
| 89 | + |
| 90 | + def _wire_serialize(self): |
| 91 | + return json.dumps(self._to_dict(), separators=(",", ":"), sort_keys=True) |
| 92 | + |
| 93 | + @classmethod |
| 94 | + def _wire_deserialize(cls, s): |
| 95 | + return cls(json.loads(s)) |
| 96 | + |
| 97 | + def _storage_serialize(self): |
| 98 | + blob = json.dumps( |
| 99 | + self._to_dict(), separators=(",", ":"), sort_keys=True |
| 100 | + ).encode("utf-8") |
| 101 | + meta = {} |
| 102 | + if self._dataclass_type is not None: |
| 103 | + meta["dataclass_module"] = self._dataclass_type.__module__ |
| 104 | + meta["dataclass_class"] = self._dataclass_type.__name__ |
| 105 | + return [SerializedBlob(blob)], meta |
| 106 | + |
| 107 | + @classmethod |
| 108 | + def _storage_deserialize(cls, blobs, **kwargs): |
| 109 | + data = json.loads(blobs[0].decode("utf-8")) |
| 110 | + metadata = kwargs.get("metadata", {}) |
| 111 | + dc_module = metadata.get("dataclass_module") |
| 112 | + dc_class = metadata.get("dataclass_class") |
| 113 | + if dc_module and dc_class: |
| 114 | + mod = importlib.import_module(dc_module) |
| 115 | + dc_type = getattr(mod, dc_class) |
| 116 | + # Guard against crafted metadata — require a class that's |
| 117 | + # actually a dataclass. ``dataclasses.is_dataclass`` alone |
| 118 | + # returns True for dataclass *instances*; the ``isinstance(..., type)`` |
| 119 | + # check excludes that (and anything else callable). |
| 120 | + if not (isinstance(dc_type, type) and dataclasses.is_dataclass(dc_type)): |
| 121 | + raise ValueError( |
| 122 | + "Struct metadata references '%s.%s' which is not a dataclass" |
| 123 | + % (dc_module, dc_class) |
| 124 | + ) |
| 125 | + return cls(_reconstruct(dc_type, data), dataclass_type=dc_type) |
| 126 | + # Fallback: return as plain dict wrapped in Struct |
| 127 | + return cls(data) |
0 commit comments