diff --git a/metaflow/datastore/artifacts/__init__.py b/metaflow/datastore/artifacts/__init__.py new file mode 100644 index 00000000000..9a0154df26f --- /dev/null +++ b/metaflow/datastore/artifacts/__init__.py @@ -0,0 +1,14 @@ +from .serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, + SerializerStore, + STORAGE, + WIRE, +) +from .lazy_registry import ( + SerializerConfig, + load_serializer_class, + register_serializer_config, + register_serializer_for_type, +) diff --git a/metaflow/datastore/artifacts/lazy_registry.py b/metaflow/datastore/artifacts/lazy_registry.py new file mode 100644 index 00000000000..c3b5d837ecc --- /dev/null +++ b/metaflow/datastore/artifacts/lazy_registry.py @@ -0,0 +1,216 @@ +""" +Lazy serializer registry driven by an import hook. + +Extensions ship serializers whose implementation modules may import optional +heavy dependencies (``torch``, ``pyarrow``, ``fastavro``, ``protobuf``, ...). +Loading those serializer modules unconditionally at ``metaflow`` import time +would force every user to pay for dependencies they may not have installed. + +This module defers both the serializer class import and its instantiation +until one of two things happens: + +1. The target type's module is already present in :data:`sys.modules` when + registration is called — registration then happens immediately. +2. The target type's module is imported later by the user's code. An + ``importlib`` hook watches for that event and performs registration the + first time the module is loaded. + +The hook is installed on :data:`sys.meta_path` and removes itself from the +path during its own ``find_spec`` call to avoid recursion. +""" + +import importlib +import importlib.abc +import importlib.machinery +import importlib.util +import sys + +from dataclasses import dataclass, field + + +@dataclass +class SerializerConfig: + """ + Declarative entry recording *which* serializer handles *which* type, + without actually importing the serializer class. The class is imported on + first use by :func:`load_serializer_class`. + + Parameters + ---------- + canonical_type : str + ``"module.ClassName"`` — e.g. ``"builtins.dict"``, ``"torch.Tensor"``. + serializer : str + Dotted import path to the serializer class, e.g. + ``"my_extension.serializers.TorchSerializer"``. + priority : int + Dispatch priority. Lower is tried first. Matches the existing + ``ArtifactSerializer.PRIORITY`` convention. + extra_kwargs : dict + Optional kwargs passed to the serializer class ``__init__``. + """ + + canonical_type: str + serializer: str + priority: int = 100 + extra_kwargs: dict = field(default_factory=dict) + + def __post_init__(self): + if not self.canonical_type: + raise ValueError("canonical_type cannot be empty") + if not self.serializer or "." not in self.serializer: + raise ValueError("serializer must be in 'module.ClassName' format") + + @property + def serializer_module(self): + return ".".join(self.serializer.split(".")[:-1]) + + @property + def serializer_class(self): + return self.serializer.split(".")[-1] + + +# Module-level registry. Keyed by canonical_type -> SerializerConfig. +# A separate dict caches instantiated classes so repeated lookups are O(1). +_registered_configs = {} +_loaded_serializers = {} + + +def register_serializer_config(config): + """Store a config immediately. The serializer class is not imported yet.""" + _registered_configs[config.canonical_type] = config + # Any previously cached class for this type is now stale. + _loaded_serializers.pop(config.canonical_type, None) + + +def load_serializer_class(canonical_type): + """ + Resolve and cache the serializer class for ``canonical_type``. Returns + ``None`` if no config is registered for that type. + """ + cached = _loaded_serializers.get(canonical_type) + if cached is not None: + return cached + config = _registered_configs.get(canonical_type) + if config is None: + return None + module = importlib.import_module(config.serializer_module) + cls = getattr(module, config.serializer_class) + _loaded_serializers[canonical_type] = cls + return cls + + +def iter_registered_configs(): + """Iterate all registered configs. Deterministic order (insertion).""" + return list(_registered_configs.values()) + + +class _WrappedLoader(importlib.abc.Loader): + """Delegating loader that fires a callback after ``exec_module``. + + Only ``create_module`` and ``exec_module`` are overridden. Other loader + attributes (``get_resource_reader``, ``get_filename``, ``get_data``, + ``is_package``, ``get_source``, ...) are forwarded to the wrapped loader + via ``__getattr__`` so importers that poke at those interfaces continue + to work transparently. + """ + + def __init__(self, original_loader, interceptor): + self._original = original_loader + self._interceptor = interceptor + + def create_module(self, spec): + return self._original.create_module(spec) + + def exec_module(self, module): + self._original.exec_module(module) + self._interceptor._on_module_imported(module) + + def __getattr__(self, name): + return getattr(self._original, name) + + +class _SerializerImportInterceptor(importlib.abc.MetaPathFinder): + """ + :class:`importlib.abc.MetaPathFinder` that watches for a fixed set of + module names and fires :func:`_on_module_imported` once each has been + fully executed. + """ + + def __init__(self): + # module_name -> list[SerializerConfig] + self._pending = {} + self._processed = set() + + def watch(self, module_name, config): + self._pending.setdefault(module_name, []).append(config) + + def find_spec(self, fullname, path, target=None): + if fullname not in self._pending: + return None + # Remove ourselves from the path during the lookup below so Python's + # normal finders (not us) can resolve the real spec. Reinstall after. + was_installed = self in sys.meta_path + if was_installed: + sys.meta_path.remove(self) + try: + spec = importlib.util.find_spec(fullname) + finally: + if was_installed: + sys.meta_path.insert(0, self) + if spec is None or spec.loader is None: + return None + spec.loader = _WrappedLoader(spec.loader, self) + return spec + + def _on_module_imported(self, module): + module_name = module.__name__ + if module_name in self._processed: + return + self._processed.add(module_name) + for config in self._pending.get(module_name, ()): + class_name = config.canonical_type.rsplit(".", 1)[-1] + if hasattr(module, class_name): + register_serializer_config(config) + + +_interceptor = _SerializerImportInterceptor() + + +def _ensure_interceptor_installed(): + if _interceptor in sys.meta_path: + sys.meta_path.remove(_interceptor) + sys.meta_path.insert(0, _interceptor) + + +def register_serializer_for_type(canonical_type, serializer, **kwargs): + """ + Public entry point for extensions. + + If the target type's module is already loaded, the config is stored + immediately. Otherwise, an import hook is installed and registration is + deferred to the first ``import`` of the target module. + + ``canonical_type`` is ``"module.ClassName"``. ``serializer`` is a dotted + path to the serializer class. Additional ``priority`` / ``extra_kwargs`` + forwarded into :class:`SerializerConfig`. + """ + config = SerializerConfig( + canonical_type=canonical_type, serializer=serializer, **kwargs + ) + module_name, class_name = canonical_type.rsplit(".", 1) + mod = sys.modules.get(module_name) + if mod is not None and hasattr(mod, class_name): + register_serializer_config(config) + return + _ensure_interceptor_installed() + _interceptor.watch(module_name, config) + + +def _reset_for_tests(): + """Clear all module-level state. Intended for unit tests only.""" + _registered_configs.clear() + _loaded_serializers.clear() + _interceptor._pending.clear() + _interceptor._processed.clear() + if _interceptor in sys.meta_path: + sys.meta_path.remove(_interceptor) diff --git a/metaflow/datastore/artifacts/serializer.py b/metaflow/datastore/artifacts/serializer.py new file mode 100644 index 00000000000..ab015795074 --- /dev/null +++ b/metaflow/datastore/artifacts/serializer.py @@ -0,0 +1,248 @@ +import inspect +from abc import ABCMeta, abstractmethod +from collections import namedtuple +from enum import Enum +from typing import Any, List, Optional, Tuple, Type, Union + + +class SerializationFormat(str, Enum): + """ + Representation a serializer produces or consumes. + + - ``STORAGE`` yields ``(List[SerializedBlob], SerializationMetadata)`` for + the datastore save path. + - ``WIRE`` yields a ``str`` for CLI args, protobuf payloads, and + cross-process IPC. + + Subclassing ``str`` keeps ``SerializationFormat.STORAGE == "storage"`` + True, so existing call sites that compare against the string literal + keep working without a migration. + """ + + STORAGE = "storage" + WIRE = "wire" + + +# Module-level aliases kept so call sites can write ``format=STORAGE`` without +# importing the enum itself. +STORAGE = SerializationFormat.STORAGE +WIRE = SerializationFormat.WIRE + + +SerializationMetadata = namedtuple( + "SerializationMetadata", ["obj_type", "size", "encoding", "serializer_info"] +) + + +class SerializedBlob(object): + """ + Represents a single blob produced by a serializer. + + A serializer may produce multiple blobs per artifact. Each blob is either: + - New bytes to be stored (is_reference=False, value is bytes) + - A reference to already-stored data (is_reference=True, value is a string key) + + Parameters + ---------- + value : Union[str, bytes] + The blob data (bytes) or a reference key (str). + is_reference : bool, optional + If None, auto-detected from value type: str -> reference, bytes -> new data. + """ + + def __init__( + self, + value: Union[str, bytes], + is_reference: Optional[bool] = None, + ): + if not isinstance(value, (str, bytes)): + raise TypeError( + "SerializedBlob value must be str or bytes, got %s" % type(value).__name__ + ) + self.value = value + if is_reference is None: + self.is_reference = isinstance(value, str) + else: + self.is_reference = is_reference + + @property + def needs_save(self) -> bool: + """True if this blob contains new bytes that need to be stored.""" + return not self.is_reference + + +class SerializerStore(ABCMeta): + """ + Metaclass for ArtifactSerializer that auto-registers subclasses by TYPE. + + Provides deterministic ordering: serializers are sorted by (PRIORITY, registration_order). + Lower PRIORITY values are tried first. Registration order breaks ties. + """ + + _all_serializers = {} + _ordered_cache = None + + def __init__(cls, name, bases, namespace): + super().__init__(name, bases, namespace) + # Skip the abstract base and any subclass that didn't implement all + # abstract methods — registering a partially-abstract class would + # blow up only at dispatch time. + if cls.TYPE is None or inspect.isabstract(cls): + return + SerializerStore._all_serializers[cls.TYPE] = cls + SerializerStore._ordered_cache = None + + @staticmethod + def get_ordered_serializers() -> List[Type["ArtifactSerializer"]]: + """ + Return serializer classes sorted by (PRIORITY, registration_order). + + Python 3.7+ dicts preserve insertion order, so enumerating + ``_all_serializers.values()`` yields registration order. A stable sort + on PRIORITY preserves that tiebreaker. + + Serializers registered via the lazy registry are materialized here + too: each registered class is imported on demand and folded into the + dispatch order. Without this step, a lazy + ``register_serializer_for_type`` call would be silently ignored + at dispatch time. + """ + # Imported locally to avoid a circular import between this module and + # ``lazy_registry`` (which depends on the ArtifactSerializer ABC). + from .lazy_registry import iter_registered_configs, load_serializer_class + + lazy_classes = [] + for cfg in iter_registered_configs(): + cls = load_serializer_class(cfg.canonical_type) + if cls is not None: + lazy_classes.append(cls) + + if SerializerStore._ordered_cache is None or lazy_classes: + # De-duplicate: lazy classes typically also self-register via the + # metaclass, but when loaded outside normal import flow they may + # not. ``dict.fromkeys`` preserves first-seen order while dropping + # duplicates. + combined = list( + dict.fromkeys( + list(SerializerStore._all_serializers.values()) + lazy_classes + ) + ) + SerializerStore._ordered_cache = sorted( + combined, key=lambda s: s.PRIORITY + ) + return SerializerStore._ordered_cache + + +class ArtifactSerializer(object, metaclass=SerializerStore): + """ + Abstract base class for artifact serializers. + + Subclasses must set TYPE to a unique string identifier and implement + all four class methods. Subclasses are auto-registered by the SerializerStore + metaclass on class definition. + + Attributes + ---------- + TYPE : str or None + Unique identifier for this serializer (e.g., "pickle", "iotype"). + Set to None in the base class to prevent registration. + PRIORITY : int + Dispatch priority. Lower values are tried first. Default 100. + PickleSerializer uses 9999 as the universal fallback. + """ + + TYPE: Optional[str] = None + PRIORITY: int = 100 + + @classmethod + @abstractmethod + def can_serialize(cls, obj: Any) -> bool: + """ + Return True if this serializer can handle the given object. + + Parameters + ---------- + obj : Any + The Python object to serialize. + + Returns + ------- + bool + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def can_deserialize(cls, metadata: SerializationMetadata) -> bool: + """ + Return True if this serializer can deserialize given the metadata. + + Parameters + ---------- + metadata : SerializationMetadata + Metadata stored alongside the artifact. + + Returns + ------- + bool + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def serialize( + cls, + obj: Any, + format: SerializationFormat = STORAGE, + ) -> Union[Tuple[List[SerializedBlob], SerializationMetadata], str]: + """ + Serialize obj. Must be side-effect-free: this method may be invoked + multiple times (caching, retries, parallel dispatch) and must not + perform I/O, mutate global state, or register the object elsewhere. + Side effects that need to happen at persist time belong in hooks, + not in the serializer. + + Parameters + ---------- + obj : Any + The Python object to serialize. + format : SerializationFormat + Either ``STORAGE`` (default) or ``WIRE``. + - ``STORAGE`` returns a tuple ``(List[SerializedBlob], SerializationMetadata)`` + for persisting through the datastore. + - ``WIRE`` returns a ``str`` representation for CLI args, protobuf + payloads, and cross-process IPC. Serializers that cannot provide + a wire encoding should raise ``NotImplementedError``. + + Returns + ------- + tuple or str + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def deserialize( + cls, + data: Union[List[bytes], str], + metadata: Optional[SerializationMetadata] = None, + format: SerializationFormat = STORAGE, + ) -> Any: + """ + Deserialize back to a Python object. + + Parameters + ---------- + data : Union[List[bytes], str] + ``List[bytes]`` when ``format=STORAGE``; ``str`` when ``format=WIRE``. + metadata : SerializationMetadata, optional + Metadata stored alongside the artifact. Required for STORAGE, + ignored for WIRE. + format : SerializationFormat + Either ``STORAGE`` (default) or ``WIRE``. + + Returns + ------- + Any + """ + raise NotImplementedError diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index abeeb8ea5fb..abe100aec78 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -1,6 +1,5 @@ from collections import defaultdict import json -import pickle import sys import time @@ -15,6 +14,7 @@ from ..parameters import Parameter from ..util import Path, is_stringish, to_fileobj +from .artifacts.serializer import SerializationMetadata, SerializerStore from .exceptions import DataException, UnpicklableArtifactException _included_file_type = "" @@ -117,12 +117,10 @@ def __init__( self._parent = flow_datastore self._persist = persist - # The GZIP encodings are for backward compatibility - self._encodings = {"pickle-v2", "gzip+pickle-v2"} - ver = sys.version_info[0] * 10 + sys.version_info[1] - if ver >= 36: - self._encodings.add("pickle-v4") - self._encodings.add("gzip+pickle-v4") + # Tests assign ``self._serializers = [...]`` to pin the dispatch list + # for isolation. When set, the ``_serializers`` property returns this + # override instead of consulting the global registry. + self._serializers_override = None self._is_done_set = False @@ -200,6 +198,25 @@ def __init__( else: raise DataException("Unknown datastore mode: '%s'" % self._mode) + @property + def _serializers(self): + # Dispatch through ``SerializerStore.get_ordered_serializers()`` on + # each access. The lookup is cheap (cached inside the store) and + # picks up serializers registered via the lazy import hook after + # this instance was constructed — otherwise long-lived datastores + # (notebooks, client sessions) would silently miss any extension + # registered after init. + if self._serializers_override is not None: + return self._serializers_override + return SerializerStore.get_ordered_serializers() + + @_serializers.setter + def _serializers(self, value): + # Tests override the dispatch list directly for isolation; preserve + # that escape hatch without losing the dynamic registry behavior + # in production. + self._serializers_override = value + @property def pathspec(self): return "/".join([self.run_id, self.step_name, self.task_id]) @@ -347,38 +364,57 @@ def save_artifacts(self, artifacts_iter, len_hint=0): """ artifact_names = [] - def pickle_iter(): + def serialize_iter(): for name, obj in artifacts_iter: - encode_type = "gzip+pickle-v4" - if encode_type in self._encodings: - try: - blob = pickle.dumps(obj, protocol=4) - except TypeError as e: - raise UnpicklableArtifactException(name) from e - else: - try: - blob = pickle.dumps(obj, protocol=2) - encode_type = "gzip+pickle-v2" - except (SystemError, OverflowError) as e: - raise DataException( - "Artifact *%s* is very large (over 2GB). " - "You need to use Python 3.6 or newer if you want to " - "serialize large objects." % name - ) from e - except TypeError as e: - raise UnpicklableArtifactException(name) from e + # Find the first serializer that can handle this object + serializer = None + for s in self._serializers: + if s.can_serialize(obj): + serializer = s + break + if serializer is None: + raise DataException( + "No serializer claimed artifact '%s' (type: %s). " + "The PickleSerializer fallback normally handles all " + "objects — check that it is installed and enabled." + % (name, type(obj).__name__) + ) + + try: + blobs, metadata = serializer.serialize(obj) + except TypeError as e: + raise UnpicklableArtifactException(name) from e self._info[name] = { - "size": len(blob), - "type": str(type(obj)), - "encoding": encode_type, + "size": metadata.size, + "type": metadata.obj_type, + "encoding": metadata.encoding, } + if metadata.serializer_info: + self._info[name]["serializer_info"] = metadata.serializer_info + if not blobs: + raise DataException( + "Serializer %s returned no blobs for artifact '%s'" + % (serializer.__name__, name) + ) + if len(blobs) > 1: + # The datastore currently stores a single blob per + # artifact. Silently dropping blobs[1:] would corrupt + # multi-blob IOType extensions (e.g. chunked tensors) on + # load. Fail loudly until multi-blob support lands. + raise DataException( + "Serializer %s returned %d blobs for artifact '%s'; " + "only single-blob serializers are supported at this " + "time. If you have a need for multi blob " + "serializers, please reach out to the Metaflow team." + % (serializer.__name__, len(blobs), name) + ) artifact_names.append(name) - yield blob + yield blobs[0].value # Use the content-addressed store to store all artifacts - save_result = self._ca_store.save_blobs(pickle_iter(), len_hint=len_hint) + save_result = self._ca_store.save_blobs(serialize_iter(), len_hint=len_hint) for name, result in zip(artifact_names, save_result): self._objects[name] = result.key @@ -414,32 +450,54 @@ def load_artifacts(self, names): "load artifacts" % self._path ) to_load = defaultdict(list) + deserializers = {} # name -> serializer class for name in names: - info = self._info.get(name) - # We use gzip+pickle-v2 as this is the oldest/most compatible. - # This datastore will always include the proper encoding version so - # this is just to be able to read very old artifacts - if info: - encode_type = info.get("encoding", "gzip+pickle-v2") - else: - encode_type = "gzip+pickle-v2" - if encode_type not in self._encodings: + info = self._info.get(name, {}) + metadata = SerializationMetadata( + obj_type=info.get("type", "object"), + size=info.get("size", 0), + # Default to gzip+pickle-v2 for very old artifacts without encoding + encoding=info.get("encoding", "gzip+pickle-v2"), + serializer_info=info.get("serializer_info", {}), + ) + + # Find deserializer via metadata + deserializer = None + for s in self._serializers: + if s.can_deserialize(metadata): + deserializer = s + break + if deserializer is None: + source_hint = "" + serializer_source = metadata.serializer_info.get("source") + if serializer_source: + source_hint = ( + " The artifact was written by '%s' — the " + "corresponding extension may not be installed." + % serializer_source + ) raise DataException( - "Python 3.6 or later is required to load artifact '%s'" % name + "No deserializer claimed artifact '%s' (encoding: %s, " + "serializer_info: %r).%s" + % ( + name, + metadata.encoding, + metadata.serializer_info, + source_hint, + ) ) - else: - to_load[self._objects[name]].append(name) - # At this point, we load what we don't have from the CAS - # We assume that if we have one "old" style artifact, all of them are - # like that which is an easy assumption to make since artifacts are all - # stored by the same implementation of the datastore for a given task. + deserializers[name] = (deserializer, metadata) + to_load[self._objects[name]].append(name) + + # Load blobs from CAS and deserialize for key, blob in self._ca_store.load_blobs(to_load.keys()): - names = to_load[key] - for name in names: - # We unpickle everytime to have fully distinct objects (the user + loaded_names = to_load[key] + for name in loaded_names: + deserializer, metadata = deserializers[name] + # Deserialize each time to have fully distinct objects (the user # would not expect two artifacts with different names to actually # be aliases of one another) - yield name, pickle.loads(blob) + yield name, deserializer.deserialize([blob], metadata) @require_mode("r") def get_artifact_sizes(self, names): diff --git a/metaflow/extension_support/plugins.py b/metaflow/extension_support/plugins.py index 62ed434fd63..d93b37a6c79 100644 --- a/metaflow/extension_support/plugins.py +++ b/metaflow/extension_support/plugins.py @@ -218,6 +218,7 @@ def resolve_plugins(category, path_only=False): ), "runner_cli": lambda x: x.name, "tl_plugin": None, + "artifact_serializer": lambda x: x.TYPE, } diff --git a/metaflow/io_types/__init__.py b/metaflow/io_types/__init__.py new file mode 100644 index 00000000000..50a27e3b894 --- /dev/null +++ b/metaflow/io_types/__init__.py @@ -0,0 +1,3 @@ +from .base import IOType +from .json_type import Json +from .struct_type import Struct diff --git a/metaflow/io_types/base.py b/metaflow/io_types/base.py new file mode 100644 index 00000000000..27f98a63579 --- /dev/null +++ b/metaflow/io_types/base.py @@ -0,0 +1,182 @@ +"""Typed-artifact contract for Metaflow. + +This module defines the minimal :class:`IOType` abstract base class. OSS +Metaflow ships the contract; concrete types (scalars, tensors, enums, +dataclass-backed structs, etc.) live in extensions — they embody +deployment-specific opinions about encoding, byte order, and dataclass +inference that do not belong in core. + +:class:`IOType` mirrors the ``format`` argument introduced on +:class:`metaflow.datastore.artifacts.serializer.ArtifactSerializer` so a +single subclass can own both representations: + +- ``STORAGE`` — blob-based, persisted through the datastore. +- ``WIRE`` — string-based, for CLI args, protobuf payloads, and + cross-process IPC. + +Subclasses implement four hooks (``_wire_serialize``, ``_wire_deserialize``, +``_storage_serialize``, ``_storage_deserialize``); callers use the public +``serialize(format=...)`` / ``deserialize(data, format=...)`` methods. +""" + +from abc import ABCMeta, abstractmethod + +from metaflow.datastore.artifacts.serializer import STORAGE, WIRE + + +_UNSET = object() + +# Registry of concrete IOType subclasses keyed by their ``type_name``. Populated +# by ``IOType.__init_subclass__``. The datastore encodes each artifact's +# ``type_name`` in its ``SerializationMetadata.encoding`` (``iotype:``), +# so ``IOTypeSerializer.deserialize`` can recover the class without the +# metadata service having to persist the Python module+class path. +_TYPE_REGISTRY = {} + + +def get_iotype_by_name(type_name): + """Return the IOType subclass registered under ``type_name``, or None.""" + return _TYPE_REGISTRY.get(type_name) + + +def _make_hashable(value): + """ + Recursively convert a JSON-like value to a hashable form. + + dicts -> frozenset of (key, hashable(value)) pairs. + lists -> tuple of hashable elements. + Everything else assumed hashable (int, float, bool, str, None, ...). + + Using ``frozenset``/``tuple`` preserves Python's numeric equivalence + (``1 == 1.0 == True`` -> equal hashes) that ``json.dumps`` would + otherwise break by rendering each as a distinct string. This keeps the + ``__eq__`` / ``__hash__`` contract intact when IOType subclasses delegate + value equality to the wrapped Python object. + """ + if isinstance(value, dict): + return frozenset((k, _make_hashable(v)) for k, v in value.items()) + if isinstance(value, list): + return tuple(_make_hashable(v) for v in value) + return value + + +class IOType(object, metaclass=ABCMeta): + """ + Base class for typed Metaflow artifacts. + + An :class:`IOType` instance plays two roles: + + - **Descriptor** (no value): ``Int64`` in a spec describes an int64 + field. + - **Wrapper** (with value): ``Int64(42)`` wraps a value for typed + serialization. + + Subclasses implement four internal operations, dispatched by the + ``format`` argument of the public :meth:`serialize` / :meth:`deserialize` + methods. + """ + + type_name = None # e.g. "text", "json", "int64" — set by subclasses. + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Register concrete subclasses so IOTypeSerializer can recover them + # from just the ``type_name`` suffix of a stored artifact's encoding. + # Abstract intermediates (``type_name`` is still None) don't register. + # Last-write-wins: production code is expected to declare each + # ``type_name`` on exactly one class; test-local subclasses that reuse + # a name harmlessly overwrite each other. + if cls.type_name: + _TYPE_REGISTRY[cls.type_name] = cls + + def __init__(self, value=_UNSET): + self._value = value + + @property + def value(self): + """The wrapped Python value, or ``_UNSET`` if this is a pure descriptor.""" + return self._value + + # -- Public API -------------------------------------------------------- + + def serialize(self, format=STORAGE): + """ + Serialize the wrapped value. Must be side-effect-free. + + Parameters + ---------- + format : str + ``STORAGE`` (default) returns ``(List[SerializedBlob], dict)``. + ``WIRE`` returns a ``str``. + """ + if format == WIRE: + return self._wire_serialize() + if format == STORAGE: + return self._storage_serialize() + raise ValueError("format must be %r or %r, got %r" % (STORAGE, WIRE, format)) + + @classmethod + def deserialize(cls, data, format=STORAGE, **kwargs): + """ + Reconstruct an :class:`IOType` from serialized data. + + Parameters + ---------- + data : Union[str, List[bytes]] + ``str`` when ``format=WIRE``; ``List[bytes]`` when ``format=STORAGE``. + format : str + ``STORAGE`` (default) or ``WIRE``. + **kwargs + Forwarded to the underlying ``_storage_deserialize`` hook + (e.g. metadata the datastore produced at save time). + """ + if format == WIRE: + return cls._wire_deserialize(data) + if format == STORAGE: + return cls._storage_deserialize(data, **kwargs) + raise ValueError("format must be %r or %r, got %r" % (STORAGE, WIRE, format)) + + # -- Subclass hooks ---------------------------------------------------- + + @abstractmethod + def _wire_serialize(self): + """Value -> string (for CLI args, protobuf, external APIs).""" + raise NotImplementedError + + @classmethod + @abstractmethod + def _wire_deserialize(cls, s): + """String -> :class:`IOType` instance.""" + raise NotImplementedError + + @abstractmethod + def _storage_serialize(self): + """Value -> ``(List[SerializedBlob], metadata_dict)``. Side-effect-free.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def _storage_deserialize(cls, blobs, **kwargs): + """``(List[bytes], metadata)`` -> :class:`IOType` instance.""" + raise NotImplementedError + + # -- Spec generation --------------------------------------------------- + + def to_spec(self): + """JSON type spec. Works with or without a wrapped value.""" + return {"type": self.type_name} + + # -- Dunder ------------------------------------------------------------ + + def __repr__(self): + if self._value is _UNSET: + return "%s()" % self.__class__.__name__ + return "%s(%r)" % (self.__class__.__name__, self._value) + + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + return self._value == other._value + + def __hash__(self): + return hash((type(self), self._value)) diff --git a/metaflow/io_types/json_type.py b/metaflow/io_types/json_type.py new file mode 100644 index 00000000000..e62d4fe9c01 --- /dev/null +++ b/metaflow/io_types/json_type.py @@ -0,0 +1,37 @@ +import json + +from ..datastore.artifacts.serializer import SerializedBlob +from .base import IOType, _UNSET, _make_hashable + + +class Json(IOType): + """JSON type (dict or list). Wire: JSON string. Storage: UTF-8 JSON bytes.""" + + type_name = "json" + + def _wire_serialize(self): + return json.dumps(self._value, separators=(",", ":"), sort_keys=True) + + @classmethod + def _wire_deserialize(cls, s): + return cls(json.loads(s)) + + def _storage_serialize(self): + blob = json.dumps(self._value, separators=(",", ":"), sort_keys=True).encode( + "utf-8" + ) + return [SerializedBlob(blob)], {} + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + return cls(json.loads(blobs[0].decode("utf-8"))) + + def __hash__(self): + # ``_value`` is typically a dict or list (unhashable), so the base + # class ``hash((type, _value))`` raises TypeError. Convert to a + # frozenset/tuple form that preserves Python's numeric equivalence + # (``1 == 1.0 == True`` hash identically), so ``__eq__`` and + # ``__hash__`` stay consistent even when users mix int/float/bool. + if self._value is _UNSET: + return hash((type(self), _UNSET)) + return hash((type(self), _make_hashable(self._value))) diff --git a/metaflow/io_types/struct_type.py b/metaflow/io_types/struct_type.py new file mode 100644 index 00000000000..1742d617fd9 --- /dev/null +++ b/metaflow/io_types/struct_type.py @@ -0,0 +1,157 @@ +import dataclasses +import importlib +import json +import typing + +from ..datastore.artifacts.serializer import SerializedBlob +from .base import IOType, _UNSET, _make_hashable + + +def _reconstruct(dc_type, data): + """ + Rebuild a dataclass instance from JSON-decoded ``data``, recursing into + fields whose annotation is itself a dataclass. Containerized annotations + (``List[Foo]``, ``Dict[str, Foo]``, ``Optional[Foo]``, ...) are left as + raw JSON-decoded values; callers that need rich container reconstruction + should wrap the field explicitly (e.g. in a ``List`` IOType shipped by + an extension). + + Fields declared with ``field(init=False, ...)`` are not accepted by the + generated ``__init__``, but ``dataclasses.asdict`` emits them. Pass only + init-eligible fields as kwargs, then assign the remainder via + ``object.__setattr__`` (which works on frozen dataclasses too) so the + serialized values are not lost to defaults or ``__post_init__``. + """ + try: + hints = typing.get_type_hints(dc_type) + except Exception: + hints = {} + kwargs = {} + post_init_fields = [] + for f in dataclasses.fields(dc_type): + if f.name not in data: + continue + raw = data[f.name] + annotation = hints.get(f.name, f.type) + if ( + isinstance(annotation, type) + and dataclasses.is_dataclass(annotation) + and isinstance(raw, dict) + ): + value = _reconstruct(annotation, raw) + else: + value = raw + if f.init: + kwargs[f.name] = value + else: + post_init_fields.append((f.name, value)) + instance = dc_type(**kwargs) + # ``object.__setattr__`` bypasses ``@dataclass(frozen=True)``'s lock, + # matching the pattern frozen dataclasses themselves use to populate + # ``init=False`` fields inside ``__post_init__``. + for name, value in post_init_fields: + object.__setattr__(instance, name, value) + return instance + + +class Struct(IOType): + """ + Structured type mapping to a Python ``@dataclass``. + + Wire: JSON string. Storage: JSON UTF-8 bytes. + + Wraps a ``@dataclass`` instance. On save, ``dataclasses.asdict`` flattens + the whole tree to plain dicts; on load, fields typed as dataclasses are + recursively rebuilt into their original types. Generic container + annotations (``List[Foo]``, ``Dict[str, Foo]``, ``Optional[Foo]``) are + not walked — those fields come back as raw JSON-decoded values. Wrap + those explicitly (e.g. via ``List[Struct]`` support shipped by an + extension) when you need typed containers. + + .. warning:: + ``Struct._storage_deserialize`` imports the dataclass module named in + the artifact metadata. Metadata written by this class is safe, but + metadata supplied from an untrusted source can trigger arbitrary + imports (and any import-time side effects those modules carry). + Only load artifacts from sources you trust. + + Parameters + ---------- + value : dataclass instance or dict, optional + The wrapped value. Dataclass instances are serialized via + ``dataclasses.asdict``; plain dicts are serialized directly. + dataclass_type : type, optional + The ``@dataclass`` class, for type-descriptor use (no value). + """ + + type_name = "struct" + + def __init__(self, value=_UNSET, dataclass_type=None): + if value is not _UNSET and dataclasses.is_dataclass(value): + self._dataclass_type = type(value) + elif dataclass_type is not None: + self._dataclass_type = dataclass_type + else: + self._dataclass_type = None + super().__init__(value) + + def _to_dict(self): + """Convert value to dict, handling both dataclass and plain dict.""" + if dataclasses.is_dataclass(self._value): + return dataclasses.asdict(self._value) + if isinstance(self._value, dict): + return self._value + raise TypeError( + "Struct value must be a dataclass instance or dict, got %s" + % type(self._value).__name__ + ) + + def _wire_serialize(self): + return json.dumps(self._to_dict(), separators=(",", ":"), sort_keys=True) + + @classmethod + def _wire_deserialize(cls, s): + return cls(json.loads(s)) + + def _storage_serialize(self): + blob = json.dumps( + self._to_dict(), separators=(",", ":"), sort_keys=True + ).encode("utf-8") + meta = {} + if self._dataclass_type is not None: + meta["dataclass_module"] = self._dataclass_type.__module__ + meta["dataclass_class"] = self._dataclass_type.__name__ + return [SerializedBlob(blob)], meta + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + data = json.loads(blobs[0].decode("utf-8")) + metadata = kwargs.get("metadata", {}) + dc_module = metadata.get("dataclass_module") + dc_class = metadata.get("dataclass_class") + if dc_module and dc_class: + mod = importlib.import_module(dc_module) + dc_type = getattr(mod, dc_class) + # Guard against crafted metadata — require a class that's + # actually a dataclass. ``dataclasses.is_dataclass`` alone + # returns True for dataclass *instances*; the ``isinstance(..., type)`` + # check excludes that (and anything else callable). + if not (isinstance(dc_type, type) and dataclasses.is_dataclass(dc_type)): + raise ValueError( + "Struct metadata references '%s.%s' which is not a dataclass" + % (dc_module, dc_class) + ) + return cls(_reconstruct(dc_type, data), dataclass_type=dc_type) + # Fallback: return as plain dict wrapped in Struct + return cls(data) + + def __hash__(self): + # ``_value`` is typically a dataclass instance or dict (often + # unhashable), so the base class ``hash((type, _value))`` raises + # TypeError. Flatten to a dict via ``_to_dict`` then convert to a + # frozenset/tuple form that preserves Python's numeric equivalence + # (``1 == 1.0 == True`` hash identically), so ``__eq__`` and + # ``__hash__`` stay consistent for mixed-type fields. + if self._value is _UNSET: + return hash((type(self), _UNSET)) + return hash((type(self), _make_hashable(self._to_dict()))) diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index 3fc1d3f8db6..d5846fb80c1 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -187,6 +187,13 @@ ("conda_environment_yml_parser", ".pypi.parsers.conda_environment_yml_parser"), ] +# Add artifact serializers here. Ordering is by PRIORITY (lower = tried first). +# PickleSerializer is the universal fallback (PRIORITY=9999). +ARTIFACT_SERIALIZERS_DESC = [ + ("iotype", ".datastores.serializers.iotype_serializer.IOTypeSerializer"), + ("pickle", ".datastores.serializers.pickle_serializer.PickleSerializer"), +] + process_plugins(globals()) @@ -228,6 +235,7 @@ def get_runner_cli_path(): DEPLOYER_IMPL_PROVIDERS = resolve_plugins("deployer_impl_provider") TL_PLUGINS = resolve_plugins("tl_plugin") +ARTIFACT_SERIALIZERS = resolve_plugins("artifact_serializer") from .cards.card_modules import MF_EXTERNAL_CARDS diff --git a/metaflow/plugins/datastores/serializers/__init__.py b/metaflow/plugins/datastores/serializers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/plugins/datastores/serializers/iotype_serializer.py b/metaflow/plugins/datastores/serializers/iotype_serializer.py new file mode 100644 index 00000000000..521289ac42c --- /dev/null +++ b/metaflow/plugins/datastores/serializers/iotype_serializer.py @@ -0,0 +1,105 @@ +import importlib + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, + STORAGE, + WIRE, +) +from metaflow.io_types.base import IOType, get_iotype_by_name + + +class IOTypeSerializer(ArtifactSerializer): + """ + Bridge between :class:`IOType` and the pluggable serializer framework. + + Claims any :class:`IOType` instance on save. On load, reconstructs the + original subclass by looking up its ``type_name`` in the global IOType + registry (populated via :meth:`IOType.__init_subclass__`). The + ``iotype_module`` / ``iotype_class`` hints stored in ``serializer_info`` + are kept as a secondary lookup path — useful when a subclass isn't yet + registered in the reader process, or when inspecting artifacts produced + by extensions whose code isn't installed locally. + + ``PRIORITY`` is 50 — ahead of the default (100) so this bridge catches + :class:`IOType` artifacts before any generic catch-all, and always ahead + of the :class:`PickleSerializer` fallback (9999). + + Only the ``STORAGE`` format is implemented on this bridge; ``WIRE`` is + handled by callers that talk to :class:`IOType` directly (CLI parsing, + protobuf payload construction), not through the datastore. + """ + + TYPE = "iotype" + PRIORITY = 50 + + _ENCODING_PREFIX = "iotype:" + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, IOType) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding.startswith(cls._ENCODING_PREFIX) + + @classmethod + def serialize(cls, obj, format=STORAGE): + if format == WIRE: + raise NotImplementedError( + "IOTypeSerializer only handles the STORAGE format; wire " + "encoding is produced by calling IOType.serialize(format=WIRE) " + "directly." + ) + blobs, meta_dict = obj.serialize(format=STORAGE) + size = sum(len(b.value) for b in blobs if isinstance(b.value, bytes)) + # Subclass metadata goes first so the routing keys below always win. + # An IOType subclass whose ``_storage_serialize`` happens to return + # ``iotype_module`` or ``iotype_class`` in its own meta dict must not + # be able to overwrite the routing info the deserialize path needs. + serializer_info = { + **meta_dict, + "iotype_module": obj.__class__.__module__, + "iotype_class": obj.__class__.__name__, + } + return ( + blobs, + SerializationMetadata( + obj_type=obj.type_name, + size=size, + encoding=cls._ENCODING_PREFIX + obj.type_name, + serializer_info=serializer_info, + ), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=STORAGE): + if format == WIRE: + raise NotImplementedError( + "IOTypeSerializer only handles the STORAGE format." + ) + info = metadata.serializer_info or {} + # Primary path: registry lookup by the type_name encoded in the + # artifact's encoding. Works whether or not the metadata service + # propagates ``serializer_info`` to the reader. + type_name = metadata.encoding[len(cls._ENCODING_PREFIX):] + iotype_cls = get_iotype_by_name(type_name) + # Fallback: explicit module/class hints in serializer_info. Useful for + # inspecting artifacts produced by extensions not loaded locally. + if iotype_cls is None and "iotype_module" in info and "iotype_class" in info: + mod = importlib.import_module(info["iotype_module"]) + iotype_cls = getattr(mod, info["iotype_class"]) + if iotype_cls is None: + raise ValueError( + "IOTypeSerializer could not resolve a class for encoding %r; " + "no IOType subclass is registered under type_name %r and " + "serializer_info lacks iotype_module/iotype_class hints." + % (metadata.encoding, type_name) + ) + # Only allow actual IOType subclasses — metadata is untrusted input. + if not (isinstance(iotype_cls, type) and issubclass(iotype_cls, IOType)): + raise ValueError( + "IOTypeSerializer resolved %r for encoding %r, which is not " + "an IOType subclass" % (iotype_cls, metadata.encoding) + ) + return iotype_cls.deserialize(data, format=STORAGE, metadata=info) diff --git a/metaflow/plugins/datastores/serializers/pickle_serializer.py b/metaflow/plugins/datastores/serializers/pickle_serializer.py new file mode 100644 index 00000000000..dc957d10b6d --- /dev/null +++ b/metaflow/plugins/datastores/serializers/pickle_serializer.py @@ -0,0 +1,63 @@ +import pickle + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, + STORAGE, + WIRE, +) + + +class PickleSerializer(ArtifactSerializer): + """ + Default serializer using Python's pickle module. + + This is the universal fallback — can_serialize always returns True. + PRIORITY is set to 9999 so custom serializers are always tried first. + Pickle produces binary bytes, so only the STORAGE format is supported; + callers that need a wire representation should pick a serializer that + implements it (e.g. an IOType- or JSON-based one in an extension). + """ + + TYPE = "pickle" + PRIORITY = 9999 + + _ENCODINGS = frozenset( + ["pickle-v2", "pickle-v4", "gzip+pickle-v2", "gzip+pickle-v4"] + ) + + @classmethod + def can_serialize(cls, obj): + return True + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding in cls._ENCODINGS + + @classmethod + def serialize(cls, obj, format=STORAGE): + if format == WIRE: + raise NotImplementedError( + "PickleSerializer does not support the WIRE format; pickle " + "produces opaque binary bytes that are not safe to pass as " + "CLI args or inline IPC payloads." + ) + blob = pickle.dumps(obj, protocol=4) + return ( + [SerializedBlob(blob, is_reference=False)], + SerializationMetadata( + obj_type=str(type(obj)), + size=len(blob), + encoding="pickle-v4", + serializer_info={}, + ), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=STORAGE): + if format == WIRE: + raise NotImplementedError( + "PickleSerializer does not support the WIRE format." + ) + return pickle.loads(data[0]) diff --git a/test/unit/io_types/__init__.py b/test/unit/io_types/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/unit/io_types/test_base.py b/test/unit/io_types/test_base.py new file mode 100644 index 00000000000..56a903c4a44 --- /dev/null +++ b/test/unit/io_types/test_base.py @@ -0,0 +1,67 @@ +"""Contract tests for the IOType ABC.""" + +import pytest + +from metaflow.datastore.artifacts.serializer import STORAGE, WIRE, SerializedBlob +from metaflow.io_types import IOType + + +class _TextIOType(IOType): + """Minimal concrete IOType used only to exercise the base-class dispatch.""" + + type_name = "test_text" + + def _wire_serialize(self): + return self._value + + @classmethod + def _wire_deserialize(cls, s): + return cls(s) + + def _storage_serialize(self): + blob = self._value.encode("utf-8") + return [SerializedBlob(blob)], {"length": len(blob)} + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + return cls(blobs[0].decode("utf-8")) + + +def test_cannot_instantiate_abstract_base(): + with pytest.raises(TypeError): + IOType() # missing hook implementations + + +def test_wire_roundtrip(): + wire = _TextIOType("hi").serialize(format=WIRE) + assert wire == "hi" + assert _TextIOType.deserialize(wire, format=WIRE) == _TextIOType("hi") + + +def test_storage_roundtrip(): + blobs, meta = _TextIOType("hi").serialize(format=STORAGE) + assert meta["length"] == 2 + raw = [b.value for b in blobs] + assert _TextIOType.deserialize(raw, format=STORAGE) == _TextIOType("hi") + + +def test_default_format_is_storage(): + out = _TextIOType("hi").serialize() + assert isinstance(out, tuple) # (blobs, metadata) + + +def test_invalid_format_raises(): + with pytest.raises(ValueError): + _TextIOType("hi").serialize(format="bogus") + with pytest.raises(ValueError): + _TextIOType.deserialize("hi", format="bogus") + + +def test_descriptor_has_no_value(): + assert _TextIOType().to_spec() == {"type": "test_text"} + + +def test_eq_and_hash(): + assert _TextIOType("x") == _TextIOType("x") + assert _TextIOType("x") != _TextIOType("y") + assert hash(_TextIOType("x")) == hash(_TextIOType("x")) diff --git a/test/unit/io_types/test_iotype_serializer.py b/test/unit/io_types/test_iotype_serializer.py new file mode 100644 index 00000000000..ed5ca22bd56 --- /dev/null +++ b/test/unit/io_types/test_iotype_serializer.py @@ -0,0 +1,153 @@ +"""Tests for the IOTypeSerializer bridge between IOType and the datastore.""" + +import dataclasses + +import pytest + +from metaflow.datastore.artifacts.serializer import ( + SerializationMetadata, + WIRE, +) +from metaflow.io_types import Json, Struct +from metaflow.plugins.datastores.serializers.iotype_serializer import IOTypeSerializer + + +@dataclasses.dataclass +class _Config: + threshold: float + name: str + + +# --------------------------------------------------------------------------- +# can_serialize / can_deserialize +# --------------------------------------------------------------------------- + + +def test_can_serialize_iotype_instances(): + assert IOTypeSerializer.can_serialize(Json({"a": 1})) is True + assert IOTypeSerializer.can_serialize(Struct(_Config(0.5, "x"))) is True + + +def test_cannot_serialize_plain_python(): + assert IOTypeSerializer.can_serialize({"a": 1}) is False + assert IOTypeSerializer.can_serialize("hello") is False + assert IOTypeSerializer.can_serialize(42) is False + + +def test_can_deserialize_iotype_prefix(): + meta = SerializationMetadata( + obj_type="json", size=0, encoding="iotype:json", serializer_info={} + ) + assert IOTypeSerializer.can_deserialize(meta) is True + + +def test_cannot_deserialize_non_iotype_encoding(): + meta = SerializationMetadata( + obj_type="dict", size=0, encoding="pickle-v4", serializer_info={} + ) + assert IOTypeSerializer.can_deserialize(meta) is False + + +# --------------------------------------------------------------------------- +# serialize — metadata shape +# --------------------------------------------------------------------------- + + +def test_serialize_produces_iotype_encoding(): + blobs, meta = IOTypeSerializer.serialize(Json({"x": 1})) + assert meta.encoding == "iotype:json" + assert meta.obj_type == "json" + assert meta.serializer_info["iotype_module"] == "metaflow.io_types.json_type" + assert meta.serializer_info["iotype_class"] == "Json" + assert len(blobs) == 1 + + +def test_serialize_wire_format_not_supported(): + with pytest.raises(NotImplementedError): + IOTypeSerializer.serialize(Json({"x": 1}), format=WIRE) + with pytest.raises(NotImplementedError): + IOTypeSerializer.deserialize(b"{}", format=WIRE) + + +# --------------------------------------------------------------------------- +# Round-trip: serialize -> deserialize +# --------------------------------------------------------------------------- + + +def test_json_roundtrip(): + original = Json({"threshold": 0.5, "name": "x", "n": 3}) + blobs, meta = IOTypeSerializer.serialize(original) + raw = [b.value for b in blobs] + result = IOTypeSerializer.deserialize(raw, metadata=meta) + assert isinstance(result, Json) + assert result.value == original.value + + +def test_struct_roundtrip_preserves_dataclass(): + original = Struct(_Config(threshold=0.75, name="model")) + blobs, meta = IOTypeSerializer.serialize(original) + raw = [b.value for b in blobs] + result = IOTypeSerializer.deserialize(raw, metadata=meta) + assert isinstance(result, Struct) + assert isinstance(result.value, _Config) + assert result.value.threshold == 0.75 + assert result.value.name == "model" + + +# --------------------------------------------------------------------------- +# Security: deserialize refuses non-IOType classes +# --------------------------------------------------------------------------- + + +def test_deserialize_rejects_non_iotype_class(): + # Craft metadata pointing at a non-IOType class (e.g. json.JSONDecoder). + meta = SerializationMetadata( + obj_type="bogus", + size=0, + encoding="iotype:bogus", + serializer_info={ + "iotype_module": "json", + "iotype_class": "JSONDecoder", + }, + ) + with pytest.raises(ValueError, match="not an IOType subclass"): + IOTypeSerializer.deserialize([b"{}"], metadata=meta) + + +# --------------------------------------------------------------------------- +# serialize — routing-key precedence +# --------------------------------------------------------------------------- + + +def test_subclass_metadata_cannot_overwrite_routing_keys(): + """An IOType subclass whose _storage_serialize returns ``iotype_module`` or + ``iotype_class`` in its own meta dict must not be able to overwrite the + routing keys the bridge writes — otherwise deserialize dispatch is + corrupted. + """ + from metaflow.datastore.artifacts.serializer import SerializedBlob + from metaflow.io_types.base import IOType + + class _Poisoner(IOType): + type_name = "poisoner" + + def _wire_serialize(self): + return "" + + @classmethod + def _wire_deserialize(cls, s): + return cls() + + def _storage_serialize(self): + return [SerializedBlob(b"{}")], { + "iotype_module": "attacker.module", + "iotype_class": "AttackerClass", + } + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + return cls() + + _, meta = IOTypeSerializer.serialize(_Poisoner()) + assert meta.serializer_info["iotype_module"] == _Poisoner.__module__ + assert meta.serializer_info["iotype_class"] == "_Poisoner" diff --git a/test/unit/io_types/test_json_type.py b/test/unit/io_types/test_json_type.py new file mode 100644 index 00000000000..72b7aa500ba --- /dev/null +++ b/test/unit/io_types/test_json_type.py @@ -0,0 +1,96 @@ +from metaflow.io_types import Json + + +def test_json_wire_round_trip_dict(): + j = Json({"key": "value", "nested": [1, 2, 3]}) + s = j.serialize(format="wire") + j2 = Json.deserialize(s, format="wire") + assert j2.value == {"key": "value", "nested": [1, 2, 3]} + + +def test_json_wire_round_trip_list(): + j = Json([1, "two", None, True]) + s = j.serialize(format="wire") + j2 = Json.deserialize(s, format="wire") + assert j2.value == [1, "two", None, True] + + +def test_json_storage_round_trip(): + j = Json({"a": {"b": [1, 2]}, "c": None}) + blobs, meta = j.serialize(format="storage") + assert len(blobs) == 1 + j2 = Json.deserialize([b.value for b in blobs], format="storage") + assert j2.value == j.value + + +def test_json_to_spec(): + assert Json().to_spec() == {"type": "json"} + + +def test_json_empty_dict(): + j = Json({}) + blobs, _ = j.serialize(format="storage") + j2 = Json.deserialize([b.value for b in blobs], format="storage") + assert j2.value == {} + + +def test_json_empty_list(): + j = Json([]) + blobs, _ = j.serialize(format="storage") + j2 = Json.deserialize([b.value for b in blobs], format="storage") + assert j2.value == [] + + +def test_json_deeply_nested(): + data = {"a": {"b": {"c": {"d": [1, 2, {"e": True}]}}}} + j = Json(data) + s = j.serialize(format="wire") + j2 = Json.deserialize(s, format="wire") + assert j2.value == data + + +def test_json_hashable_with_dict_value(): + """Json wrapping a dict must be hashable. The base class default + ``hash((type(self), self._value))`` raises TypeError for unhashable + values, so Json overrides it with a wire-format-based hash. + """ + j = Json({"a": 1, "b": [2, 3]}) + # No TypeError. + h = hash(j) + assert isinstance(h, int) + + # Equal values must hash equal — even when insertion order differs. + j2 = Json({"b": [2, 3], "a": 1}) + assert j == j2 + assert hash(j) == hash(j2) + + # Works inside a set. + assert {Json({"x": 1}), Json({"x": 1})} == {Json({"x": 1})} + + +def test_json_hashable_with_list_value(): + j = Json([1, 2, 3]) + h = hash(j) + assert isinstance(h, int) + assert hash(Json([1, 2, 3])) == hash(j) + + +def test_json_hash_preserves_numeric_equivalence(): + """``1 == 1.0 == True`` in Python — equal dicts must hash equal even + when they mix numeric types. A naive JSON-string hash renders ``1`` + and ``1.0`` as distinct strings and violates the hash/eq contract. + """ + a = Json({"x": 1}) + b = Json({"x": 1.0}) + assert a == b + assert hash(a) == hash(b) + + c = Json({"x": True}) + d = Json({"x": 1}) + assert c == d + assert hash(c) == hash(d) + + e = Json([1, 2, 3]) + f = Json([1.0, 2.0, 3.0]) + assert e == f + assert hash(e) == hash(f) diff --git a/test/unit/io_types/test_struct_type.py b/test/unit/io_types/test_struct_type.py new file mode 100644 index 00000000000..b3fdfcf21be --- /dev/null +++ b/test/unit/io_types/test_struct_type.py @@ -0,0 +1,272 @@ +import dataclasses +from dataclasses import dataclass + +import pytest + +from metaflow.io_types import Struct + + +@dataclass +class SimpleData: + name: str + count: int + score: float + active: bool + + +@dataclass +class Inner: + x: int + y: str + + +@dataclass +class Outer: + label: str + inner: Inner + + +@dataclass +class NestedData: + label: str + sub: dict # container — stays a dict on reconstruction + + +@dataclass +class WithInitFalse: + a: int + computed: int = dataclasses.field(init=False, default=0) + + def __post_init__(self): + self.computed = self.a * 10 + + +@dataclass(frozen=True) +class FrozenWithInitFalse: + a: int + computed: int = dataclasses.field(init=False, default=0) + + def __post_init__(self): + # Frozen dataclasses must use object.__setattr__ to touch fields. + object.__setattr__(self, "computed", self.a * 10) + + +def test_struct_wire_round_trip(): + s = Struct(SimpleData(name="test", count=5, score=3.14, active=True)) + wire = s.serialize(format="wire") + s2 = Struct.deserialize(wire, format="wire") + # Wire deserializes to dict (no dataclass type info in wire format) + assert s2.value == {"name": "test", "count": 5, "score": 3.14, "active": True} + + +def test_struct_storage_round_trip(): + original = SimpleData(name="test", count=5, score=3.14, active=True) + s = Struct(original) + blobs, meta = s.serialize(format="storage") + assert "dataclass_module" in meta + assert "dataclass_class" in meta + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata=meta) + assert s2.value == original + assert type(s2.value) is SimpleData + + +def test_struct_storage_without_dataclass_type(): + """When metadata lacks dataclass info, falls back to dict.""" + s = Struct(SimpleData(name="x", count=1, score=0.0, active=False)) + blobs, _ = s.serialize(format="storage") + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata={}) + assert isinstance(s2.value, dict) + assert s2.value["name"] == "x" + + +def test_struct_to_spec_default(): + """to_spec returns just the type name; richer schemas live in extensions.""" + assert Struct().to_spec() == {"type": "struct"} + assert Struct(dataclass_type=SimpleData).to_spec() == {"type": "struct"} + + +def test_struct_nested_dataclass_roundtrip(): + """Directly nested @dataclass fields reconstruct to their original type.""" + original = Outer(label="root", inner=Inner(x=7, y="hi")) + s = Struct(original) + blobs, meta = s.serialize(format="storage") + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata=meta) + assert s2.value == original + assert type(s2.value) is Outer + assert type(s2.value.inner) is Inner + + +def test_struct_dict_field_stays_dict(): + """Container annotations like ``sub: dict`` aren't walked — dicts stay as-is.""" + nd = NestedData(label="test", sub={"key": [1, 2, 3]}) + s = Struct(nd) + blobs, meta = s.serialize(format="storage") + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata=meta) + assert s2.value == nd + assert type(s2.value) is NestedData + + +def test_struct_wire_deserialize_then_reserialize(): + """Wire round-trip: deserialize returns dict, re-serialize should work on dict.""" + s = Struct(SimpleData(name="test", count=5, score=3.14, active=True)) + wire = s.serialize(format="wire") + s2 = Struct.deserialize(wire, format="wire") + # s2 wraps a dict — re-serializing should work + wire2 = s2.serialize(format="wire") + s3 = Struct.deserialize(wire2, format="wire") + assert s3.value == s2.value + + +def test_struct_security_rejects_non_dataclass(): + """Metadata pointing to a non-dataclass class should be rejected.""" + blobs = [b'{"cmd": "echo pwned"}'] + meta = {"dataclass_module": "subprocess", "dataclass_class": "Popen"} + with pytest.raises(ValueError, match="not a dataclass"): + Struct.deserialize(blobs, format="storage", metadata=meta) + + +def test_struct_security_rejects_dataclass_instance(): + """A callable module-level *instance* that happens to be a dataclass must + also be rejected — ``dataclasses.is_dataclass`` returns True for instances, + so the guard has to require an actual class. + """ + # Build a fake module in sys.modules with a callable dataclass instance. + import sys + import types + + @dataclass + class _Sink: + called_with: dict = dataclasses.field(default_factory=dict) + + def __call__(self, **kwargs): + self.called_with.update(kwargs) + return "attacker controlled" + + fake = types.ModuleType("_struct_security_probe") + fake.sink_instance = _Sink() + sys.modules["_struct_security_probe"] = fake + try: + meta = { + "dataclass_module": "_struct_security_probe", + "dataclass_class": "sink_instance", + } + blobs = [b'{"foo": "bar"}'] + with pytest.raises(ValueError, match="not a dataclass"): + Struct.deserialize(blobs, format="storage", metadata=meta) + # The callable instance must not have been invoked. + assert fake.sink_instance.called_with == {} + finally: + sys.modules.pop("_struct_security_probe", None) + + +def test_struct_init_false_field_round_trip(): + """Dataclasses with ``field(init=False)`` must round-trip without + TypeError. ``dataclasses.asdict`` includes init=False fields in its + output, but passing them to the generated ``__init__`` raises + ``TypeError: __init__() got an unexpected keyword argument``. The + reconstructor must skip them in kwargs and restore their values via + ``setattr`` after construction. + """ + original = WithInitFalse(a=5) + assert original.computed == 50 + # Mutate the computed field after construction to prove the + # serialized value (not __post_init__'s recomputation) survives. + original.computed = 99 + + s = Struct(original) + blobs, meta = s.serialize(format="storage") + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata=meta) + assert type(s2.value) is WithInitFalse + assert s2.value.a == 5 + assert s2.value.computed == 99 + + +def test_struct_plain_dict_value(): + """Struct wrapping a plain dict works for serde.""" + s = Struct({"x": 1, "y": "hello"}) + wire = s.serialize(format="wire") + s2 = Struct.deserialize(wire, format="wire") + assert s2.value == {"x": 1, "y": "hello"} + + +def test_struct_frozen_dataclass_with_init_false_field_round_trip(): + """Frozen dataclasses reject plain ``setattr``. Reconstructing one with + an ``init=False`` field must use ``object.__setattr__`` so the + serialized value survives, matching how frozen dataclasses initialize + such fields in their own ``__post_init__``. + """ + original = FrozenWithInitFalse(a=7) + assert original.computed == 70 + + s = Struct(original) + blobs, meta = s.serialize(format="storage") + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata=meta) + assert type(s2.value) is FrozenWithInitFalse + assert s2.value.a == 7 + assert s2.value.computed == 70 + + +def test_struct_hashable_with_unhashable_dataclass(): + """Struct wrapping a dataclass with mutable fields (list, dict) must be + hashable. The base ``hash((type, _value))`` raises TypeError because + dataclasses with mutable fields are unhashable; Struct overrides + ``__hash__`` via ``_make_hashable``. + """ + + @dataclass + class WithList: + x: int + items: list = dataclasses.field(default_factory=list) + + s = Struct(WithList(x=1, items=[1, 2, 3])) + h = hash(s) + assert isinstance(h, int) + + # Hash/eq contract: equal values hash equal. + s2 = Struct(WithList(x=1, items=[1, 2, 3])) + assert s == s2 + assert hash(s) == hash(s2) + + # Wrapping a plain dict also hashes without TypeError. + s3 = Struct({"a": [1, 2], "b": {"nested": True}}) + assert isinstance(hash(s3), int) + + +def test_struct_hash_preserves_numeric_equivalence(): + """Same contract as Json: equal dicts with int/float/bool members must + hash equal. Confirms the frozenset/tuple based ``_make_hashable`` is + used (not ``json.dumps``, which would render ``1`` and ``1.0`` as + distinct strings). + """ + s_int = Struct({"x": 1}) + s_float = Struct({"x": 1.0}) + assert s_int == s_float + assert hash(s_int) == hash(s_float) + + +def test_struct_descriptor_is_hashable(): + """Struct() (no value) must still be hashable via the _UNSET sentinel.""" + s = Struct() + assert isinstance(hash(s), int) + + +def test_struct_descriptor_uses_unset_sentinel(): + """Struct() (no value) must behave as a pure descriptor via the same + _UNSET sentinel the base IOType uses. Two empty descriptors should be + equal, hashable, and ``repr`` cleanly — and must be distinguishable + from ``Struct(None)`` which wraps an actual ``None`` value. + """ + from metaflow.io_types.base import _UNSET + + a = Struct() + b = Struct() + assert a.value is _UNSET + assert a == b + assert hash(a) == hash(b) + assert repr(a) == "Struct()" + + # Struct(None) is a wrapper around None and is NOT equivalent to a + # descriptor — they compare by their wrapped value. + c = Struct(None) + assert c.value is None + assert a != c diff --git a/test/unit/test_artifact_serializer.py b/test/unit/test_artifact_serializer.py new file mode 100644 index 00000000000..28c97db69ed --- /dev/null +++ b/test/unit/test_artifact_serializer.py @@ -0,0 +1,326 @@ +import pytest + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, + SerializerStore, + STORAGE, + WIRE, +) + + +# Snapshot the registry before this module's classes are defined. Module-level +# test serializers (_HighPrioritySerializer, ...) self-register at class +# definition time; the module-scoped fixture below removes them at teardown so +# other test modules see an unpolluted registry. +_PRE_IMPORT_SNAPSHOT = dict(SerializerStore._all_serializers) + + +@pytest.fixture(scope="module", autouse=True) +def _restore_serializer_registry(): + yield + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(_PRE_IMPORT_SNAPSHOT) + SerializerStore._ordered_cache = None + + +# --------------------------------------------------------------------------- +# Helpers — test serializer subclasses defined inside the test module +# --------------------------------------------------------------------------- + + +class _HighPrioritySerializer(ArtifactSerializer): + TYPE = "test_high" + PRIORITY = 10 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_high" + + @classmethod + def serialize(cls, obj): + blob = obj.encode("utf-8") + return ( + [SerializedBlob(blob)], + SerializationMetadata("str", len(blob), "test_high", {}), + ) + + @classmethod + def deserialize(cls, blobs, metadata): + return blobs[0].decode("utf-8") + + +class _LowPrioritySerializer(ArtifactSerializer): + TYPE = "test_low" + PRIORITY = 200 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, int) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_low" + + @classmethod + def serialize(cls, obj): + blob = str(obj).encode("utf-8") + return ( + [SerializedBlob(blob)], + SerializationMetadata("int", len(blob), "test_low", {}), + ) + + @classmethod + def deserialize(cls, blobs, metadata): + return int(blobs[0].decode("utf-8")) + + +class _SamePrioritySerializer(ArtifactSerializer): + """Same PRIORITY as default (100), registered after _HighPriority and _LowPriority.""" + + TYPE = "test_default_priority" + PRIORITY = 100 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj): + raise NotImplementedError + + @classmethod + def deserialize(cls, blobs, metadata): + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# SerializerStore tests +# --------------------------------------------------------------------------- + + +def test_auto_registration(): + """Subclasses with non-None TYPE are auto-registered.""" + assert "test_high" in SerializerStore._all_serializers + assert "test_low" in SerializerStore._all_serializers + assert SerializerStore._all_serializers["test_high"] is _HighPrioritySerializer + assert SerializerStore._all_serializers["test_low"] is _LowPrioritySerializer + + +def test_base_class_not_registered(): + """ArtifactSerializer itself (TYPE=None) is not registered.""" + assert None not in SerializerStore._all_serializers + + +def test_re_registration_overwrites(): + """A second class with the same TYPE overwrites the first (notebook-friendly).""" + original = SerializerStore._all_serializers["test_high"] + try: + + class _ReplacementSerializer(ArtifactSerializer): + TYPE = "test_high" # same as _HighPrioritySerializer + PRIORITY = 1 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj): + raise NotImplementedError + + @classmethod + def deserialize(cls, blobs, metadata): + raise NotImplementedError + + assert SerializerStore._all_serializers["test_high"] is _ReplacementSerializer + finally: + SerializerStore._all_serializers["test_high"] = original + SerializerStore._ordered_cache = None + + +def test_priority_ordering(): + """get_ordered_serializers returns lower PRIORITY first.""" + ordered = SerializerStore.get_ordered_serializers() + priorities = [s.PRIORITY for s in ordered] + assert priorities == sorted(priorities) + + +def test_registration_order_tiebreaker(): + """When PRIORITY is equal, registration order breaks the tie.""" + ordered = SerializerStore.get_ordered_serializers() + priority_100 = [s for s in ordered if s.PRIORITY == 100] + if len(priority_100) > 1: + registration_order = list(SerializerStore._all_serializers) + indices = [registration_order.index(s.TYPE) for s in priority_100] + assert indices == sorted(indices) + + +def test_deterministic_ordering(): + """Calling get_ordered_serializers twice returns the same order.""" + first = SerializerStore.get_ordered_serializers() + second = SerializerStore.get_ordered_serializers() + assert [s.TYPE for s in first] == [s.TYPE for s in second] + + +def test_high_priority_before_low(): + """_HighPrioritySerializer (PRIORITY=10) comes before _LowPrioritySerializer (PRIORITY=200).""" + ordered = SerializerStore.get_ordered_serializers() + types = [s.TYPE for s in ordered] + assert types.index("test_high") < types.index("test_low") + + +# --------------------------------------------------------------------------- +# SerializationMetadata tests +# --------------------------------------------------------------------------- + + +def test_metadata_fields(): + meta = SerializationMetadata( + obj_type="dict", + size=1024, + encoding="pickle-v4", + serializer_info={"key": "value"}, + ) + assert meta.obj_type == "dict" + assert meta.size == 1024 + assert meta.encoding == "pickle-v4" + assert meta.serializer_info == {"key": "value"} + + +def test_metadata_is_namedtuple(): + meta = SerializationMetadata("str", 10, "utf-8", {}) + assert isinstance(meta, tuple) + assert len(meta) == 4 + + +# --------------------------------------------------------------------------- +# SerializedBlob tests +# --------------------------------------------------------------------------- + + +def test_blob_bytes_auto_detect(): + """bytes value auto-detects as not a reference.""" + blob = SerializedBlob(b"hello") + assert blob.is_reference is False + assert blob.needs_save is True + + +def test_blob_str_auto_detect(): + """str value auto-detects as a reference.""" + blob = SerializedBlob("sha1_key_abc123") + assert blob.is_reference is True + assert blob.needs_save is False + + +def test_blob_explicit_is_reference_override(): + """Explicit is_reference overrides auto-detection.""" + # bytes but marked as reference (edge case) + blob = SerializedBlob(b"data", is_reference=True) + assert blob.is_reference is True + assert blob.needs_save is False + + # str but marked as not a reference (edge case) + blob = SerializedBlob("inline_data", is_reference=False) + assert blob.is_reference is False + assert blob.needs_save is True + + +def test_blob_value_preserved(): + data = b"\x00\x01\x02\x03" + blob = SerializedBlob(data) + assert blob.value is data + + key = "abc123def456" + blob = SerializedBlob(key) + assert blob.value is key + + +def test_blob_rejects_invalid_types(): + """SerializedBlob must be str or bytes — reject everything else.""" + for bad_value in [123, 3.14, None, [], {}]: + with pytest.raises(TypeError, match="must be str or bytes"): + SerializedBlob(bad_value) + + +# --------------------------------------------------------------------------- +# Wire vs storage format dispatch +# --------------------------------------------------------------------------- + + +class _DualFormatSerializer(ArtifactSerializer): + """Toy serializer that implements both formats for str objects.""" + + TYPE = "test_dual_format" + PRIORITY = 40 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_dual_format" + + @classmethod + def serialize(cls, obj, format=STORAGE): + if format == WIRE: + return obj + blob = obj.encode("utf-8") + return ( + [SerializedBlob(blob)], + SerializationMetadata("str", len(blob), "test_dual_format", {}), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=STORAGE): + if format == WIRE: + return data + return data[0].decode("utf-8") + + +def test_format_constants(): + assert STORAGE == "storage" + assert WIRE == "wire" + # Enum members compare both by identity and against the underlying string. + assert STORAGE is STORAGE + assert STORAGE.value == "storage" + assert WIRE.value == "wire" + + +def test_dual_format_storage_roundtrip(): + blobs, meta = _DualFormatSerializer.serialize("hello") + assert meta.encoding == "test_dual_format" + assert _DualFormatSerializer.deserialize( + [b.value for b in blobs], metadata=meta + ) == "hello" + + +def test_dual_format_wire_roundtrip(): + wire = _DualFormatSerializer.serialize("hello", format=WIRE) + assert isinstance(wire, str) + assert _DualFormatSerializer.deserialize(wire, format=WIRE) == "hello" + + +def test_pickle_serializer_rejects_wire(): + from metaflow.plugins.datastores.serializers.pickle_serializer import ( + PickleSerializer, + ) + + with pytest.raises(NotImplementedError): + PickleSerializer.serialize(42, format=WIRE) + with pytest.raises(NotImplementedError): + PickleSerializer.deserialize("42", format=WIRE) diff --git a/test/unit/test_lazy_serializer_registry.py b/test/unit/test_lazy_serializer_registry.py new file mode 100644 index 00000000000..05ac499f97b --- /dev/null +++ b/test/unit/test_lazy_serializer_registry.py @@ -0,0 +1,226 @@ +"""Tests for the lazy serializer registry and its import interceptor.""" + +import sys +import textwrap +import types + +import pytest + +from metaflow.datastore.artifacts import lazy_registry +from metaflow.datastore.artifacts.lazy_registry import ( + SerializerConfig, + _interceptor, + _reset_for_tests, + iter_registered_configs, + load_serializer_class, + register_serializer_config, + register_serializer_for_type, +) + + +@pytest.fixture(autouse=True) +def reset(): + _reset_for_tests() + yield + _reset_for_tests() + + +def test_config_requires_canonical_type_and_dotted_serializer(): + with pytest.raises(ValueError): + SerializerConfig(canonical_type="", serializer="pkg.Cls") + with pytest.raises(ValueError): + SerializerConfig(canonical_type="builtins.dict", serializer="nodot") + + +def test_config_splits_serializer_path(): + cfg = SerializerConfig( + canonical_type="builtins.dict", serializer="pkg.mod.Cls" + ) + assert cfg.serializer_module == "pkg.mod" + assert cfg.serializer_class == "Cls" + + +def test_register_config_is_immediate(): + cfg = SerializerConfig( + canonical_type="builtins.dict", + serializer="test_lazy_serializer_registry.DictSerializer", + ) + register_serializer_config(cfg) + assert cfg in iter_registered_configs() + + +def test_already_imported_type_registers_eagerly(): + """If the type's module is already in sys.modules, no hook install.""" + register_serializer_for_type( + canonical_type="builtins.dict", + serializer="test_lazy_serializer_registry.DictSerializer", + ) + assert any( + c.canonical_type == "builtins.dict" for c in iter_registered_configs() + ) + # Hook should not be installed since dict was already imported. + assert _interceptor not in sys.meta_path + + +def test_deferred_registration_fires_on_import(tmp_path, monkeypatch): + """A not-yet-imported module triggers registration on first import.""" + pkg_dir = tmp_path / "_lazy_probe" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text( + textwrap.dedent( + """ + class ProbeClass: + pass + """ + ) + ) + monkeypatch.syspath_prepend(str(tmp_path)) + + register_serializer_for_type( + canonical_type="_lazy_probe.ProbeClass", + serializer="test_lazy_serializer_registry.DictSerializer", + ) + # Hook must be installed since the probe module isn't loaded yet. + assert _interceptor in sys.meta_path + # No config registered yet. + assert not any( + c.canonical_type == "_lazy_probe.ProbeClass" + for c in iter_registered_configs() + ) + + import _lazy_probe # noqa: F401 + + assert any( + c.canonical_type == "_lazy_probe.ProbeClass" + for c in iter_registered_configs() + ) + + +def test_load_serializer_class_resolves_dotted_path(monkeypatch): + # Inject a fake serializer class into a throwaway module. + fake = types.ModuleType("_fake_serializer_mod") + + class FakeSerializer: + pass + + fake.FakeSerializer = FakeSerializer + monkeypatch.setitem(sys.modules, "_fake_serializer_mod", fake) + + register_serializer_config( + SerializerConfig( + canonical_type="builtins.int", + serializer="_fake_serializer_mod.FakeSerializer", + ) + ) + cls = load_serializer_class("builtins.int") + assert cls is FakeSerializer + # Cached on second call. + assert load_serializer_class("builtins.int") is FakeSerializer + + +def test_load_serializer_class_returns_none_for_unregistered(): + assert load_serializer_class("builtins.nonexistent") is None + + +def test_interceptor_find_spec_returns_none_for_unwatched(): + # Unwatched module — find_spec should decline to intercept. + assert _interceptor.find_spec("json", None) is None + + +def test_wrapped_loader_forwards_unknown_attrs(): + """Loaders expose additional attributes (get_filename, is_package, ...). + The wrapper must forward those so importers that poke at them keep + working. + """ + from metaflow.datastore.artifacts.lazy_registry import _WrappedLoader + + class _FakeLoader: + def create_module(self, spec): + return None + + def exec_module(self, module): + return None + + def get_filename(self, fullname): + return "/tmp/" + fullname + + custom_attr = "hello" + + wrapped = _WrappedLoader(_FakeLoader(), _interceptor) + assert wrapped.get_filename("pkg.mod") == "/tmp/pkg.mod" + assert wrapped.custom_attr == "hello" + + +def test_interceptor_recursion_guard(tmp_path, monkeypatch): + """find_spec must temporarily remove itself from meta_path.""" + pkg_dir = tmp_path / "_lazy_recur" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("X = 1\n") + monkeypatch.syspath_prepend(str(tmp_path)) + + register_serializer_for_type( + canonical_type="_lazy_recur.X", + serializer="test_lazy_serializer_registry.DictSerializer", + ) + # Just import — if the recursion guard is broken, this stack-overflows. + import _lazy_recur # noqa: F401 + + +def test_lazy_registered_serializer_reaches_dispatch(monkeypatch): + """A lazy-registered ArtifactSerializer must surface through + SerializerStore.get_ordered_serializers() — otherwise the extension + registration API is silently inert. + """ + import types as _types + + from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, + SerializerStore, + ) + + snapshot = dict(SerializerStore._all_serializers) + + class _LazyProbeSerializer(ArtifactSerializer): + TYPE = "test_lazy_probe" + PRIORITY = 33 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format="storage"): + blob = b"" + return [SerializedBlob(blob)], SerializationMetadata("x", 0, "x", {}) + + @classmethod + def deserialize(cls, data, metadata=None, format="storage"): + return None + + # Remove it so lazy-registry has to pull it in. + SerializerStore._all_serializers.pop("test_lazy_probe", None) + SerializerStore._ordered_cache = None + + fake = _types.ModuleType("_lazy_probe_mod") + fake._LazyProbeSerializer = _LazyProbeSerializer + monkeypatch.setitem(sys.modules, "_lazy_probe_mod", fake) + + try: + register_serializer_config( + SerializerConfig( + canonical_type="builtins.object", + serializer="_lazy_probe_mod._LazyProbeSerializer", + ) + ) + ordered = SerializerStore.get_ordered_serializers() + assert _LazyProbeSerializer in ordered + finally: + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(snapshot) + SerializerStore._ordered_cache = None diff --git a/test/unit/test_pickle_serializer.py b/test/unit/test_pickle_serializer.py new file mode 100644 index 00000000000..78615939787 --- /dev/null +++ b/test/unit/test_pickle_serializer.py @@ -0,0 +1,186 @@ +import pickle + +import pytest + +from metaflow.datastore.artifacts.serializer import ( + SerializationMetadata, + SerializerStore, +) +from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + + +# --------------------------------------------------------------------------- +# Registration and identity +# --------------------------------------------------------------------------- + + +def test_type_is_pickle(): + assert PickleSerializer.TYPE == "pickle" + + +def test_priority_is_fallback(): + assert PickleSerializer.PRIORITY == 9999 + + +def test_registered_in_store(): + assert "pickle" in SerializerStore._all_serializers + assert SerializerStore._all_serializers["pickle"] is PickleSerializer + + +def test_last_in_ordering(): + """PickleSerializer should be last (highest PRIORITY) among registered serializers.""" + ordered = SerializerStore.get_ordered_serializers() + assert ordered[-1] is PickleSerializer + + +# --------------------------------------------------------------------------- +# can_serialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "obj", + [ + 42, + "hello", + 3.14, + None, + True, + [1, 2, 3], + {"key": "value"}, + (1, "a"), + set([1, 2]), + b"bytes", + object(), + ], + ids=[ + "int", + "str", + "float", + "None", + "bool", + "list", + "dict", + "tuple", + "set", + "bytes", + "object", + ], +) +def test_can_serialize_any_object(obj): + assert PickleSerializer.can_serialize(obj) is True + + +# --------------------------------------------------------------------------- +# can_deserialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "encoding", + ["pickle-v2", "pickle-v4", "gzip+pickle-v2", "gzip+pickle-v4"], +) +def test_can_deserialize_valid_encodings(encoding): + meta = SerializationMetadata("object", 100, encoding, {}) + assert PickleSerializer.can_deserialize(meta) is True + + +@pytest.mark.parametrize( + "encoding", + ["json", "iotype:text", "msgpack", "unknown", ""], +) +def test_cannot_deserialize_unknown_encodings(encoding): + meta = SerializationMetadata("object", 100, encoding, {}) + assert PickleSerializer.can_deserialize(meta) is False + + +# --------------------------------------------------------------------------- +# serialize +# --------------------------------------------------------------------------- + + +def test_serialize_returns_single_blob(): + blobs, meta = PickleSerializer.serialize({"key": "value"}) + assert len(blobs) == 1 + assert blobs[0].needs_save is True + assert blobs[0].is_reference is False + + +def test_serialize_metadata_encoding(): + _, meta = PickleSerializer.serialize(42) + assert meta.encoding == "pickle-v4" + + +def test_serialize_metadata_type(): + _, meta = PickleSerializer.serialize([1, 2, 3]) + assert "list" in meta.obj_type + + +def test_serialize_metadata_size(): + obj = {"a": 1, "b": 2} + blobs, meta = PickleSerializer.serialize(obj) + assert meta.size == len(blobs[0].value) + assert meta.size > 0 + + +def test_serialize_metadata_serializer_info_empty(): + _, meta = PickleSerializer.serialize("hello") + assert meta.serializer_info == {} + + +# --------------------------------------------------------------------------- +# Round-trip: serialize -> deserialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "obj", + [ + 42, + "hello world", + 3.14, + None, + True, + False, + [1, "two", 3.0], + {"nested": {"key": [1, 2, 3]}}, + (1, 2, 3), + set([1, 2, 3]), + b"raw bytes", + ], + ids=[ + "int", + "str", + "float", + "None", + "True", + "False", + "list", + "nested_dict", + "tuple", + "set", + "bytes", + ], +) +def test_round_trip(obj): + blobs, meta = PickleSerializer.serialize(obj) + raw_blobs = [b.value for b in blobs] + result = PickleSerializer.deserialize(raw_blobs, meta) + assert result == obj + + +class _CustomObj: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return isinstance(other, _CustomObj) and self.x == other.x + + +def test_round_trip_custom_class(): + obj = _CustomObj(42) + blobs, meta = PickleSerializer.serialize(obj) + raw_blobs = [b.value for b in blobs] + result = PickleSerializer.deserialize(raw_blobs, meta) + assert result == obj + assert result.x == 42 diff --git a/test/unit/test_serializer_integration.py b/test/unit/test_serializer_integration.py new file mode 100644 index 00000000000..7b2fe051370 --- /dev/null +++ b/test/unit/test_serializer_integration.py @@ -0,0 +1,250 @@ +""" +Integration tests for the pluggable serializer framework wired into TaskDataStore. + +Tests that: +- PickleSerializer handles standard Python objects through save/load_artifacts +- Custom serializers take priority over PickleSerializer +- Backward compat: old artifacts (without serializer_info) still load +- Metadata includes serializer_info when present +""" + +import json +import os +import shutil +import tempfile + +import pytest + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, + SerializerStore, +) +from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + + +# --------------------------------------------------------------------------- +# Test PickleSerializer round-trip through save/load artifacts +# --------------------------------------------------------------------------- + + +@pytest.fixture +def task_datastore(tmp_path): + """Create a minimal TaskDataStore wired to a local storage backend.""" + from metaflow.datastore.flow_datastore import FlowDataStore + from metaflow.plugins.datastores.local_storage import LocalStorage + + storage_root = str(tmp_path / "datastore") + os.makedirs(storage_root, exist_ok=True) + + flow_ds = FlowDataStore( + flow_name="TestFlow", + environment=None, + metadata=None, + event_logger=None, + monitor=None, + storage_impl=LocalStorage, + ds_root=storage_root, + ) + + task_ds = flow_ds.get_task_datastore( + run_id="1", + step_name="start", + task_id="1", + attempt=0, + mode="w", + ) + task_ds.init_task() + # Isolate from test serializers registered by other test files. + # Only use PickleSerializer (as the plugin system would provide). + task_ds._serializers = [PickleSerializer] + return task_ds + + +def test_save_load_pickle_round_trip(task_datastore): + """Standard Python objects go through PickleSerializer and round-trip.""" + artifacts = [ + ("my_dict", {"key": "value", "nested": [1, 2, 3]}), + ("my_int", 42), + ("my_str", "hello world"), + ("my_none", None), + ] + task_datastore.save_artifacts(iter(artifacts)) + + # Verify metadata + for name, _ in artifacts: + info = task_datastore._info[name] + assert "encoding" in info + assert info["encoding"] == "pickle-v4" + assert info["size"] > 0 + assert "type" in info + + # Load and verify + loaded = dict(task_datastore.load_artifacts([name for name, _ in artifacts])) + assert loaded["my_dict"] == {"key": "value", "nested": [1, 2, 3]} + assert loaded["my_int"] == 42 + assert loaded["my_str"] == "hello world" + assert loaded["my_none"] is None + + +def test_distinct_objects_on_load(task_datastore): + """Loading the same artifact twice yields distinct object instances.""" + shared_list = [1, 2, 3] + task_datastore.save_artifacts(iter([("a", shared_list), ("b", shared_list)])) + + loaded = dict(task_datastore.load_artifacts(["a", "b"])) + assert loaded["a"] == loaded["b"] + assert loaded["a"] is not loaded["b"] # distinct instances + + +def test_metadata_has_no_serializer_info_for_pickle(task_datastore): + """PickleSerializer returns empty serializer_info, so _info should not contain it.""" + task_datastore.save_artifacts(iter([("x", 42)])) + info = task_datastore._info["x"] + # Empty serializer_info should NOT be stored (saves space in metadata) + assert "serializer_info" not in info + + +# --------------------------------------------------------------------------- +# Test custom serializer takes priority +# --------------------------------------------------------------------------- + + +def test_custom_serializer_takes_priority(task_datastore): + """A custom serializer with lower PRIORITY claims matching objects over pickle.""" + + # Define and register a custom serializer inside the test + class _JsonStringSerializer(ArtifactSerializer): + TYPE = "test_json_str" + PRIORITY = 50 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_json_str" + + @classmethod + def serialize(cls, obj): + blob = json.dumps(obj).encode("utf-8") + return ( + [SerializedBlob(blob, is_reference=False)], + SerializationMetadata( + obj_type="str", + size=len(blob), + encoding="test_json_str", + serializer_info={"format": "json-utf8"}, + ), + ) + + @classmethod + def deserialize(cls, blobs, metadata): + return json.loads(blobs[0].decode("utf-8")) + + # Explicitly set serializers: custom first, then pickle fallback. + # Don't use get_ordered_serializers() to avoid pollution from other test files. + task_datastore._serializers = [_JsonStringSerializer, PickleSerializer] + + try: + task_datastore.save_artifacts(iter([("msg", "hello"), ("num", 42)])) + + # "msg" should use our custom serializer (str → _JsonStringSerializer) + msg_info = task_datastore._info["msg"] + assert msg_info["encoding"] == "test_json_str" + assert msg_info["serializer_info"] == {"format": "json-utf8"} + + # "num" should fall through to PickleSerializer (int → not claimed by custom) + num_info = task_datastore._info["num"] + assert num_info["encoding"] == "pickle-v4" + + # Both round-trip correctly + loaded = dict(task_datastore.load_artifacts(["msg", "num"])) + assert loaded["msg"] == "hello" + assert loaded["num"] == 42 + finally: + SerializerStore._all_serializers.pop("test_json_str", None) + SerializerStore._ordered_cache = None + + +# --------------------------------------------------------------------------- +# Backward compat: old metadata format +# --------------------------------------------------------------------------- + + +def test_backward_compat_old_metadata(task_datastore): + """Artifacts saved with old metadata format (no serializer_info) still load.""" + # Save normally first + task_datastore.save_artifacts(iter([("old_artifact", {"a": 1})])) + + # Simulate old metadata format: no serializer_info, old encoding + task_datastore._info["old_artifact"] = { + "size": 100, + "type": "", + "encoding": "gzip+pickle-v4", + # no "serializer_info" key + } + + # Should still load via PickleSerializer (can_deserialize handles gzip+pickle-v4) + loaded = dict(task_datastore.load_artifacts(["old_artifact"])) + assert loaded["old_artifact"] == {"a": 1} + + +def test_backward_compat_no_encoding(task_datastore): + """Very old artifacts without encoding field default to gzip+pickle-v2.""" + # Save an artifact + task_datastore.save_artifacts(iter([("ancient", 99)])) + + # Simulate very old metadata: no encoding, no serializer_info + task_datastore._info["ancient"] = { + "size": 10, + "type": "", + # no "encoding" key — defaults to gzip+pickle-v2 + } + + # Should still load + loaded = dict(task_datastore.load_artifacts(["ancient"])) + assert loaded["ancient"] == 99 + + +# --------------------------------------------------------------------------- +# Dynamic registry: lazy registrations reach long-lived datastores +# --------------------------------------------------------------------------- + + +def test_post_init_registration_reaches_existing_datastore(task_datastore): + """A serializer registered AFTER the datastore was constructed must still + be visible. Without the dynamic ``_serializers`` property, lazy imports + (e.g. ``import torch`` after ``TaskDataStore.__init__``) would be silently + ignored for that instance. + """ + # Drop the test override so the property falls back to the live registry. + task_datastore._serializers = None + + class _PostInitSerializer(ArtifactSerializer): + TYPE = "test_post_init_registration" + PRIORITY = 5 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format="storage"): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format="storage"): + raise NotImplementedError + + try: + assert _PostInitSerializer in task_datastore._serializers + finally: + SerializerStore._all_serializers.pop("test_post_init_registration", None) + SerializerStore._ordered_cache = None