Skip to content

Commit a8e9ca6

Browse files
Saeid BaratiSaeid Barati
authored andcommitted
Add IOType ABC, Json, Struct, and the datastore bridge
On top of Netflix#3117, this ships the typed-artifact layer. Scope kept deliberately small: - ``IOType`` ABC — the contract extension authors target. - ``Json`` and ``Struct`` — two concrete types with clear standalone value in core: wire format for CLI/IPC, cross-language JSON bytes on storage, no pickle code-execution risk. ``Struct`` also walks directly-nested ``@dataclass`` fields so ``Outer(inner=Inner(...))`` round-trips back to its original type (generic containers like ``List[Inner]`` come back as raw JSON values — wrap those explicitly when you need richer reconstruction). - ``IOTypeSerializer`` — the bridge that plugs any ``IOType`` instance into the ``ArtifactSerializer`` dispatch added by Netflix#3117 so save/load through the datastore just works. What's intentionally *not* in this PR - Primitive wrappers (Int32/Int64/Float32/Float64/Bool/Text). Standard Python numbers and strings flow through ``PickleSerializer`` unchanged. Wrapping is opt-in, for cases where you want constraints/metadata attached. - ``Tensor``. Pulls in numpy + byte-order/dtype opinions; belongs in an extension that can own those choices. - ``List`` / ``Map`` / ``Enum``. Thin wrappers whose value over plain JSON is mostly schema emission — not enough on their own for core. - Rich schema emission from ``Struct.to_spec()``. Extensions that ship primitive wrappers can override to emit fully-typed schemas; core just returns ``{"type": "struct"}``. Contract ``serialize(format=...)`` / ``deserialize(data, format=..., **kw)`` mirror the ``ArtifactSerializer`` signature from Netflix#3117 and use the same ``WIRE`` / ``STORAGE`` constants, so one subclass owns both representations: - ``STORAGE`` → ``(List[SerializedBlob], metadata_dict)`` for persisting through the datastore. - ``WIRE`` → ``str`` for CLI args, protobuf payloads, and cross-process IPC. Subclasses implement four hooks (``_wire_serialize``, ``_wire_deserialize``, ``_storage_serialize``, ``_storage_deserialize``). Instantiating without the hooks raises ``TypeError``. ``IOTypeSerializer`` is registered via ``ARTIFACT_SERIALIZERS_DESC`` with ``PRIORITY=50`` — ahead of the default 100 so it catches ``IOType`` instances before a generic catch-all, and always ahead of the ``PickleSerializer`` fallback (9999). It implements only ``STORAGE``; wire encoding is produced by calling ``IOType.serialize(format=WIRE)`` directly. Safety - ``Struct._storage_deserialize`` and ``IOTypeSerializer.deserialize`` both require the class named in artifact metadata to be an actual class (``isinstance(..., type)``) before any further checks. This excludes module-level dataclass *instances* (``is_dataclass`` alone returns ``True`` for those) and other callables that could be invoked with attacker-controlled kwargs. - Importing the metadata-named module can still run module-level side-effect code; the ``Struct`` docstring calls this out so callers don't load artifacts from untrusted sources. Tests - ``test_base.py`` — abstract instantiation, WIRE/STORAGE dispatch, invalid format, equality/hash, spec. - ``test_json_type.py`` — wire and storage round-trips. - ``test_struct_type.py`` — dataclass round-trip, dict round-trip, directly-nested dataclass round-trip, container-field pass-through, rejection of non-dataclass and dataclass-instance metadata. - ``test_iotype_serializer.py`` — bridge ``can_serialize``/``can_deserialize``, round-trip through dataclass reconstruction, rejection of non-IOType classes in metadata, WIRE not supported on the bridge.
1 parent fed2c78 commit a8e9ca6

11 files changed

Lines changed: 821 additions & 0 deletions

File tree

metaflow/io_types/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base import IOType
2+
from .json_type import Json
3+
from .struct_type import Struct

metaflow/io_types/base.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""Typed-artifact contract for Metaflow.
2+
3+
This module defines the minimal :class:`IOType` abstract base class. OSS
4+
Metaflow ships the contract; concrete types (scalars, tensors, enums,
5+
dataclass-backed structs, etc.) live in extensions — they embody
6+
deployment-specific opinions about encoding, byte order, and dataclass
7+
inference that do not belong in core.
8+
9+
:class:`IOType` mirrors the ``format`` argument introduced on
10+
:class:`metaflow.datastore.artifacts.serializer.ArtifactSerializer` so a
11+
single subclass can own both representations:
12+
13+
- ``STORAGE`` — blob-based, persisted through the datastore.
14+
- ``WIRE`` — string-based, for CLI args, protobuf payloads, and
15+
cross-process IPC.
16+
17+
Subclasses implement four hooks (``_wire_serialize``, ``_wire_deserialize``,
18+
``_storage_serialize``, ``_storage_deserialize``); callers use the public
19+
``serialize(format=...)`` / ``deserialize(data, format=...)`` methods.
20+
"""
21+
22+
from abc import ABCMeta, abstractmethod
23+
24+
from metaflow.datastore.artifacts.serializer import STORAGE, WIRE
25+
26+
27+
_UNSET = object()
28+
29+
30+
class IOType(object, metaclass=ABCMeta):
31+
"""
32+
Base class for typed Metaflow artifacts.
33+
34+
An :class:`IOType` instance plays two roles:
35+
36+
- **Descriptor** (no value): ``Int64`` in a spec describes an int64
37+
field.
38+
- **Wrapper** (with value): ``Int64(42)`` wraps a value for typed
39+
serialization.
40+
41+
Subclasses implement four internal operations, dispatched by the
42+
``format`` argument of the public :meth:`serialize` / :meth:`deserialize`
43+
methods.
44+
"""
45+
46+
type_name = None # e.g. "text", "json", "int64" — set by subclasses.
47+
48+
def __init__(self, value=_UNSET):
49+
self._value = value
50+
51+
@property
52+
def value(self):
53+
"""The wrapped Python value, or ``_UNSET`` if this is a pure descriptor."""
54+
return self._value
55+
56+
# -- Public API --------------------------------------------------------
57+
58+
def serialize(self, format=STORAGE):
59+
"""
60+
Serialize the wrapped value. Must be side-effect-free.
61+
62+
Parameters
63+
----------
64+
format : str
65+
``STORAGE`` (default) returns ``(List[SerializedBlob], dict)``.
66+
``WIRE`` returns a ``str``.
67+
"""
68+
if format == WIRE:
69+
return self._wire_serialize()
70+
if format == STORAGE:
71+
return self._storage_serialize()
72+
raise ValueError("format must be %r or %r, got %r" % (STORAGE, WIRE, format))
73+
74+
@classmethod
75+
def deserialize(cls, data, format=STORAGE, **kwargs):
76+
"""
77+
Reconstruct an :class:`IOType` from serialized data.
78+
79+
Parameters
80+
----------
81+
data : Union[str, List[bytes]]
82+
``str`` when ``format=WIRE``; ``List[bytes]`` when ``format=STORAGE``.
83+
format : str
84+
``STORAGE`` (default) or ``WIRE``.
85+
**kwargs
86+
Forwarded to the underlying ``_storage_deserialize`` hook
87+
(e.g. metadata the datastore produced at save time).
88+
"""
89+
if format == WIRE:
90+
return cls._wire_deserialize(data)
91+
if format == STORAGE:
92+
return cls._storage_deserialize(data, **kwargs)
93+
raise ValueError("format must be %r or %r, got %r" % (STORAGE, WIRE, format))
94+
95+
# -- Subclass hooks ----------------------------------------------------
96+
97+
@abstractmethod
98+
def _wire_serialize(self):
99+
"""Value -> string (for CLI args, protobuf, external APIs)."""
100+
raise NotImplementedError
101+
102+
@classmethod
103+
@abstractmethod
104+
def _wire_deserialize(cls, s):
105+
"""String -> :class:`IOType` instance."""
106+
raise NotImplementedError
107+
108+
@abstractmethod
109+
def _storage_serialize(self):
110+
"""Value -> ``(List[SerializedBlob], metadata_dict)``. Side-effect-free."""
111+
raise NotImplementedError
112+
113+
@classmethod
114+
@abstractmethod
115+
def _storage_deserialize(cls, blobs, **kwargs):
116+
"""``(List[bytes], metadata)`` -> :class:`IOType` instance."""
117+
raise NotImplementedError
118+
119+
# -- Spec generation ---------------------------------------------------
120+
121+
def to_spec(self):
122+
"""JSON type spec. Works with or without a wrapped value."""
123+
return {"type": self.type_name}
124+
125+
# -- Dunder ------------------------------------------------------------
126+
127+
def __repr__(self):
128+
if self._value is _UNSET:
129+
return "%s()" % self.__class__.__name__
130+
return "%s(%r)" % (self.__class__.__name__, self._value)
131+
132+
def __eq__(self, other):
133+
if type(self) is not type(other):
134+
return NotImplemented
135+
return self._value == other._value
136+
137+
def __hash__(self):
138+
return hash((type(self), self._value))

metaflow/io_types/json_type.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import json
2+
3+
from ..datastore.artifacts.serializer import SerializedBlob
4+
from .base import IOType
5+
6+
7+
class Json(IOType):
8+
"""JSON type (dict or list). Wire: JSON string. Storage: UTF-8 JSON bytes."""
9+
10+
type_name = "json"
11+
12+
def _wire_serialize(self):
13+
return json.dumps(self._value, separators=(",", ":"), sort_keys=True)
14+
15+
@classmethod
16+
def _wire_deserialize(cls, s):
17+
return cls(json.loads(s))
18+
19+
def _storage_serialize(self):
20+
blob = json.dumps(self._value, separators=(",", ":"), sort_keys=True).encode(
21+
"utf-8"
22+
)
23+
return [SerializedBlob(blob)], {}
24+
25+
@classmethod
26+
def _storage_deserialize(cls, blobs, **kwargs):
27+
return cls(json.loads(blobs[0].decode("utf-8")))

metaflow/io_types/struct_type.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import dataclasses
2+
import importlib
3+
import json
4+
import typing
5+
6+
from ..datastore.artifacts.serializer import SerializedBlob
7+
from .base import IOType, _UNSET
8+
9+
10+
def _reconstruct(dc_type, data):
11+
"""
12+
Rebuild a dataclass instance from JSON-decoded ``data``, recursing into
13+
fields whose annotation is itself a dataclass. Containerized annotations
14+
(``List[Foo]``, ``Dict[str, Foo]``, ``Optional[Foo]``, ...) are left as
15+
raw JSON-decoded values; callers that need rich container reconstruction
16+
should wrap the field explicitly (e.g. in a ``List`` IOType shipped by
17+
an extension).
18+
"""
19+
try:
20+
hints = typing.get_type_hints(dc_type)
21+
except Exception:
22+
hints = {}
23+
kwargs = {}
24+
for f in dataclasses.fields(dc_type):
25+
raw = data.get(f.name)
26+
annotation = hints.get(f.name, f.type)
27+
if (
28+
isinstance(annotation, type)
29+
and dataclasses.is_dataclass(annotation)
30+
and isinstance(raw, dict)
31+
):
32+
kwargs[f.name] = _reconstruct(annotation, raw)
33+
else:
34+
kwargs[f.name] = raw
35+
return dc_type(**kwargs)
36+
37+
38+
class Struct(IOType):
39+
"""
40+
Structured type mapping to a Python ``@dataclass``.
41+
42+
Wire: JSON string. Storage: JSON UTF-8 bytes.
43+
44+
Wraps a ``@dataclass`` instance. On save, ``dataclasses.asdict`` flattens
45+
the whole tree to plain dicts; on load, fields typed as dataclasses are
46+
recursively rebuilt into their original types. Generic container
47+
annotations (``List[Foo]``, ``Dict[str, Foo]``, ``Optional[Foo]``) are
48+
not walked — those fields come back as raw JSON-decoded values. Wrap
49+
those explicitly (e.g. via ``List[Struct]`` support shipped by an
50+
extension) when you need typed containers.
51+
52+
.. warning::
53+
``Struct._storage_deserialize`` imports the dataclass module named in
54+
the artifact metadata. Metadata written by this class is safe, but
55+
metadata supplied from an untrusted source can trigger arbitrary
56+
imports (and any import-time side effects those modules carry).
57+
Only load artifacts from sources you trust.
58+
59+
Parameters
60+
----------
61+
value : dataclass instance or dict, optional
62+
The wrapped value. Dataclass instances are serialized via
63+
``dataclasses.asdict``; plain dicts are serialized directly.
64+
dataclass_type : type, optional
65+
The ``@dataclass`` class, for type-descriptor use (no value).
66+
"""
67+
68+
type_name = "struct"
69+
70+
def __init__(self, value=_UNSET, dataclass_type=None):
71+
if value is not _UNSET and dataclasses.is_dataclass(value):
72+
self._dataclass_type = type(value)
73+
elif dataclass_type is not None:
74+
self._dataclass_type = dataclass_type
75+
else:
76+
self._dataclass_type = None
77+
super().__init__(value)
78+
79+
def _to_dict(self):
80+
"""Convert value to dict, handling both dataclass and plain dict."""
81+
if dataclasses.is_dataclass(self._value):
82+
return dataclasses.asdict(self._value)
83+
if isinstance(self._value, dict):
84+
return self._value
85+
raise TypeError(
86+
"Struct value must be a dataclass instance or dict, got %s"
87+
% type(self._value).__name__
88+
)
89+
90+
def _wire_serialize(self):
91+
return json.dumps(self._to_dict(), separators=(",", ":"), sort_keys=True)
92+
93+
@classmethod
94+
def _wire_deserialize(cls, s):
95+
return cls(json.loads(s))
96+
97+
def _storage_serialize(self):
98+
blob = json.dumps(
99+
self._to_dict(), separators=(",", ":"), sort_keys=True
100+
).encode("utf-8")
101+
meta = {}
102+
if self._dataclass_type is not None:
103+
meta["dataclass_module"] = self._dataclass_type.__module__
104+
meta["dataclass_class"] = self._dataclass_type.__name__
105+
return [SerializedBlob(blob)], meta
106+
107+
@classmethod
108+
def _storage_deserialize(cls, blobs, **kwargs):
109+
data = json.loads(blobs[0].decode("utf-8"))
110+
metadata = kwargs.get("metadata", {})
111+
dc_module = metadata.get("dataclass_module")
112+
dc_class = metadata.get("dataclass_class")
113+
if dc_module and dc_class:
114+
mod = importlib.import_module(dc_module)
115+
dc_type = getattr(mod, dc_class)
116+
# Guard against crafted metadata — require a class that's
117+
# actually a dataclass. ``dataclasses.is_dataclass`` alone
118+
# returns True for dataclass *instances*; the ``isinstance(..., type)``
119+
# check excludes that (and anything else callable).
120+
if not (isinstance(dc_type, type) and dataclasses.is_dataclass(dc_type)):
121+
raise ValueError(
122+
"Struct metadata references '%s.%s' which is not a dataclass"
123+
% (dc_module, dc_class)
124+
)
125+
return cls(_reconstruct(dc_type, data), dataclass_type=dc_type)
126+
# Fallback: return as plain dict wrapped in Struct
127+
return cls(data)

metaflow/plugins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@
190190
# Add artifact serializers here. Ordering is by PRIORITY (lower = tried first).
191191
# PickleSerializer is the universal fallback (PRIORITY=9999).
192192
ARTIFACT_SERIALIZERS_DESC = [
193+
("iotype", ".datastores.serializers.iotype_serializer.IOTypeSerializer"),
193194
("pickle", ".datastores.serializers.pickle_serializer.PickleSerializer"),
194195
]
195196

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import importlib
2+
3+
from metaflow.datastore.artifacts.serializer import (
4+
ArtifactSerializer,
5+
SerializationMetadata,
6+
STORAGE,
7+
WIRE,
8+
)
9+
from metaflow.io_types.base import IOType
10+
11+
12+
class IOTypeSerializer(ArtifactSerializer):
13+
"""
14+
Bridge between :class:`IOType` and the pluggable serializer framework.
15+
16+
Claims any :class:`IOType` instance on save. On load, reconstructs the
17+
original subclass from the ``iotype_module`` / ``iotype_class`` hints
18+
that were written into ``serializer_info``.
19+
20+
``PRIORITY`` is 50 — ahead of the default (100) so this bridge catches
21+
:class:`IOType` artifacts before any generic catch-all, and always ahead
22+
of the :class:`PickleSerializer` fallback (9999).
23+
24+
Only the ``STORAGE`` format is implemented on this bridge; ``WIRE`` is
25+
handled by callers that talk to :class:`IOType` directly (CLI parsing,
26+
protobuf payload construction), not through the datastore.
27+
"""
28+
29+
TYPE = "iotype"
30+
PRIORITY = 50
31+
32+
_ENCODING_PREFIX = "iotype:"
33+
34+
@classmethod
35+
def can_serialize(cls, obj):
36+
return isinstance(obj, IOType)
37+
38+
@classmethod
39+
def can_deserialize(cls, metadata):
40+
return metadata.encoding.startswith(cls._ENCODING_PREFIX)
41+
42+
@classmethod
43+
def serialize(cls, obj, format=STORAGE):
44+
if format == WIRE:
45+
raise NotImplementedError(
46+
"IOTypeSerializer only handles the STORAGE format; wire "
47+
"encoding is produced by calling IOType.serialize(format=WIRE) "
48+
"directly."
49+
)
50+
blobs, meta_dict = obj.serialize(format=STORAGE)
51+
size = sum(len(b.value) for b in blobs if isinstance(b.value, bytes))
52+
# Subclass metadata goes first so the routing keys below always win.
53+
# An IOType subclass whose ``_storage_serialize`` happens to return
54+
# ``iotype_module`` or ``iotype_class`` in its own meta dict must not
55+
# be able to overwrite the routing info the deserialize path needs.
56+
serializer_info = {
57+
**meta_dict,
58+
"iotype_module": obj.__class__.__module__,
59+
"iotype_class": obj.__class__.__name__,
60+
}
61+
return (
62+
blobs,
63+
SerializationMetadata(
64+
obj_type=obj.type_name,
65+
size=size,
66+
encoding=cls._ENCODING_PREFIX + obj.type_name,
67+
serializer_info=serializer_info,
68+
),
69+
)
70+
71+
@classmethod
72+
def deserialize(cls, data, metadata=None, format=STORAGE):
73+
if format == WIRE:
74+
raise NotImplementedError(
75+
"IOTypeSerializer only handles the STORAGE format."
76+
)
77+
info = metadata.serializer_info
78+
mod = importlib.import_module(info["iotype_module"])
79+
iotype_cls = getattr(mod, info["iotype_class"])
80+
# Only allow actual IOType subclasses — metadata is untrusted input.
81+
if not (isinstance(iotype_cls, type) and issubclass(iotype_cls, IOType)):
82+
raise ValueError(
83+
"IOTypeSerializer metadata references '%s.%s' which is not an "
84+
"IOType subclass" % (info["iotype_module"], info["iotype_class"])
85+
)
86+
return iotype_cls.deserialize(data, format=STORAGE, metadata=info)

test/unit/io_types/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)