diff --git a/metaflow/datastore/artifacts/__init__.py b/metaflow/datastore/artifacts/__init__.py new file mode 100644 index 00000000000..02537a21cd6 --- /dev/null +++ b/metaflow/datastore/artifacts/__init__.py @@ -0,0 +1,8 @@ +from .serializer import ( + ArtifactSerializer, + SerializationFormat, + SerializationMetadata, + SerializedBlob, + SerializerStore, +) +from .diagnostic import list_serializer_status, SerializerState diff --git a/metaflow/datastore/artifacts/diagnostic.py b/metaflow/datastore/artifacts/diagnostic.py new file mode 100644 index 00000000000..246ef4d4d6f --- /dev/null +++ b/metaflow/datastore/artifacts/diagnostic.py @@ -0,0 +1,63 @@ +"""Per-entry diagnostic records for the artifact-serializer lifecycle.""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + + +class SerializerState(str, Enum): + KNOWN = "known" + IMPORTING = "importing" + CLASS_LOADED = "class_loaded" + IMPORTING_DEPS = "importing_deps" + ACTIVE = "active" + PENDING_ON_IMPORTS = "pending_on_imports" + BROKEN = "broken" + DISABLED = "disabled" + + +@dataclass +class SerializerRecord: + name: str + class_path: str + state: SerializerState = SerializerState.KNOWN + awaiting_modules: List[str] = field(default_factory=list) + last_error: Optional[str] = None + priority: Optional[int] = None + type: Optional[str] = None + import_trigger: Optional[str] = None + dispatch_error_count: int = 0 + # Human-readable identifier for where this serializer came from — e.g. + # ``"metaflow"`` for the core, the extension's ``package_name`` for one + # shipped by an extension. Stamped into ``serializer_info["source"]`` at + # save time so the "no deserializer claimed this artifact" load error can + # tell the user which extension to install. Serializers that set + # ``source`` in their own ``serializer_info`` are not overridden. + source: Optional[str] = None + + def as_dict(self): + return { + "name": self.name, + "class_path": self.class_path, + "state": self.state.value, + "awaiting_modules": list(self.awaiting_modules), + "last_error": self.last_error, + "priority": self.priority, + "type": self.type, + "import_trigger": self.import_trigger, + "dispatch_error_count": self.dispatch_error_count, + "source": self.source, + } + + +def list_serializer_status(): + """Return a list of per-serializer diagnostic records as dicts. + + One entry per tuple in ``ARTIFACT_SERIALIZERS_DESC`` (post-toggle), + including entries in ``pending_on_imports``, ``broken``, and + ``disabled`` states. Used for debugging "why isn't my custom + serializer active?". + """ + from .serializer import SerializerStore + + return [rec.as_dict() for rec in SerializerStore._records.values()] diff --git a/metaflow/datastore/artifacts/lazy_registry.py b/metaflow/datastore/artifacts/lazy_registry.py new file mode 100644 index 00000000000..19b1b7cd1c2 --- /dev/null +++ b/metaflow/datastore/artifacts/lazy_registry.py @@ -0,0 +1,125 @@ +""" +Import-hook plumbing that the serializer registry uses to retry a serializer's +``setup_imports`` after one of its required modules becomes importable. + +Extensions ship serializers whose implementation modules may import optional +heavy dependencies (``torch``, ``pyarrow``, ``fastavro``, ``protobuf``, ...). +Loading those modules unconditionally at ``metaflow`` import time would force +every user to pay for dependencies they may not have installed. When +``SerializerStore.bootstrap_entries`` encounters such a missing module, it +parks the entry in ``pending_on_imports`` state and installs a watch here. +The first time the awaited module is imported by the user's code, this +interceptor fires ``SerializerStore._on_module_imported`` so the registry can +retry activation. + +The interceptor is installed on :data:`sys.meta_path` and removes itself from +the path during its own ``find_spec`` call to avoid recursion. + +This module has no public API — extensions declare serializers through +``ARTIFACT_SERIALIZERS_DESC`` in their ``mfextinit_*`` file and interact with +the registry via the state-machine public surface in +:mod:`metaflow.datastore.artifacts.serializer`. +""" + +import importlib +import importlib.abc +import importlib.machinery +import importlib.util +import sys + + +class _WrappedLoader(importlib.abc.Loader): + """Delegating loader that fires a callback after ``exec_module``. + + Only ``create_module`` and ``exec_module`` are overridden. Other loader + attributes (``get_resource_reader``, ``get_filename``, ``get_data``, + ``is_package``, ``get_source``, ...) are forwarded to the wrapped loader + via ``__getattr__`` so importers that poke at those interfaces continue + to work transparently. + """ + + def __init__(self, original_loader, interceptor): + self._original = original_loader + self._interceptor = interceptor + + def create_module(self, spec): + return self._original.create_module(spec) + + def exec_module(self, module): + self._original.exec_module(module) + self._interceptor._on_module_imported(module) + + def __getattr__(self, name): + return getattr(self._original, name) + + +class _SerializerImportInterceptor(importlib.abc.MetaPathFinder): + """ + :class:`importlib.abc.MetaPathFinder` that watches a set of module names + and notifies :class:`SerializerStore` once each has finished executing. + """ + + def __init__(self): + # Module names to watch on behalf of SerializerStore records parked + # via _park_on_import_error. Firing calls + # SerializerStore._on_module_imported. + self._watched = set() + # Modules we have already notified about, to avoid firing twice if + # the same module gets imported through multiple paths. + self._processed = set() + + def watch(self, module_name): + """Watch ``module_name``. When it finishes executing, + :meth:`SerializerStore._on_module_imported` is called.""" + self._watched.add(module_name) + + def find_spec(self, fullname, path, target=None): + if fullname not in self._watched: + return None + # Remove ourselves from the path during the lookup below so Python's + # normal finders (not us) can resolve the real spec. Reinstall after. + was_installed = self in sys.meta_path + if was_installed: + sys.meta_path.remove(self) + try: + spec = importlib.util.find_spec(fullname) + finally: + if was_installed: + sys.meta_path.insert(0, self) + if spec is None or spec.loader is None: + return None + spec.loader = _WrappedLoader(spec.loader, self) + return spec + + def _on_module_imported(self, module): + module_name = module.__name__ + if module_name in self._processed: + return + self._processed.add(module_name) + if module_name not in self._watched: + return + try: + from .serializer import SerializerStore + + SerializerStore._on_module_imported(module_name, module) + except Exception: + # A broken callback must not break the host's import. The record + # itself will be marked BROKEN via _retry_bootstrap. + pass + + +_interceptor = _SerializerImportInterceptor() + + +def _ensure_interceptor_installed(): + if _interceptor in sys.meta_path: + sys.meta_path.remove(_interceptor) + sys.meta_path.insert(0, _interceptor) + + +def _reset_for_tests(): + """Clear all module-level state. Intended for unit tests only.""" + _interceptor._watched.clear() + _interceptor._processed.clear() + if _interceptor in sys.meta_path: + sys.meta_path.remove(_interceptor) diff --git a/metaflow/datastore/artifacts/serializer.py b/metaflow/datastore/artifacts/serializer.py new file mode 100644 index 00000000000..f01790e0210 --- /dev/null +++ b/metaflow/datastore/artifacts/serializer.py @@ -0,0 +1,717 @@ +import inspect +from abc import ABCMeta, abstractmethod +from collections import namedtuple +from enum import Enum +from typing import Any, List, Optional, Tuple, Union + + +class SerializationFormat(str, Enum): + """ + Serialization format for :class:`ArtifactSerializer`. + + ``STORAGE`` produces ``(blobs, metadata)`` for the datastore persist path; + ``WIRE`` produces a ``str`` for CLI args, protobuf payloads, and + cross-process IPC. + + This subclasses ``str`` so that existing equality checks and JSON / artifact + metadata round-trips continue to work with the underlying ``"storage"`` / + ``"wire"`` values. + """ + + STORAGE = "storage" + WIRE = "wire" + + +def _call_setup_imports(cls, context=None): + """Invoke ``cls.setup_imports`` respecting both signatures: + ``def setup_imports(cls)`` and ``def setup_imports(cls, context=None)``. + """ + func = cls.setup_imports.__func__ # unwrap classmethod + # co_argcount counts positional args: (cls,) -> 1, (cls, context) -> 2. + if func.__code__.co_argcount >= 2: + return cls.setup_imports(context) + return cls.setup_imports() + + +SerializationMetadata = namedtuple( + "SerializationMetadata", ["obj_type", "size", "encoding", "serializer_info"] +) + + +class _OrderedSet: + """ + Minimal insertion-ordered set used for ``_active_serializers``. + + Supports the subset of ``set`` API the codebase exercises: + ``add``, ``discard``, ``update``, ``clear``, ``in``, iteration, ``len``. + Iteration yields insertion order (dict keys preserve it since 3.7). + """ + + __slots__ = ("_d",) + + def __init__(self, iterable=None): + self._d = {} + if iterable is not None: + self.update(iterable) + + def add(self, x): + # Re-adding is a no-op so "last registered" reflects first-time + # insertion order rather than most recent re-assertion. + if x not in self._d: + self._d[x] = None + + def discard(self, x): + self._d.pop(x, None) + + def update(self, iterable): + for x in iterable: + self.add(x) + + def clear(self): + self._d.clear() + + def __contains__(self, x): + return x in self._d + + def __iter__(self): + return iter(self._d) + + def __len__(self): + return len(self._d) + + def __repr__(self): + return "_OrderedSet(%r)" % list(self._d) + + +class SerializedBlob(object): + """ + Represents a single blob produced by a serializer. + + A serializer may produce multiple blobs per artifact. Each blob is either: + - New bytes to be stored (is_reference=False, value is bytes) + - A reference to already-stored data (is_reference=True, value is a string key) + + Parameters + ---------- + value : Union[str, bytes] + The blob data (bytes) or a reference key (str). + is_reference : bool, optional + If None, auto-detected from value type: str -> reference, bytes -> new data. + """ + + def __init__(self, value, is_reference=None): + if not isinstance(value, (str, bytes)): + raise TypeError( + "SerializedBlob value must be str or bytes, got %s" + % type(value).__name__ + ) + self.value = value + if is_reference is None: + self.is_reference = isinstance(value, str) + else: + self.is_reference = is_reference + + @property + def needs_save(self): + """True if this blob contains new bytes that need to be stored.""" + return not self.is_reference + + +class SerializerStore(ABCMeta): + """ + Metaclass for ArtifactSerializer that auto-registers subclasses by TYPE. + + Provides deterministic ordering: serializers are sorted by (PRIORITY, registration_order). + Lower PRIORITY values are tried first. Registration order breaks ties. + """ + + _all_serializers = {} + # Dispatch pool — only classes whose state == ACTIVE are here. Uses an + # insertion-ordered set so that ``get_ordered_serializers`` can honor + # "last-registered wins" on PRIORITY ties regardless of how a class got + # into the pool (bootstrap, retry hook, or a direct ``.add`` by a test). + _active_serializers = _OrderedSet() + # Diagnostic records, keyed by tuple name from ARTIFACT_SERIALIZERS_DESC. + _records = {} + # Map: awaited-module-name -> list of record names waiting on it. + _pending_by_module = {} + _ordered_cache = None + + def __init__(cls, name, bases, namespace): + super().__init__(name, bases, namespace) + # Skip the abstract base and any subclass that didn't implement all + # abstract methods — registering a partially-abstract class would + # blow up only at dispatch time. + if cls.TYPE is None or inspect.isabstract(cls): + return + SerializerStore._all_serializers[cls.TYPE] = cls + SerializerStore._ordered_cache = None + + @staticmethod + def get_ordered_serializers(): + """ + Return classes in ``_active_serializers`` sorted for dispatch. + + Sort order: + - PRIORITY ascending (lower tried first) + - On PRIORITY tie: last-registered in ``_active_serializers`` order + wins (i.e., later insertion dispatched first) + - Deterministic lexicographic tertiary key on ``class_path`` for + cross-environment reproducibility + """ + if SerializerStore._ordered_cache is not None: + return SerializerStore._ordered_cache + + # Use a list from the set in deterministic registration order where + # possible. Python sets don't preserve order; we recover a stable + # order by consulting ``_records`` (dict iteration preserves insertion + # order) and picking classes whose type matches an active record. + record_order = [ + rec + for rec in SerializerStore._records.values() + if rec.state.value == "active" + ] + # Find each active record's class via its type; tolerant of records + # whose class isn't in _all_serializers (shouldn't happen in practice). + active_in_order = [] + for rec in record_order: + cls = SerializerStore._all_serializers.get(rec.type) + if cls is not None and cls in SerializerStore._active_serializers: + active_in_order.append(cls) + # Include any _active_serializers entries that don't have a record + # (e.g. inline-registered test classes) at the end of the base list. + seen = set(active_in_order) + for cls in SerializerStore._active_serializers: + if cls not in seen: + active_in_order.append(cls) + seen.add(cls) + + # Sort: PRIORITY ascending; on ties, LAST-registered wins + # (secondary key = -index); tertiary key = class path. + def _sort_key(iv): + idx, cls = iv + class_path = "%s.%s" % (cls.__module__, cls.__qualname__) + return (cls.PRIORITY, -idx, class_path) + + indexed = list(enumerate(active_in_order)) + SerializerStore._ordered_cache = [ + cls for _, cls in sorted(indexed, key=_sort_key) + ] + return SerializerStore._ordered_cache + + @classmethod + def bootstrap(cls): + """Walk every extension's ARTIFACT_SERIALIZERS_DESC, apply toggles + from config, and drive each entry through the state machine. Called + once at Metaflow startup from ``metaflow/plugins/__init__.py``. + """ + entries = [] + + # Bucketed by source so each bucket can be passed to + # bootstrap_entries with its own source label. + by_source = [] # list of (source, entries) + + # Core serializers (shipped with Metaflow). Class paths in the core + # ARTIFACT_SERIALIZERS_DESC are relative to ``metaflow.plugins`` and + # must be resolved here now that ``artifact_serializer`` is no longer + # a registered plugin category (which previously handled this via + # ``_resolve_relative_paths``). + try: + from metaflow.plugins import ARTIFACT_SERIALIZERS_DESC as core_desc + + core_entries = [ + (name, cls._resolve_relative_class_path(path, "metaflow.plugins")) + for name, path in core_desc + ] + if core_entries: + by_source.append(("metaflow", core_entries)) + except ImportError: + # Possible during partial test imports; skip gracefully. + pass + + # Extension serializers — stamp each with the extension's package name + # so load errors can point at the right package to install. + try: + from metaflow.extension_support import plugins as ext_plugins + + for ext_entry in ext_plugins.get_modules("plugins"): + mod = getattr(ext_entry, "module", None) + if mod is None: + continue + ext_pkg = getattr(mod, "__package__", None) or "" + ext_desc = mod.__dict__.get("ARTIFACT_SERIALIZERS_DESC", []) + ext_entries = [ + (name, cls._resolve_relative_class_path(path, ext_pkg)) + for name, path in ext_desc + ] + if not ext_entries: + continue + source = ( + getattr(ext_entry, "package_name", None) + or getattr(ext_entry, "tl_package", None) + or ext_pkg + or None + ) + by_source.append((source, ext_entries)) + except Exception: + # Do not let extension discovery failures kill Metaflow startup. + pass + + # Resolve +/- toggles from config. + disabled_names = set() + try: + from metaflow import metaflow_config + + enabled_list = ( + getattr(metaflow_config, "ENABLED_ARTIFACT_SERIALIZER", None) or [] + ) + for n in enabled_list: + if isinstance(n, str) and n.startswith("-"): + disabled_names.add(n[1:]) + except Exception: + pass + + for source, group in by_source: + cls.bootstrap_entries(group, disabled_names=disabled_names, source=source) + + @staticmethod + def _resolve_relative_class_path(class_path, pkg_path): + """Convert a leading-dot class_path into an absolute dotted path. + + Mirrors ``metaflow.extension_support.plugins._resolve_relative_paths``'s + ``resolve_path`` inner helper: the number of leading dots indicates + how many levels to walk up from ``pkg_path``. Non-relative paths are + returned unchanged. ``pkg_path`` should be the ``__package__`` of the + module that declared the descriptor (e.g. ``metaflow.plugins``). + """ + if not class_path or class_path[0] != ".": + return class_path + pkg_components = pkg_path.split(".") if pkg_path else [] + i = 1 + while i < len(class_path) and class_path[i] == ".": + i += 1 + if i > len(pkg_components): + raise ValueError( + "Path '%s' exits out of Metaflow module at %s" % (class_path, pkg_path) + ) + prefix = ".".join(pkg_components[: -i + 1] if i > 1 else pkg_components) + return prefix + class_path[i - 1 :] + + @classmethod + def bootstrap_entries(cls, entries, disabled_names=None, source=None): + """Drive a list of (name, class_path) tuples through the state machine. + + Called once at Metaflow startup via ``bootstrap()`` and directly by + tests for isolation. + + ``disabled_names``: set of tuple names to mark DISABLED without + attempting import (from ``-name`` toggles in config). + ``source``: human-readable identifier for where these entries came + from (e.g. ``"metaflow"`` for core, an extension's ``package_name`` + for extension-shipped serializers). Stamped on each record and + auto-injected into ``serializer_info["source"]`` at save time unless + the serializer sets its own value. + """ + import importlib + from .diagnostic import SerializerRecord, SerializerState + + disabled_names = disabled_names or set() + + for name, class_path in entries: + rec = SerializerRecord(name=name, class_path=class_path, source=source) + cls._records[name] = rec + + if name in disabled_names: + rec.state = SerializerState.DISABLED + continue + + module_path, class_name = class_path.rsplit(".", 1) + + rec.state = SerializerState.IMPORTING + try: + module = importlib.import_module(module_path) + except ImportError as e: + cls._park_on_import_error(rec, e) + continue + + serializer_cls = getattr(module, class_name, None) + if serializer_cls is None: + rec.state = SerializerState.BROKEN + rec.last_error = "class '%s' not found in module '%s'" % ( + class_name, + module_path, + ) + continue + + if serializer_cls.TYPE != name: + rec.state = SerializerState.BROKEN + rec.last_error = "tuple name '%s' != class.TYPE '%s'" % ( + name, + serializer_cls.TYPE, + ) + continue + + rec.state = SerializerState.CLASS_LOADED + rec.priority = serializer_cls.PRIORITY + rec.type = serializer_cls.TYPE + + rec.state = SerializerState.IMPORTING_DEPS + try: + _call_setup_imports(serializer_cls, context=None) + except ImportError as e: + cls._park_on_import_error(rec, e) + continue + except Exception as e: + rec.state = SerializerState.BROKEN + rec.last_error = "%s: %s" % (type(e).__name__, e) + continue + + rec.state = SerializerState.ACTIVE + rec.import_trigger = "eager" + cls._active_serializers.add(serializer_cls) + cls._ordered_cache = None + + @classmethod + def _park_on_import_error(cls, rec, exc): + """Transition rec to PENDING_ON_IMPORTS and register the watched module. + + Used for ImportError / ModuleNotFoundError during module-import OR + setup_imports. Loop guard: same e.name twice in a row -> BROKEN. + + Installs a ``sys.meta_path`` retry hook for the missing module so + that the record is re-driven through the state machine when (and + if) the module is eventually imported. + """ + from .diagnostic import SerializerState + + if exc.name is None: + rec.state = SerializerState.BROKEN + rec.last_error = "ImportError with no module name: %s" % exc + return + + # Loop guard: if we've parked on this exact module name before + # (either currently awaiting or historically), the dep is not + # recoverable — mark broken to prevent infinite retries. + prev_seen = getattr(rec, "_previously_awaited", set()) + if exc.name in rec.awaiting_modules or exc.name in prev_seen: + rec.state = SerializerState.BROKEN + rec.last_error = "repeated ImportError on '%s': %s" % (exc.name, exc) + return + + rec.awaiting_modules.append(exc.name) + rec.state = SerializerState.PENDING_ON_IMPORTS + rec.last_error = "%s: %s" % (type(exc).__name__, exc) + cls._pending_by_module.setdefault(exc.name, []).append(rec.name) + + # Install the sys.meta_path hook so an eventual import of exc.name + # drives _on_module_imported -> _retry_bootstrap. + from .lazy_registry import _ensure_interceptor_installed, _interceptor + + _ensure_interceptor_installed() + _interceptor.watch(exc.name) + + @classmethod + def _on_module_imported(cls, module_name, module): + """Called when a watched module finishes importing. Retries bootstrap + for every record awaiting that module. Safe to call directly for + tests; normally fired by the lazy_registry import interceptor. + """ + waiting = cls._pending_by_module.pop(module_name, []) + for record_name in waiting: + rec = cls._records.get(record_name) + if rec is None: + continue + # Clear this module from the rec's awaiting list before retry. + # Loop guard reads awaiting_modules, so we must remove the current + # module or a legitimate retry would short-circuit to BROKEN. The + # history tracker ``_previously_awaited`` preserves the fact that + # we did park on this module before, so a genuine repeat failure + # still trips the guard. + if not hasattr(rec, "_previously_awaited"): + rec._previously_awaited = set() + rec._previously_awaited.update(rec.awaiting_modules) + rec.awaiting_modules = [m for m in rec.awaiting_modules if m != module_name] + cls._retry_bootstrap(rec) + + @classmethod + def _retry_bootstrap(cls, rec): + """Re-run the import + setup_imports sequence for a single record.""" + import importlib + from .diagnostic import SerializerState + + module_path, class_name = rec.class_path.rsplit(".", 1) + + try: + module = importlib.import_module(module_path) + except ImportError as e: + cls._park_on_import_error(rec, e) + return + + serializer_cls = getattr(module, class_name, None) + if serializer_cls is None: + rec.state = SerializerState.BROKEN + rec.last_error = "class '%s' not found in '%s' on retry" % ( + class_name, + module_path, + ) + return + + try: + _call_setup_imports(serializer_cls, context=None) + except ImportError as e: + cls._park_on_import_error(rec, e) + return + except Exception as e: + rec.state = SerializerState.BROKEN + rec.last_error = "%s: %s" % (type(e).__name__, e) + return + + rec.state = SerializerState.ACTIVE + rec.import_trigger = "hook" + cls._active_serializers.add(serializer_cls) + cls._ordered_cache = None + + @classmethod + def get_source_for(cls, serializer_cls): + """Return the ``source`` label attached to a serializer's record, or + ``None`` if no record exists. Looked up by ``TYPE`` — tuple names are + validated equal to ``TYPE`` at bootstrap time.""" + target_type = getattr(serializer_cls, "TYPE", None) + if target_type is None: + return None + for rec in cls._records.values(): + if rec.type == target_type: + return rec.source + return None + + @classmethod + def _reset_for_tests(cls): + """Clear all registry state. Intended for unit tests only. + + Clears: + - ``_records`` (per-entry diagnostic records) + - ``_active_serializers`` (dispatch pool) + - ``_pending_by_module`` (retry watch map) + - ``_ordered_cache`` (dispatch sort cache) + - per-class ``_lazy_imported_names`` (walks MRO of every class in + ``_all_serializers`` and ``_active_serializers`` and delattr's + tracked lazy-imported attribute names) + - Interceptor's watched-module set so a future bootstrap does not + re-fire stale hooks. + + Does NOT clear ``_all_serializers`` (metaclass-populated). Tests that + need that swept should do it themselves. + """ + # Collect candidate classes from both registries (metaclass-tracked + # + active dispatch pool). + candidate_classes = set(cls._all_serializers.values()) + candidate_classes.update(cls._active_serializers) + + # For each class (and its MRO), clear lazy-imported attributes. + for serializer_cls in candidate_classes: + for base in serializer_cls.__mro__: + names = base.__dict__.get("_lazy_imported_names") + if not names: + continue + for name in list(names): + if name in base.__dict__: + try: + delattr(base, name) + except (AttributeError, TypeError): + pass + names.clear() + + # Reset registry state. + cls._records.clear() + cls._active_serializers.clear() + cls._pending_by_module.clear() + cls._ordered_cache = None + + # Clear interceptor watches so a future bootstrap does not fire on + # stale module names. + try: + from .lazy_registry import _interceptor + + if hasattr(_interceptor, "_watched"): + _interceptor._watched.clear() + except ImportError: + pass + + +class ArtifactSerializer(object, metaclass=SerializerStore): + """ + Abstract base class for artifact serializers. + + Subclasses must set TYPE to a unique string identifier and implement + all four class methods. Subclasses are auto-registered by the SerializerStore + metaclass on class definition. + + Attributes + ---------- + TYPE : str or None + Unique identifier for this serializer (e.g., "pickle", "iotype"). + Set to None in the base class to prevent registration. + PRIORITY : int + Dispatch priority. Lower values are tried first. Default 100. + PickleSerializer uses 9999 as the universal fallback. + """ + + TYPE = None + PRIORITY = 100 + + @classmethod + def setup_imports(cls, context=None): + """Perform heavy imports. Called once by SerializerStore after the + class module loads cleanly and before the class is added to the + dispatch pool. If this method raises ``ImportError`` with a named + missing module, Metaflow parks the serializer until that module + appears in ``sys.modules`` and retries. Any other exception moves + the serializer to the ``broken`` state. + + The ``context`` parameter is reserved for future use; it is always + ``None`` today. Authors may omit it entirely + (``def setup_imports(cls):``) or accept it with a default + (``def setup_imports(cls, context=None):``). + + Default: no-op. + """ + return None + + _RESERVED_LAZY_IMPORT_NAMES = frozenset( + { + "TYPE", + "PRIORITY", + "setup_imports", + "can_serialize", + "can_deserialize", + "serialize", + "deserialize", + "lazy_import", + } + ) + + @classmethod + def lazy_import(cls, module_path, alias=None): + """Import ``module_path``, stash on ``cls``, return the module. + + Alias defaults to the leaf of ``module_path`` + (``lazy_import("torch.nn.functional")`` -> ``cls.functional``). + Names starting with ``_`` and methods of ``ArtifactSerializer`` are + reserved and cannot be used as aliases. + + Propagates ``ImportError`` / ``ModuleNotFoundError`` from the + import itself -- the state machine interprets these to install a + retry hook. + """ + import importlib + + name = alias if alias is not None else module_path.rsplit(".", 1)[-1] + if name in cls._RESERVED_LAZY_IMPORT_NAMES or name.startswith("_"): + raise ValueError("lazy_import alias '%s' is reserved or invalid" % name) + # Track per-class so _reset_for_tests can clean up later. + if "_lazy_imported_names" not in cls.__dict__: + cls._lazy_imported_names = set() + if name in cls._lazy_imported_names: + # Idempotent re-import: the alias was set in a prior setup_imports + # call (e.g. first attempt partially succeeded before parking on a + # missing dep). Return the already-stashed module. + return getattr(cls, name) + mod = importlib.import_module(module_path) + setattr(cls, name, mod) + cls._lazy_imported_names.add(name) + return mod + + @classmethod + @abstractmethod + def can_serialize(cls, obj: Any) -> bool: + """ + Return True if this serializer can handle the given object. + + Parameters + ---------- + obj : Any + The Python object to serialize. + + Returns + ------- + bool + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def can_deserialize(cls, metadata: SerializationMetadata) -> bool: + """ + Return True if this serializer can deserialize given the metadata. + + Parameters + ---------- + metadata : SerializationMetadata + Metadata stored alongside the artifact. + + Returns + ------- + bool + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def serialize( + cls, + obj: Any, + format: SerializationFormat = SerializationFormat.STORAGE, + ) -> Union[Tuple[List["SerializedBlob"], SerializationMetadata], str]: + """ + Serialize obj. Must be side-effect-free: this method may be invoked + multiple times (caching, retries, parallel dispatch) and must not + perform I/O, mutate global state, or register the object elsewhere. + Side effects that need to happen at persist time belong in hooks, + not in the serializer. + + Parameters + ---------- + obj : Any + The Python object to serialize. + format : SerializationFormat + Either ``SerializationFormat.STORAGE`` (default) or + ``SerializationFormat.WIRE``. + - ``STORAGE`` returns a tuple ``(List[SerializedBlob], SerializationMetadata)`` + for persisting through the datastore. + - ``WIRE`` returns a ``str`` representation for CLI args, protobuf + payloads, and cross-process IPC. Serializers that cannot provide + a wire encoding should raise ``NotImplementedError``. + + Returns + ------- + tuple or str + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def deserialize( + cls, + data: Union[List[bytes], str], + metadata: Optional[SerializationMetadata] = None, + format: SerializationFormat = SerializationFormat.STORAGE, + ) -> Any: + """ + Deserialize back to a Python object. + + Parameters + ---------- + data : Union[List[bytes], str] + ``List[bytes]`` when ``format=STORAGE``; ``str`` when ``format=WIRE``. + metadata : SerializationMetadata, optional + Metadata stored alongside the artifact. Required for STORAGE, + ignored for WIRE. + format : SerializationFormat + Either ``SerializationFormat.STORAGE`` (default) or + ``SerializationFormat.WIRE``. + + Returns + ------- + Any + """ + raise NotImplementedError diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index abeeb8ea5fb..755cb2fe913 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -1,6 +1,5 @@ from collections import defaultdict import json -import pickle import sys import time @@ -15,11 +14,32 @@ from ..parameters import Parameter from ..util import Path, is_stringish, to_fileobj +from .artifacts.serializer import SerializationMetadata, SerializerStore from .exceptions import DataException, UnpicklableArtifactException _included_file_type = "" +def _record_dispatch_error(serializer_cls, exc): + """Increment dispatch_error_count + update last_error for the given serializer's + diagnostic record. No-op if the class has no record (e.g., it was not + registered via bootstrap).""" + try: + from .artifacts.serializer import SerializerStore + + target_type = getattr(serializer_cls, "TYPE", None) + if target_type is None: + return + for rec in SerializerStore._records.values(): + if rec.type == target_type: + rec.dispatch_error_count += 1 + rec.last_error = "dispatch: %s: %s" % (type(exc).__name__, exc) + return + except Exception: + # Never let diagnostic bookkeeping crash dispatch. + pass + + def only_if_not_done(f): @wraps(f) def method(self, *args, **kwargs): @@ -117,12 +137,9 @@ def __init__( self._parent = flow_datastore self._persist = persist - # The GZIP encodings are for backward compatibility - self._encodings = {"pickle-v2", "gzip+pickle-v2"} - ver = sys.version_info[0] * 10 + sys.version_info[1] - if ver >= 36: - self._encodings.add("pickle-v4") - self._encodings.add("gzip+pickle-v4") + # Optional override consumed by the ``_serializers`` property (used by + # tests to swap in a fixed dispatch list). + self._serializers_override = None self._is_done_set = False @@ -200,6 +217,24 @@ def __init__( else: raise DataException("Unknown datastore mode: '%s'" % self._mode) + @property + def _serializers(self): + # Dispatches through ``SerializerStore.get_ordered_serializers()`` on + # each access. The lookup is cheap (cached inside the store) and picks + # up serializers registered via the lazy import hook after this + # instance was constructed — otherwise long-lived datastores + # (notebooks, client sessions) would silently miss any extension + # registered after init. Tests can short-circuit this by assigning to + # ``_serializers`` directly (see the setter, which stores the override + # in ``_serializers_override``). + if self._serializers_override is not None: + return self._serializers_override + return SerializerStore.get_ordered_serializers() + + @_serializers.setter + def _serializers(self, value): + self._serializers_override = value + @property def pathspec(self): return "/".join([self.run_id, self.step_name, self.task_id]) @@ -347,38 +382,79 @@ def save_artifacts(self, artifacts_iter, len_hint=0): """ artifact_names = [] - def pickle_iter(): + def serialize_iter(): for name, obj in artifacts_iter: - encode_type = "gzip+pickle-v4" - if encode_type in self._encodings: + # Find the first serializer that can handle this object + serializer = None + for s in self._serializers: try: - blob = pickle.dumps(obj, protocol=4) - except TypeError as e: - raise UnpicklableArtifactException(name) from e - else: - try: - blob = pickle.dumps(obj, protocol=2) - encode_type = "gzip+pickle-v2" - except (SystemError, OverflowError) as e: - raise DataException( - "Artifact *%s* is very large (over 2GB). " - "You need to use Python 3.6 or newer if you want to " - "serialize large objects." % name - ) from e - except TypeError as e: - raise UnpicklableArtifactException(name) from e + if s.can_serialize(obj): + serializer = s + break + except Exception as e: + _record_dispatch_error(s, e) + continue + if serializer is None: + raise DataException( + "No serializer claimed artifact '%s' (type: %s). " + "The PickleSerializer fallback normally handles all " + "objects — check that it is installed and enabled." + % (name, type(obj).__name__) + ) + + try: + blobs, metadata = serializer.serialize(obj) + except TypeError as e: + raise UnpicklableArtifactException(name) from e + + # Validate the blob shape BEFORE recording anything in + # ``_info`` — a failure here must not leave ``_info[name]`` + # populated for an artifact we refused to persist. + if not blobs: + raise DataException( + "Serializer %s returned no blobs for artifact '%s'" + % (serializer.__name__, name) + ) + if len(blobs) > 1: + # The datastore currently stores a single blob per + # artifact. Silently dropping blobs[1:] would corrupt + # multi-blob IOType extensions (e.g. chunked tensors) on + # load. Fail loudly until multi-blob support lands. + raise DataException( + "Serializer %s returned %d blobs for artifact '%s'; " + "only single-blob serializers are supported at this " + "time. If you have a need for multi blob serializers, " + "please reach out to the Metaflow team." + % (serializer.__name__, len(blobs), name) + ) + + # Auto-inject ``source`` into serializer_info so the + # "no deserializer claimed artifact" load error can point at + # the extension to install. Authors who set their own + # ``source`` in the returned ``serializer_info`` are not + # overridden. Copy the dict so we don't mutate the + # serializer's returned value across calls. + merged_info = ( + dict(metadata.serializer_info) if metadata.serializer_info else {} + ) + if "source" not in merged_info: + auto_source = SerializerStore.get_source_for(serializer) + if auto_source: + merged_info["source"] = auto_source self._info[name] = { - "size": len(blob), - "type": str(type(obj)), - "encoding": encode_type, + "size": metadata.size, + "type": metadata.obj_type, + "encoding": metadata.encoding, } + if merged_info: + self._info[name]["serializer_info"] = merged_info artifact_names.append(name) - yield blob + yield blobs[0].value # Use the content-addressed store to store all artifacts - save_result = self._ca_store.save_blobs(pickle_iter(), len_hint=len_hint) + save_result = self._ca_store.save_blobs(serialize_iter(), len_hint=len_hint) for name, result in zip(artifact_names, save_result): self._objects[name] = result.key @@ -414,32 +490,53 @@ def load_artifacts(self, names): "load artifacts" % self._path ) to_load = defaultdict(list) + deserializers = {} # name -> serializer class for name in names: - info = self._info.get(name) - # We use gzip+pickle-v2 as this is the oldest/most compatible. - # This datastore will always include the proper encoding version so - # this is just to be able to read very old artifacts - if info: - encode_type = info.get("encoding", "gzip+pickle-v2") - else: - encode_type = "gzip+pickle-v2" - if encode_type not in self._encodings: + info = self._info.get(name, {}) + metadata = SerializationMetadata( + obj_type=info.get("type", "object"), + size=info.get("size", 0), + # Default to gzip+pickle-v2 for very old artifacts without encoding + encoding=info.get("encoding", "gzip+pickle-v2"), + serializer_info=info.get("serializer_info", {}), + ) + + # Find deserializer via metadata + deserializer = None + for s in self._serializers: + try: + if s.can_deserialize(metadata): + deserializer = s + break + except Exception as e: + _record_dispatch_error(s, e) + continue + if deserializer is None: + source_hint = "" + serializer_source = metadata.serializer_info.get("source") + if serializer_source: + source_hint = ( + " The artifact was written by '%s' — the " + "corresponding extension may not be installed." + % serializer_source + ) raise DataException( - "Python 3.6 or later is required to load artifact '%s'" % name + "No deserializer claimed artifact '%s' (encoding: %s, " + "serializer_info: %s).%s" + % (name, metadata.encoding, metadata.serializer_info, source_hint) ) - else: - to_load[self._objects[name]].append(name) - # At this point, we load what we don't have from the CAS - # We assume that if we have one "old" style artifact, all of them are - # like that which is an easy assumption to make since artifacts are all - # stored by the same implementation of the datastore for a given task. + deserializers[name] = (deserializer, metadata) + to_load[self._objects[name]].append(name) + + # Load blobs from CAS and deserialize for key, blob in self._ca_store.load_blobs(to_load.keys()): - names = to_load[key] - for name in names: - # We unpickle everytime to have fully distinct objects (the user + loaded_names = to_load[key] + for name in loaded_names: + deserializer, metadata = deserializers[name] + # Deserialize each time to have fully distinct objects (the user # would not expect two artifacts with different names to actually # be aliases of one another) - yield name, pickle.loads(blob) + yield name, deserializer.deserialize([blob], metadata) @require_mode("r") def get_artifact_sizes(self, names): diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index 3fc1d3f8db6..04e7468ce1f 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -187,6 +187,12 @@ ("conda_environment_yml_parser", ".pypi.parsers.conda_environment_yml_parser"), ] +# Add artifact serializers here. Ordering is by PRIORITY (lower = tried first). +# PickleSerializer is the universal fallback (PRIORITY=9999). +ARTIFACT_SERIALIZERS_DESC = [ + ("pickle", ".datastores.serializers.pickle_serializer.PickleSerializer"), +] + process_plugins(globals()) @@ -281,3 +287,12 @@ def _import_tl_plugins(globals_dict): for name, p in TL_PLUGINS.items(): globals_dict[name] = p + + +# Drive every ARTIFACT_SERIALIZERS_DESC entry through the serializer state +# machine — this is what makes serializer classes reachable via +# ``SerializerStore.get_ordered_serializers()`` at dispatch time. +from metaflow.datastore.artifacts.serializer import SerializerStore as _SerializerStore + +_SerializerStore.bootstrap() +del _SerializerStore diff --git a/metaflow/plugins/datastores/serializers/__init__.py b/metaflow/plugins/datastores/serializers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/plugins/datastores/serializers/pickle_serializer.py b/metaflow/plugins/datastores/serializers/pickle_serializer.py new file mode 100644 index 00000000000..5e7f7fee2e8 --- /dev/null +++ b/metaflow/plugins/datastores/serializers/pickle_serializer.py @@ -0,0 +1,62 @@ +import pickle + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationFormat, + SerializationMetadata, + SerializedBlob, +) + + +class PickleSerializer(ArtifactSerializer): + """ + Default serializer using Python's pickle module. + + This is the universal fallback — can_serialize always returns True. + PRIORITY is set to 9999 so custom serializers are always tried first. + Pickle produces binary bytes, so only the STORAGE format is supported; + callers that need a wire representation should pick a serializer that + implements it (e.g. an IOType- or JSON-based one in an extension). + """ + + TYPE = "pickle" + PRIORITY = 9999 + + _ENCODINGS = frozenset( + ["pickle-v2", "pickle-v4", "gzip+pickle-v2", "gzip+pickle-v4"] + ) + + @classmethod + def can_serialize(cls, obj): + return True + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding in cls._ENCODINGS + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + if format == SerializationFormat.WIRE: + raise NotImplementedError( + "PickleSerializer does not support the WIRE format; pickle " + "produces opaque binary bytes that are not safe to pass as " + "CLI args or inline IPC payloads." + ) + blob = pickle.dumps(obj, protocol=4) + return ( + [SerializedBlob(blob, is_reference=False)], + SerializationMetadata( + obj_type=str(type(obj)), + size=len(blob), + encoding="gzip+pickle-v4", + serializer_info={}, + ), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + if format == SerializationFormat.WIRE: + raise NotImplementedError( + "PickleSerializer does not support the WIRE format." + ) + return pickle.loads(data[0]) diff --git a/test/unit/test_artifact_serializer.py b/test/unit/test_artifact_serializer.py new file mode 100644 index 00000000000..80b2c15dcb0 --- /dev/null +++ b/test/unit/test_artifact_serializer.py @@ -0,0 +1,666 @@ +import pytest + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationFormat, + SerializationMetadata, + SerializedBlob, + SerializerStore, +) + + +# Snapshot the registry before this module's classes are defined. Module-level +# test serializers (_HighPrioritySerializer, ...) self-register at class +# definition time; the module-scoped fixture below removes them at teardown so +# other test modules see an unpolluted registry. +_PRE_IMPORT_SNAPSHOT = dict(SerializerStore._all_serializers) +_PRE_IMPORT_ACTIVE_SNAPSHOT = set(SerializerStore._active_serializers) + + +@pytest.fixture(scope="module", autouse=True) +def _restore_serializer_registry(): + yield + SerializerStore._all_serializers.clear() + SerializerStore._all_serializers.update(_PRE_IMPORT_SNAPSHOT) + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(_PRE_IMPORT_ACTIVE_SNAPSHOT) + SerializerStore._ordered_cache = None + + +# --------------------------------------------------------------------------- +# Helpers — test serializer subclasses defined inside the test module +# --------------------------------------------------------------------------- + + +class _HighPrioritySerializer(ArtifactSerializer): + TYPE = "test_high" + PRIORITY = 10 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_high" + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + blob = obj.encode("utf-8") + return ( + [SerializedBlob(blob)], + SerializationMetadata("str", len(blob), "test_high", {}), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + return data[0].decode("utf-8") + + +class _LowPrioritySerializer(ArtifactSerializer): + TYPE = "test_low" + PRIORITY = 200 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, int) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_low" + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + blob = str(obj).encode("utf-8") + return ( + [SerializedBlob(blob)], + SerializationMetadata("int", len(blob), "test_low", {}), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + return int(data[0].decode("utf-8")) + + +class _SamePrioritySerializer(ArtifactSerializer): + """Same PRIORITY as default (100), registered after _HighPriority and _LowPriority.""" + + TYPE = "test_default_priority" + PRIORITY = 100 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + +# Dispatch is now driven by _active_serializers (post-Phase-6). The metaclass +# only populates _all_serializers; tests that assert against the ordered +# dispatch list must also mark their classes as active. +SerializerStore._active_serializers.add(_HighPrioritySerializer) +SerializerStore._active_serializers.add(_LowPrioritySerializer) +SerializerStore._active_serializers.add(_SamePrioritySerializer) +SerializerStore._ordered_cache = None + + +# --------------------------------------------------------------------------- +# SerializerStore tests +# --------------------------------------------------------------------------- + + +def test_auto_registration(): + """Subclasses with non-None TYPE are auto-registered.""" + assert "test_high" in SerializerStore._all_serializers + assert "test_low" in SerializerStore._all_serializers + assert SerializerStore._all_serializers["test_high"] is _HighPrioritySerializer + assert SerializerStore._all_serializers["test_low"] is _LowPrioritySerializer + + +def test_base_class_not_registered(): + """ArtifactSerializer itself (TYPE=None) is not registered.""" + assert None not in SerializerStore._all_serializers + + +def test_re_registration_overwrites(): + """A second class with the same TYPE overwrites the first (notebook-friendly).""" + original = SerializerStore._all_serializers["test_high"] + try: + + class _ReplacementSerializer(ArtifactSerializer): + TYPE = "test_high" # same as _HighPrioritySerializer + PRIORITY = 1 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize( + cls, data, metadata=None, format=SerializationFormat.STORAGE + ): + raise NotImplementedError + + assert SerializerStore._all_serializers["test_high"] is _ReplacementSerializer + finally: + SerializerStore._all_serializers["test_high"] = original + SerializerStore._ordered_cache = None + + +def test_priority_ordering(): + """get_ordered_serializers returns lower PRIORITY first.""" + ordered = SerializerStore.get_ordered_serializers() + priorities = [s.PRIORITY for s in ordered] + assert priorities == sorted(priorities) + + +def test_priority_tie_last_wins(): + """When PRIORITY is equal, last-registered wins the tie.""" + + class _TieFirst(ArtifactSerializer): + TYPE = "test_tie_first" + PRIORITY = 123 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + class _TieSecond(ArtifactSerializer): + TYPE = "test_tie_second" + PRIORITY = 123 # same as _TieFirst + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + try: + SerializerStore._active_serializers.add(_TieFirst) + SerializerStore._active_serializers.add(_TieSecond) + SerializerStore._ordered_cache = None + ordered = SerializerStore.get_ordered_serializers() + idx_first = ordered.index(_TieFirst) + idx_second = ordered.index(_TieSecond) + # _TieSecond was registered LAST, so it should appear BEFORE _TieFirst. + assert idx_second < idx_first, ( + "Expected last-registered (_TieSecond) to come first; got " + "_TieFirst at index %d, _TieSecond at index %d" % (idx_first, idx_second) + ) + finally: + SerializerStore._all_serializers.pop("test_tie_first", None) + SerializerStore._all_serializers.pop("test_tie_second", None) + SerializerStore._active_serializers.discard(_TieFirst) + SerializerStore._active_serializers.discard(_TieSecond) + SerializerStore._ordered_cache = None + + +def test_deterministic_ordering(): + """Calling get_ordered_serializers twice returns the same order.""" + first = SerializerStore.get_ordered_serializers() + second = SerializerStore.get_ordered_serializers() + assert [s.TYPE for s in first] == [s.TYPE for s in second] + + +def test_high_priority_before_low(): + """_HighPrioritySerializer (PRIORITY=10) comes before _LowPrioritySerializer (PRIORITY=200).""" + ordered = SerializerStore.get_ordered_serializers() + types = [s.TYPE for s in ordered] + assert types.index("test_high") < types.index("test_low") + + +# --------------------------------------------------------------------------- +# SerializationMetadata tests +# --------------------------------------------------------------------------- + + +def test_metadata_fields(): + meta = SerializationMetadata( + obj_type="dict", + size=1024, + encoding="pickle-v4", + serializer_info={"key": "value"}, + ) + assert meta.obj_type == "dict" + assert meta.size == 1024 + assert meta.encoding == "pickle-v4" + assert meta.serializer_info == {"key": "value"} + + +def test_metadata_is_namedtuple(): + meta = SerializationMetadata("str", 10, "utf-8", {}) + assert isinstance(meta, tuple) + assert len(meta) == 4 + + +# --------------------------------------------------------------------------- +# SerializedBlob tests +# --------------------------------------------------------------------------- + + +def test_blob_bytes_auto_detect(): + """bytes value auto-detects as not a reference.""" + blob = SerializedBlob(b"hello") + assert blob.is_reference is False + assert blob.needs_save is True + + +def test_blob_str_auto_detect(): + """str value auto-detects as a reference.""" + blob = SerializedBlob("sha1_key_abc123") + assert blob.is_reference is True + assert blob.needs_save is False + + +def test_blob_explicit_is_reference_override(): + """Explicit is_reference overrides auto-detection.""" + # bytes but marked as reference (edge case) + blob = SerializedBlob(b"data", is_reference=True) + assert blob.is_reference is True + assert blob.needs_save is False + + # str but marked as not a reference (edge case) + blob = SerializedBlob("inline_data", is_reference=False) + assert blob.is_reference is False + assert blob.needs_save is True + + +def test_blob_value_preserved(): + data = b"\x00\x01\x02\x03" + blob = SerializedBlob(data) + assert blob.value is data + + key = "abc123def456" + blob = SerializedBlob(key) + assert blob.value is key + + +def test_blob_rejects_invalid_types(): + """SerializedBlob must be str or bytes — reject everything else.""" + for bad_value in [123, 3.14, None, [], {}]: + with pytest.raises(TypeError, match="must be str or bytes"): + SerializedBlob(bad_value) + + +# --------------------------------------------------------------------------- +# Wire vs storage format dispatch +# --------------------------------------------------------------------------- + + +class _DualFormatSerializer(ArtifactSerializer): + """Toy serializer that implements both formats for str objects.""" + + TYPE = "test_dual_format" + PRIORITY = 40 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_dual_format" + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + if format == SerializationFormat.WIRE: + return obj + blob = obj.encode("utf-8") + return ( + [SerializedBlob(blob)], + SerializationMetadata("str", len(blob), "test_dual_format", {}), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + if format == SerializationFormat.WIRE: + return data + return data[0].decode("utf-8") + + +def test_format_enum_values(): + assert SerializationFormat.STORAGE.value == "storage" + assert SerializationFormat.WIRE.value == "wire" + # str-backed Enum, so direct string comparison still works. + assert SerializationFormat.STORAGE == "storage" + assert SerializationFormat.WIRE == "wire" + + +def test_dual_format_storage_roundtrip(): + blobs, meta = _DualFormatSerializer.serialize("hello") + assert meta.encoding == "test_dual_format" + assert ( + _DualFormatSerializer.deserialize([b.value for b in blobs], metadata=meta) + == "hello" + ) + + +def test_dual_format_wire_roundtrip(): + wire = _DualFormatSerializer.serialize("hello", format=SerializationFormat.WIRE) + assert isinstance(wire, str) + assert ( + _DualFormatSerializer.deserialize(wire, format=SerializationFormat.WIRE) + == "hello" + ) + + +def test_pickle_serializer_rejects_wire(): + from metaflow.plugins.datastores.serializers.pickle_serializer import ( + PickleSerializer, + ) + + with pytest.raises(NotImplementedError): + PickleSerializer.serialize(42, format=SerializationFormat.WIRE) + with pytest.raises(NotImplementedError): + PickleSerializer.deserialize("42", format=SerializationFormat.WIRE) + + +def test_priority_tie_lexicographic_fallback(): + """When PRIORITY and registration index both tie (simulated), class_path lex-sort wins.""" + + # Within a single process, registration indices are always unique, so + # to actually exercise the tertiary key we construct two classes with + # identical (PRIORITY, registration_index) by manipulating the combined + # list passed to the sort logic. We do this by calling the internal + # sort key directly. + class _AClass: + __module__ = "z.module" + __qualname__ = "AClass" + PRIORITY = 100 + + class _BClass: + __module__ = "a.module" + __qualname__ = "BClass" + PRIORITY = 100 + + # Same registration index (simulated): the (priority, -idx) prefix ties. + keys = [ + (_AClass.PRIORITY, 0, "%s.%s" % (_AClass.__module__, _AClass.__qualname__)), + (_BClass.PRIORITY, 0, "%s.%s" % (_BClass.__module__, _BClass.__qualname__)), + ] + sorted_keys = sorted(keys) + # "a.module.BClass" < "z.module.AClass" lexicographically + assert sorted_keys[0][2] == "a.module.BClass" + assert sorted_keys[1][2] == "z.module.AClass" + + +def test_setup_imports_default_is_noop(): + """Default setup_imports should be callable and do nothing.""" + + class _NoOverride(ArtifactSerializer): + TYPE = "test_no_override" + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + try: + result = _NoOverride.setup_imports() + assert result is None + result = _NoOverride.setup_imports(context="anything") + assert result is None + finally: + SerializerStore._all_serializers.pop("test_no_override", None) + SerializerStore._ordered_cache = None + + +def test_lazy_import_happy_path(): + """lazy_import imports the module, stashes on cls at the leaf alias, and returns it.""" + + class _LazyOk(ArtifactSerializer): + TYPE = "test_lazy_ok" + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + try: + mod = _LazyOk.lazy_import("json") + import json as _json + + assert mod is _json + assert _LazyOk.json is _json + finally: + SerializerStore._all_serializers.pop("test_lazy_ok", None) + SerializerStore._ordered_cache = None + if hasattr(_LazyOk, "json"): + delattr(_LazyOk, "json") + + +def test_lazy_import_custom_alias(): + """alias= overrides the default leaf-name stash key.""" + + class _LazyAlias(ArtifactSerializer): + TYPE = "test_lazy_alias" + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + try: + _LazyAlias.lazy_import("json", alias="j") + import json as _json + + assert _LazyAlias.j is _json + finally: + SerializerStore._all_serializers.pop("test_lazy_alias", None) + SerializerStore._ordered_cache = None + if hasattr(_LazyAlias, "j"): + delattr(_LazyAlias, "j") + + +def test_lazy_import_rejects_reserved_names(): + """Attempting to shadow TYPE / PRIORITY / dispatch methods raises.""" + + class _LazyReserved(ArtifactSerializer): + TYPE = "test_lazy_reserved" + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + try: + for bad in [ + "TYPE", + "PRIORITY", + "serialize", + "deserialize", + "can_serialize", + "can_deserialize", + "setup_imports", + "lazy_import", + "_secret", + ]: + with pytest.raises(ValueError, match="reserved or invalid"): + _LazyReserved.lazy_import("json", alias=bad) + finally: + SerializerStore._all_serializers.pop("test_lazy_reserved", None) + SerializerStore._ordered_cache = None + + +def test_lazy_import_rejects_double_assignment(): + """Calling lazy_import twice with the same alias on the same cls raises.""" + + class _LazyDup(ArtifactSerializer): + TYPE = "test_lazy_dup" + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + try: + _LazyDup.lazy_import("json") + with pytest.raises(ValueError, match="already set"): + _LazyDup.lazy_import("sys", alias="json") + finally: + SerializerStore._all_serializers.pop("test_lazy_dup", None) + SerializerStore._ordered_cache = None + if hasattr(_LazyDup, "json"): + delattr(_LazyDup, "json") + if hasattr(_LazyDup, "_lazy_imported_names"): + delattr(_LazyDup, "_lazy_imported_names") + + +def test_setup_imports_accepts_both_signatures(): + """Bootstrap calls setup_imports correctly whether author writes (cls) or (cls, context=None).""" + + class _OneArg(ArtifactSerializer): + TYPE = "test_setup_one_arg" + called = False + + @classmethod + def setup_imports(cls): + cls.called = True + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + class _TwoArg(ArtifactSerializer): + TYPE = "test_setup_two_arg" + called_with = "sentinel" + + @classmethod + def setup_imports(cls, context=None): + cls.called_with = context + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + from metaflow.datastore.artifacts.serializer import _call_setup_imports + + try: + _call_setup_imports(_OneArg, context=None) + assert _OneArg.called is True + + _call_setup_imports(_TwoArg, context="some-ctx") + assert _TwoArg.called_with == "some-ctx" + finally: + SerializerStore._all_serializers.pop("test_setup_one_arg", None) + SerializerStore._all_serializers.pop("test_setup_two_arg", None) + SerializerStore._ordered_cache = None diff --git a/test/unit/test_pickle_serializer.py b/test/unit/test_pickle_serializer.py new file mode 100644 index 00000000000..d08206a88dd --- /dev/null +++ b/test/unit/test_pickle_serializer.py @@ -0,0 +1,197 @@ +import pickle + +import pytest + +from metaflow.datastore.artifacts.serializer import ( + SerializationMetadata, + SerializerStore, +) +from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + + +# --------------------------------------------------------------------------- +# Registration and identity +# --------------------------------------------------------------------------- + + +def test_type_is_pickle(): + assert PickleSerializer.TYPE == "pickle" + + +def test_priority_is_fallback(): + assert PickleSerializer.PRIORITY == 9999 + + +def test_registered_in_store(): + assert "pickle" in SerializerStore._all_serializers + assert SerializerStore._all_serializers["pickle"] is PickleSerializer + + +def test_last_in_ordering(): + """PickleSerializer should be last (highest PRIORITY) among registered serializers.""" + # Dispatch is driven by _active_serializers (post-Phase-6). Ensure Pickle + # is active for this test regardless of whether bootstrap() has already + # run in the current process. + was_active = PickleSerializer in SerializerStore._active_serializers + SerializerStore._active_serializers.add(PickleSerializer) + SerializerStore._ordered_cache = None + try: + ordered = SerializerStore.get_ordered_serializers() + assert ordered[-1] is PickleSerializer + finally: + if not was_active: + SerializerStore._active_serializers.discard(PickleSerializer) + SerializerStore._ordered_cache = None + + +# --------------------------------------------------------------------------- +# can_serialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "obj", + [ + 42, + "hello", + 3.14, + None, + True, + [1, 2, 3], + {"key": "value"}, + (1, "a"), + set([1, 2]), + b"bytes", + object(), + ], + ids=[ + "int", + "str", + "float", + "None", + "bool", + "list", + "dict", + "tuple", + "set", + "bytes", + "object", + ], +) +def test_can_serialize_any_object(obj): + assert PickleSerializer.can_serialize(obj) is True + + +# --------------------------------------------------------------------------- +# can_deserialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "encoding", + ["pickle-v2", "pickle-v4", "gzip+pickle-v2", "gzip+pickle-v4"], +) +def test_can_deserialize_valid_encodings(encoding): + meta = SerializationMetadata("object", 100, encoding, {}) + assert PickleSerializer.can_deserialize(meta) is True + + +@pytest.mark.parametrize( + "encoding", + ["json", "iotype:text", "msgpack", "unknown", ""], +) +def test_cannot_deserialize_unknown_encodings(encoding): + meta = SerializationMetadata("object", 100, encoding, {}) + assert PickleSerializer.can_deserialize(meta) is False + + +# --------------------------------------------------------------------------- +# serialize +# --------------------------------------------------------------------------- + + +def test_serialize_returns_single_blob(): + blobs, meta = PickleSerializer.serialize({"key": "value"}) + assert len(blobs) == 1 + assert blobs[0].needs_save is True + assert blobs[0].is_reference is False + + +def test_serialize_metadata_encoding(): + _, meta = PickleSerializer.serialize(42) + assert meta.encoding == "pickle-v4" + + +def test_serialize_metadata_type(): + _, meta = PickleSerializer.serialize([1, 2, 3]) + assert "list" in meta.obj_type + + +def test_serialize_metadata_size(): + obj = {"a": 1, "b": 2} + blobs, meta = PickleSerializer.serialize(obj) + assert meta.size == len(blobs[0].value) + assert meta.size > 0 + + +def test_serialize_metadata_serializer_info_empty(): + _, meta = PickleSerializer.serialize("hello") + assert meta.serializer_info == {} + + +# --------------------------------------------------------------------------- +# Round-trip: serialize -> deserialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "obj", + [ + 42, + "hello world", + 3.14, + None, + True, + False, + [1, "two", 3.0], + {"nested": {"key": [1, 2, 3]}}, + (1, 2, 3), + set([1, 2, 3]), + b"raw bytes", + ], + ids=[ + "int", + "str", + "float", + "None", + "True", + "False", + "list", + "nested_dict", + "tuple", + "set", + "bytes", + ], +) +def test_round_trip(obj): + blobs, meta = PickleSerializer.serialize(obj) + raw_blobs = [b.value for b in blobs] + result = PickleSerializer.deserialize(raw_blobs, meta) + assert result == obj + + +class _CustomObj: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return isinstance(other, _CustomObj) and self.x == other.x + + +def test_round_trip_custom_class(): + obj = _CustomObj(42) + blobs, meta = PickleSerializer.serialize(obj) + raw_blobs = [b.value for b in blobs] + result = PickleSerializer.deserialize(raw_blobs, meta) + assert result == obj + assert result.x == 42 diff --git a/test/unit/test_serializer_integration.py b/test/unit/test_serializer_integration.py new file mode 100644 index 00000000000..97ab8c42c58 --- /dev/null +++ b/test/unit/test_serializer_integration.py @@ -0,0 +1,584 @@ +""" +Integration tests for the pluggable serializer framework wired into TaskDataStore. + +Tests that: +- PickleSerializer handles standard Python objects through save/load_artifacts +- Custom serializers take priority over PickleSerializer +- Backward compat: old artifacts (without serializer_info) still load +- Metadata includes serializer_info when present +""" + +import json +import os +import shutil +import tempfile + +import pytest + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, + SerializerStore, +) +from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + + +# --------------------------------------------------------------------------- +# Test PickleSerializer round-trip through save/load artifacts +# --------------------------------------------------------------------------- + + +@pytest.fixture +def task_datastore(tmp_path): + """Create a minimal TaskDataStore wired to a local storage backend.""" + from metaflow.datastore.flow_datastore import FlowDataStore + from metaflow.plugins.datastores.local_storage import LocalStorage + + storage_root = str(tmp_path / "datastore") + os.makedirs(storage_root, exist_ok=True) + + flow_ds = FlowDataStore( + flow_name="TestFlow", + environment=None, + metadata=None, + event_logger=None, + monitor=None, + storage_impl=LocalStorage, + ds_root=storage_root, + ) + + task_ds = flow_ds.get_task_datastore( + run_id="1", + step_name="start", + task_id="1", + attempt=0, + mode="w", + ) + task_ds.init_task() + # Isolate from test serializers registered by other test files. + # Only use PickleSerializer (as the plugin system would provide). + task_ds._serializers = [PickleSerializer] + return task_ds + + +def test_save_load_pickle_round_trip(task_datastore): + """Standard Python objects go through PickleSerializer and round-trip.""" + artifacts = [ + ("my_dict", {"key": "value", "nested": [1, 2, 3]}), + ("my_int", 42), + ("my_str", "hello world"), + ("my_none", None), + ] + task_datastore.save_artifacts(iter(artifacts)) + + # Verify metadata + for name, _ in artifacts: + info = task_datastore._info[name] + assert "encoding" in info + assert info["encoding"] == "pickle-v4" + assert info["size"] > 0 + assert "type" in info + + # Load and verify + loaded = dict(task_datastore.load_artifacts([name for name, _ in artifacts])) + assert loaded["my_dict"] == {"key": "value", "nested": [1, 2, 3]} + assert loaded["my_int"] == 42 + assert loaded["my_str"] == "hello world" + assert loaded["my_none"] is None + + +def test_distinct_objects_on_load(task_datastore): + """Loading the same artifact twice yields distinct object instances.""" + shared_list = [1, 2, 3] + task_datastore.save_artifacts(iter([("a", shared_list), ("b", shared_list)])) + + loaded = dict(task_datastore.load_artifacts(["a", "b"])) + assert loaded["a"] == loaded["b"] + assert loaded["a"] is not loaded["b"] # distinct instances + + +def test_metadata_auto_populates_source_for_pickle(task_datastore): + """PickleSerializer returns empty serializer_info, but save_artifacts + auto-injects ``source`` from the bootstrap-time record so load errors + can tell the user which package provides the missing serializer.""" + task_datastore.save_artifacts(iter([("x", 42)])) + info = task_datastore._info["x"] + assert info.get("serializer_info", {}).get("source") == "metaflow" + + +def test_author_source_is_not_overridden(task_datastore): + """A serializer that sets its own ``source`` in serializer_info should + not have it overridden by the auto-injected bootstrap source.""" + from metaflow.datastore.artifacts import SerializationFormat, SerializerStore + from metaflow.datastore.artifacts.diagnostic import ( + SerializerRecord, + SerializerState, + ) + + class _ExplicitSourceSerializer(ArtifactSerializer): + TYPE = "test_explicit_source" + PRIORITY = 1 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_explicit_source" + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + blob = obj.encode("utf-8") + return ( + [SerializedBlob(blob, is_reference=False)], + SerializationMetadata( + obj_type="str", + size=len(blob), + encoding="test_explicit_source", + serializer_info={"source": "i-picked-this-myself"}, + ), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + return data[0].decode("utf-8") + + # Seed a record so get_source_for would try to inject "some-extension" + # — the author's explicit source should still win. + rec = SerializerRecord( + name="test_explicit_source", + class_path="inline.ExplicitSourceSerializer", + state=SerializerState.ACTIVE, + type="test_explicit_source", + source="some-extension", + ) + SerializerStore._records["test_explicit_source"] = rec + SerializerStore._active_serializers.add(_ExplicitSourceSerializer) + + task_datastore._serializers = [_ExplicitSourceSerializer, PickleSerializer] + + try: + task_datastore.save_artifacts(iter([("hello", "world")])) + info = task_datastore._info["hello"] + assert info["serializer_info"]["source"] == "i-picked-this-myself" + finally: + SerializerStore._records.pop("test_explicit_source", None) + SerializerStore._active_serializers.discard(_ExplicitSourceSerializer) + SerializerStore._all_serializers.pop("test_explicit_source", None) + SerializerStore._ordered_cache = None + + +# --------------------------------------------------------------------------- +# Test custom serializer takes priority +# --------------------------------------------------------------------------- + + +def test_custom_serializer_takes_priority(task_datastore): + """A custom serializer with lower PRIORITY claims matching objects over pickle.""" + + # Define and register a custom serializer inside the test + class _JsonStringSerializer(ArtifactSerializer): + TYPE = "test_json_str" + PRIORITY = 50 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_json_str" + + @classmethod + def serialize(cls, obj): + blob = json.dumps(obj).encode("utf-8") + return ( + [SerializedBlob(blob, is_reference=False)], + SerializationMetadata( + obj_type="str", + size=len(blob), + encoding="test_json_str", + serializer_info={"format": "json-utf8"}, + ), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format="storage"): + return json.loads(data[0].decode("utf-8")) + + # Explicitly set serializers: custom first, then pickle fallback. + # Don't use get_ordered_serializers() to avoid pollution from other test files. + task_datastore._serializers = [_JsonStringSerializer, PickleSerializer] + + try: + task_datastore.save_artifacts(iter([("msg", "hello"), ("num", 42)])) + + # "msg" should use our custom serializer (str → _JsonStringSerializer) + msg_info = task_datastore._info["msg"] + assert msg_info["encoding"] == "test_json_str" + assert msg_info["serializer_info"] == {"format": "json-utf8"} + + # "num" should fall through to PickleSerializer (int → not claimed by custom) + num_info = task_datastore._info["num"] + assert num_info["encoding"] == "pickle-v4" + + # Both round-trip correctly + loaded = dict(task_datastore.load_artifacts(["msg", "num"])) + assert loaded["msg"] == "hello" + assert loaded["num"] == 42 + finally: + SerializerStore._all_serializers.pop("test_json_str", None) + SerializerStore._ordered_cache = None + + +# --------------------------------------------------------------------------- +# Backward compat: old metadata format +# --------------------------------------------------------------------------- + + +def test_backward_compat_old_metadata(task_datastore): + """Artifacts saved with old metadata format (no serializer_info) still load.""" + # Save normally first + task_datastore.save_artifacts(iter([("old_artifact", {"a": 1})])) + + # Simulate old metadata format: no serializer_info, old encoding + task_datastore._info["old_artifact"] = { + "size": 100, + "type": "", + "encoding": "gzip+pickle-v4", + # no "serializer_info" key + } + + # Should still load via PickleSerializer (can_deserialize handles gzip+pickle-v4) + loaded = dict(task_datastore.load_artifacts(["old_artifact"])) + assert loaded["old_artifact"] == {"a": 1} + + +def test_backward_compat_no_encoding(task_datastore): + """Very old artifacts without encoding field default to gzip+pickle-v2.""" + # Save an artifact + task_datastore.save_artifacts(iter([("ancient", 99)])) + + # Simulate very old metadata: no encoding, no serializer_info + task_datastore._info["ancient"] = { + "size": 10, + "type": "", + # no "encoding" key — defaults to gzip+pickle-v2 + } + + # Should still load + loaded = dict(task_datastore.load_artifacts(["ancient"])) + assert loaded["ancient"] == 99 + + +# --------------------------------------------------------------------------- +# Dynamic registry: lazy registrations reach long-lived datastores +# --------------------------------------------------------------------------- + + +def test_post_init_registration_reaches_existing_datastore(task_datastore): + """A serializer registered AFTER the datastore was constructed must still + be visible. Without the dynamic ``_serializers`` property, lazy imports + (e.g. ``import torch`` after ``TaskDataStore.__init__``) would be silently + ignored for that instance. + """ + # Drop the test override so the property falls back to the live registry. + task_datastore._serializers = None + + class _PostInitSerializer(ArtifactSerializer): + TYPE = "test_post_init_registration" + PRIORITY = 5 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format="storage"): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format="storage"): + raise NotImplementedError + + # Dispatch reads from _active_serializers now (post-Phase-6). + SerializerStore._active_serializers.add(_PostInitSerializer) + SerializerStore._ordered_cache = None + try: + assert _PostInitSerializer in task_datastore._serializers + finally: + SerializerStore._all_serializers.pop("test_post_init_registration", None) + SerializerStore._active_serializers.discard(_PostInitSerializer) + SerializerStore._ordered_cache = None + + +# --------------------------------------------------------------------------- +# Blob-count validation must happen before ``_info`` is mutated +# --------------------------------------------------------------------------- + + +def test_info_not_populated_when_serializer_returns_no_blobs(task_datastore): + """ + Regression for the "_info[name] poisoned on validation failure" bug: if a + serializer returns an empty blob list, ``save_artifacts`` must raise + without leaving partial metadata in ``_info``. + """ + from metaflow.datastore.exceptions import DataException + + class _EmptyBlobSerializer(ArtifactSerializer): + TYPE = "test_empty_blob" + PRIORITY = 5 + + @classmethod + def can_serialize(cls, obj): + return True + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format="storage"): + return ( + [], + SerializationMetadata("x", 0, "test_empty_blob", {}), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format="storage"): + raise NotImplementedError + + task_datastore._serializers = [_EmptyBlobSerializer, PickleSerializer] + try: + with pytest.raises(DataException, match="returned no blobs"): + task_datastore.save_artifacts(iter([("bad", object())])) + assert "bad" not in task_datastore._info + finally: + SerializerStore._all_serializers.pop("test_empty_blob", None) + SerializerStore._ordered_cache = None + + +def test_info_not_populated_when_serializer_returns_multi_blob(task_datastore): + """Same guarantee as above for the multi-blob rejection path.""" + from metaflow.datastore.exceptions import DataException + + class _MultiBlobSerializer(ArtifactSerializer): + TYPE = "test_multi_blob" + PRIORITY = 5 + + @classmethod + def can_serialize(cls, obj): + return True + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format="storage"): + return ( + [SerializedBlob(b"a"), SerializedBlob(b"b")], + SerializationMetadata("x", 2, "test_multi_blob", {}), + ) + + @classmethod + def deserialize(cls, data, metadata=None, format="storage"): + raise NotImplementedError + + task_datastore._serializers = [_MultiBlobSerializer, PickleSerializer] + try: + with pytest.raises(DataException, match="single-blob serializers"): + task_datastore.save_artifacts(iter([("bad", object())])) + assert "bad" not in task_datastore._info + finally: + SerializerStore._all_serializers.pop("test_multi_blob", None) + SerializerStore._ordered_cache = None + + +def test_can_serialize_exception_falls_through_to_pickle(task_datastore): + """A buggy custom serializer's can_serialize exception must NOT crash + save_artifacts. The buggy serializer is skipped; pickle fallback handles + the artifact; dispatch_error_count is incremented.""" + from metaflow.datastore.artifacts import SerializationFormat + from metaflow.datastore.artifacts.diagnostic import ( + SerializerRecord, + SerializerState, + ) + + class _BuggyCanSerialize(ArtifactSerializer): + TYPE = "test_buggy_cs" + PRIORITY = 1 # tried first + + @classmethod + def can_serialize(cls, obj): + raise RuntimeError("intentional bug in can_serialize") + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + # Seed a diagnostic record so dispatch_error_count has somewhere to go. + rec = SerializerRecord( + name="test_buggy_cs", + class_path="test.inline.BuggyCanSerialize", + state=SerializerState.ACTIVE, + type="test_buggy_cs", + priority=1, + ) + SerializerStore._records["test_buggy_cs"] = rec + SerializerStore._active_serializers.add(_BuggyCanSerialize) + + task_datastore._serializers = [_BuggyCanSerialize, PickleSerializer] + + try: + # Must NOT raise. + task_datastore.save_artifacts(iter([("x", 42)])) + assert task_datastore._info["x"]["encoding"] == "pickle-v4" + assert rec.dispatch_error_count == 1 + assert rec.last_error is not None + assert "RuntimeError" in rec.last_error + finally: + SerializerStore._all_serializers.pop("test_buggy_cs", None) + SerializerStore._active_serializers.discard(_BuggyCanSerialize) + SerializerStore._records.pop("test_buggy_cs", None) + SerializerStore._ordered_cache = None + + +def test_can_deserialize_exception_falls_through(task_datastore): + """Same guarantee for can_deserialize during load_artifacts.""" + from metaflow.datastore.artifacts import SerializationFormat + from metaflow.datastore.artifacts.diagnostic import ( + SerializerRecord, + SerializerState, + ) + + class _BuggyCanDeserialize(ArtifactSerializer): + TYPE = "test_buggy_cd" + PRIORITY = 1 + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + raise RuntimeError("intentional bug in can_deserialize") + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + rec = SerializerRecord( + name="test_buggy_cd", + class_path="test.inline.BuggyCanDeserialize", + state=SerializerState.ACTIVE, + type="test_buggy_cd", + priority=1, + ) + SerializerStore._records["test_buggy_cd"] = rec + SerializerStore._active_serializers.add(_BuggyCanDeserialize) + + # First save an artifact normally via pickle so load has something to load. + task_datastore._serializers = [PickleSerializer] + task_datastore.save_artifacts(iter([("y", "hello")])) + + # Now install the buggy serializer and try to load — buggy can_deserialize + # should be skipped and pickle should take over. + task_datastore._serializers = [_BuggyCanDeserialize, PickleSerializer] + + try: + loaded = dict(task_datastore.load_artifacts(["y"])) + assert loaded["y"] == "hello" + assert rec.dispatch_error_count == 1 + assert rec.last_error is not None + assert "RuntimeError" in rec.last_error + finally: + SerializerStore._all_serializers.pop("test_buggy_cd", None) + SerializerStore._active_serializers.discard(_BuggyCanDeserialize) + SerializerStore._records.pop("test_buggy_cd", None) + SerializerStore._ordered_cache = None + + +def test_subclass_lazy_import_stashes_on_child_not_parent(): + """lazy_import on a subclass should set attrs on the subclass, not the parent. + Parent and children should each have their own _lazy_imported_names set.""" + from metaflow.datastore.artifacts import ( + ArtifactSerializer, + SerializationFormat, + SerializerStore, + ) + + class _ParentSer(ArtifactSerializer): + TYPE = "test_inherit_parent" + + @classmethod + def setup_imports(cls, context=None): + cls.lazy_import("json") + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + class _ChildSer(_ParentSer): + TYPE = "test_inherit_child" + + @classmethod + def setup_imports(cls, context=None): + cls.lazy_import("sys") + + try: + _ParentSer.setup_imports() + _ChildSer.setup_imports() + import json as _json + import sys as _sys + + # Parent has json; child has sys + assert _ParentSer.json is _json + assert _ChildSer.sys is _sys + + # Each class should have its OWN _lazy_imported_names set + # (not a shared inherited one) + parent_names = _ParentSer.__dict__.get("_lazy_imported_names", set()) + child_names = _ChildSer.__dict__.get("_lazy_imported_names", set()) + assert parent_names == {"json"} + assert child_names == {"sys"} + finally: + for t in ("test_inherit_parent", "test_inherit_child"): + SerializerStore._all_serializers.pop(t, None) + SerializerStore._ordered_cache = None + for c, attr in ((_ParentSer, "json"), (_ChildSer, "sys")): + if attr in c.__dict__: + delattr(c, attr) + for c in (_ParentSer, _ChildSer): + if "_lazy_imported_names" in c.__dict__: + delattr(c, "_lazy_imported_names") diff --git a/test/unit/test_serializer_lifecycle.py b/test/unit/test_serializer_lifecycle.py new file mode 100644 index 00000000000..d201f06587a --- /dev/null +++ b/test/unit/test_serializer_lifecycle.py @@ -0,0 +1,785 @@ +"""Tests for the serializer lifecycle (state machine, bootstrap, diagnostics).""" + +import sys +import types + +import pytest + +from metaflow.datastore.artifacts.diagnostic import ( + SerializerRecord, + SerializerState, +) + + +def test_serializer_record_default_fields(): + rec = SerializerRecord(name="pickle", class_path="m.pkl.Pickle") + assert rec.name == "pickle" + assert rec.class_path == "m.pkl.Pickle" + assert rec.state == SerializerState.KNOWN + assert rec.awaiting_modules == [] + assert rec.last_error is None + assert rec.priority is None + assert rec.type is None + assert rec.import_trigger is None + assert rec.dispatch_error_count == 0 + + +def test_serializer_record_as_dict(): + rec = SerializerRecord( + name="torch", + class_path="ext.t.TorchSerializer", + state=SerializerState.ACTIVE, + awaiting_modules=[], + last_error=None, + priority=50, + type="torch", + import_trigger="eager", + dispatch_error_count=0, + ) + d = rec.as_dict() + assert d["name"] == "torch" + assert d["class_path"] == "ext.t.TorchSerializer" + assert d["state"] == "active" + assert d["priority"] == 50 + assert d["type"] == "torch" + assert d["import_trigger"] == "eager" + + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializerStore, + SerializationFormat, +) + + +def test_store_separates_all_vs_active(): + """_all_serializers is the known-classes index; _active_serializers is dispatch pool.""" + + class _Known(ArtifactSerializer): + TYPE = "test_known_not_active" + + @classmethod + def can_serialize(cls, obj): + return False + + @classmethod + def can_deserialize(cls, metadata): + return False + + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError + + try: + # Metaclass registers on class body execution + assert _Known in SerializerStore._all_serializers.values() + # But without bootstrap, it is NOT in the dispatch pool + assert _Known not in SerializerStore._active_serializers + # _records is an empty dict initially (for entries from DESC tuples) + assert isinstance(SerializerStore._records, dict) + finally: + SerializerStore._all_serializers.pop("test_known_not_active", None) + SerializerStore._active_serializers.discard(_Known) + SerializerStore._ordered_cache = None + + +def test_bootstrap_activates_dependency_free_serializer(): + """bootstrap_entries with an in-process serializer moves it to ACTIVE.""" + + mod = types.ModuleType("_test_bootstrap_mod") + source = """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _BootProbe(ArtifactSerializer): + TYPE = "test_bootstrap_probe" + PRIORITY = 60 + + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""" + exec(source, mod.__dict__) + sys.modules["_test_bootstrap_mod"] = mod + try: + SerializerStore.bootstrap_entries( + [ + ("test_bootstrap_probe", "_test_bootstrap_mod._BootProbe"), + ] + ) + rec = SerializerStore._records["test_bootstrap_probe"] + assert rec.state == SerializerState.ACTIVE + assert rec.priority == 60 + assert rec.type == "test_bootstrap_probe" + assert rec.import_trigger == "eager" + assert mod._BootProbe in SerializerStore._active_serializers + finally: + SerializerStore._all_serializers.pop("test_bootstrap_probe", None) + SerializerStore._active_serializers.discard(mod._BootProbe) + SerializerStore._records.pop("test_bootstrap_probe", None) + SerializerStore._ordered_cache = None + del sys.modules["_test_bootstrap_mod"] + + +def test_bootstrap_rejects_name_type_mismatch(): + """Tuple first element MUST equal class.TYPE.""" + + mod = types.ModuleType("_test_mismatch_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _Mismatch(ArtifactSerializer): + TYPE = "actual_type" + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + mod.__dict__, + ) + sys.modules["_test_mismatch_mod"] = mod + try: + SerializerStore.bootstrap_entries( + [ + ("declared_name", "_test_mismatch_mod._Mismatch"), + ] + ) + rec = SerializerStore._records["declared_name"] + assert rec.state == SerializerState.BROKEN + assert "tuple name" in rec.last_error + assert "actual_type" in rec.last_error + finally: + SerializerStore._all_serializers.pop("actual_type", None) + SerializerStore._records.pop("declared_name", None) + del sys.modules["_test_mismatch_mod"] + + +def test_bootstrap_missing_module_parks_entry(): + """ModuleNotFoundError during import_module moves entry to PENDING_ON_IMPORTS.""" + + SerializerStore.bootstrap_entries( + [ + ("test_absent", "_never_created_module._Absent"), + ] + ) + try: + rec = SerializerStore._records["test_absent"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "_never_created_module" in rec.awaiting_modules + # _pending_by_module should track it + assert "test_absent" in SerializerStore._pending_by_module.get( + "_never_created_module", [] + ) + finally: + SerializerStore._records.pop("test_absent", None) + SerializerStore._pending_by_module.pop("_never_created_module", None) + + +def test_bootstrap_missing_class_in_module_broken(): + """getattr failure after successful module import moves to BROKEN.""" + mod = types.ModuleType("_test_no_class_mod") + # Intentionally empty — no class inside + sys.modules["_test_no_class_mod"] = mod + try: + SerializerStore.bootstrap_entries( + [ + ("test_no_class", "_test_no_class_mod._Missing"), + ] + ) + rec = SerializerStore._records["test_no_class"] + assert rec.state == SerializerState.BROKEN + assert "class" in rec.last_error.lower() + assert "_Missing" in rec.last_error + finally: + SerializerStore._records.pop("test_no_class", None) + del sys.modules["_test_no_class_mod"] + + +def test_bootstrap_setup_imports_missing_dep_parks_entry(): + """ImportError inside setup_imports parks on the missing module name.""" + mod = types.ModuleType("_test_setup_missing_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _WantsMissing(ArtifactSerializer): + TYPE = "test_setup_wants_missing" + @classmethod + def setup_imports(cls, context=None): + cls.lazy_import("absent_at_setup_time_xyz") + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + mod.__dict__, + ) + sys.modules["_test_setup_missing_mod"] = mod + try: + SerializerStore.bootstrap_entries( + [ + ("test_setup_wants_missing", "_test_setup_missing_mod._WantsMissing"), + ] + ) + rec = SerializerStore._records["test_setup_wants_missing"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "absent_at_setup_time_xyz" in rec.awaiting_modules + finally: + SerializerStore._all_serializers.pop("test_setup_wants_missing", None) + SerializerStore._records.pop("test_setup_wants_missing", None) + SerializerStore._pending_by_module.pop("absent_at_setup_time_xyz", None) + del sys.modules["_test_setup_missing_mod"] + + +def test_bootstrap_setup_imports_other_exception_broken(): + """Non-ImportError from setup_imports moves entry to BROKEN.""" + mod = types.ModuleType("_test_setup_boom_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _Boom(ArtifactSerializer): + TYPE = "test_boom" + @classmethod + def setup_imports(cls, context=None): + raise RuntimeError("explicit boom from test") + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + mod.__dict__, + ) + sys.modules["_test_setup_boom_mod"] = mod + try: + SerializerStore.bootstrap_entries( + [ + ("test_boom", "_test_setup_boom_mod._Boom"), + ] + ) + rec = SerializerStore._records["test_boom"] + assert rec.state == SerializerState.BROKEN + assert "RuntimeError" in rec.last_error + assert "explicit boom" in rec.last_error + finally: + SerializerStore._all_serializers.pop("test_boom", None) + SerializerStore._records.pop("test_boom", None) + del sys.modules["_test_setup_boom_mod"] + + +def test_bootstrap_disabled_toggle(): + """Entries whose name is in disabled_names land in DISABLED state.""" + mod = types.ModuleType("_test_disable_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _DisableMe(ArtifactSerializer): + TYPE = "test_disable" + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + mod.__dict__, + ) + sys.modules["_test_disable_mod"] = mod + try: + SerializerStore.bootstrap_entries( + [("test_disable", "_test_disable_mod._DisableMe")], + disabled_names={"test_disable"}, + ) + rec = SerializerStore._records["test_disable"] + assert rec.state == SerializerState.DISABLED + # Class should NOT be in active pool + assert mod._DisableMe not in SerializerStore._active_serializers + finally: + SerializerStore._all_serializers.pop("test_disable", None) + SerializerStore._active_serializers.discard(mod._DisableMe) + SerializerStore._records.pop("test_disable", None) + del sys.modules["_test_disable_mod"] + + +def test_retry_activates_pending_record_on_module_import(): + """When a pending record's awaited module imports, the record retries to ACTIVE.""" + ser_mod = types.ModuleType("_test_retry_ser_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _Pending(ArtifactSerializer): + TYPE = "test_retry_pending" + @classmethod + def setup_imports(cls, context=None): + cls.lazy_import("retry_dep_mod_name") + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + ser_mod.__dict__, + ) + sys.modules["_test_retry_ser_mod"] = ser_mod + sys.modules.pop("retry_dep_mod_name", None) + + try: + SerializerStore.bootstrap_entries( + [ + ("test_retry_pending", "_test_retry_ser_mod._Pending"), + ] + ) + rec = SerializerStore._records["test_retry_pending"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + assert "retry_dep_mod_name" in rec.awaiting_modules + assert ( + "test_retry_pending" + in SerializerStore._pending_by_module["retry_dep_mod_name"] + ) + + # Simulate the dep becoming available, then fire the retry hook. + dep_mod = types.ModuleType("retry_dep_mod_name") + sys.modules["retry_dep_mod_name"] = dep_mod + SerializerStore._on_module_imported("retry_dep_mod_name", dep_mod) + + assert rec.state == SerializerState.ACTIVE + assert rec.import_trigger == "hook" + assert ser_mod._Pending in SerializerStore._active_serializers + # _pending_by_module should no longer list this record under that module + assert "test_retry_pending" not in SerializerStore._pending_by_module.get( + "retry_dep_mod_name", [] + ) + finally: + SerializerStore._all_serializers.pop("test_retry_pending", None) + SerializerStore._active_serializers.discard(ser_mod._Pending) + SerializerStore._records.pop("test_retry_pending", None) + SerializerStore._pending_by_module.pop("retry_dep_mod_name", None) + SerializerStore._ordered_cache = None + sys.modules.pop("_test_retry_ser_mod", None) + sys.modules.pop("retry_dep_mod_name", None) + + +def test_retry_hits_loop_guard_after_repeated_failure(): + """Calling _on_module_imported when setup_imports still fails on the same + module name should transition to BROKEN via the loop guard.""" + ser_mod = types.ModuleType("_test_loop_ser_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _Loopy(ArtifactSerializer): + TYPE = "test_loopy" + @classmethod + def setup_imports(cls, context=None): + # Always raises on the same module name even if retried. + cls.lazy_import("never_resolves_mod") + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + ser_mod.__dict__, + ) + sys.modules["_test_loop_ser_mod"] = ser_mod + sys.modules.pop("never_resolves_mod", None) + + try: + SerializerStore.bootstrap_entries( + [ + ("test_loopy", "_test_loop_ser_mod._Loopy"), + ] + ) + rec = SerializerStore._records["test_loopy"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + + # Fake the dep appearing but NOT actually installing it — the retry + # will re-run setup_imports, which will ImportError again on the same name. + dep_mod = types.ModuleType("never_resolves_mod") + # We deliberately DO NOT put dep_mod in sys.modules, so lazy_import + # inside setup_imports will raise ModuleNotFoundError again. + SerializerStore._on_module_imported("never_resolves_mod", dep_mod) + + assert rec.state == SerializerState.BROKEN + assert "repeated" in rec.last_error.lower() + finally: + SerializerStore._all_serializers.pop("test_loopy", None) + SerializerStore._records.pop("test_loopy", None) + SerializerStore._pending_by_module.pop("never_resolves_mod", None) + sys.modules.pop("_test_loop_ser_mod", None) + sys.modules.pop("never_resolves_mod", None) + + +def test_retry_fires_via_real_import_hook(tmp_path, monkeypatch): + """End-to-end: park a serializer on a missing module, install the hook, + actually import the module, verify the serializer activates.""" + pkg_dir = tmp_path / "fixture_retry_dep" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("VALUE = 42\n") + # NOTE: we prepend syspath AFTER bootstrap_entries below so the first + # lazy_import() fails and the record parks on the missing module. + + ser_mod = types.ModuleType("_test_e2e_ser_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _E2E(ArtifactSerializer): + TYPE = "test_e2e_retry" + @classmethod + def setup_imports(cls, context=None): + cls.lazy_import("fixture_retry_dep") + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + ser_mod.__dict__, + ) + sys.modules["_test_e2e_ser_mod"] = ser_mod + # Make sure the dep module isn't pre-imported + sys.modules.pop("fixture_retry_dep", None) + + try: + SerializerStore.bootstrap_entries( + [ + ("test_e2e_retry", "_test_e2e_ser_mod._E2E"), + ] + ) + rec = SerializerStore._records["test_e2e_retry"] + assert rec.state == SerializerState.PENDING_ON_IMPORTS + + # Now make the dep discoverable on sys.path and actually import it. + monkeypatch.syspath_prepend(str(tmp_path)) + import fixture_retry_dep # noqa: F401 + + # After real import, the hook chain should have fired. + assert rec.state == SerializerState.ACTIVE + assert rec.import_trigger == "hook" + assert ser_mod._E2E in SerializerStore._active_serializers + finally: + SerializerStore._all_serializers.pop("test_e2e_retry", None) + SerializerStore._active_serializers.discard(ser_mod._E2E) + SerializerStore._records.pop("test_e2e_retry", None) + SerializerStore._pending_by_module.pop("fixture_retry_dep", None) + SerializerStore._ordered_cache = None + sys.modules.pop("_test_e2e_ser_mod", None) + sys.modules.pop("fixture_retry_dep", None) + # Clean up interceptor state + from metaflow.datastore.artifacts.lazy_registry import _interceptor + + _interceptor._watched.discard("fixture_retry_dep") + _interceptor._processed.discard("fixture_retry_dep") + + +def test_bootstrap_with_no_extensions_still_runs_core(): + """bootstrap() reads core ARTIFACT_SERIALIZERS_DESC from metaflow.plugins + and activates PickleSerializer.""" + from metaflow.plugins.datastores.serializers.pickle_serializer import ( + PickleSerializer, + ) + + # Snapshot and clear state + saved_active = set(SerializerStore._active_serializers) + saved_records = dict(SerializerStore._records) + SerializerStore._active_serializers.clear() + SerializerStore._records.clear() + + try: + SerializerStore.bootstrap() + # PickleSerializer should be in active pool (core entry) + pickle_active = any( + r.type == "pickle" and r.state == SerializerState.ACTIVE + for r in SerializerStore._records.values() + ) + assert pickle_active + finally: + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(saved_active) + SerializerStore._records.clear() + SerializerStore._records.update(saved_records) + + +def test_bootstrap_stamps_core_source_on_record(): + """Core serializers bootstrap with source='metaflow' on their records.""" + saved_active = set(SerializerStore._active_serializers) + saved_records = dict(SerializerStore._records) + SerializerStore._active_serializers.clear() + SerializerStore._records.clear() + + try: + SerializerStore.bootstrap() + pickle_rec = next( + (r for r in SerializerStore._records.values() if r.type == "pickle"), + None, + ) + assert pickle_rec is not None + assert pickle_rec.source == "metaflow" + finally: + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(saved_active) + SerializerStore._records.clear() + SerializerStore._records.update(saved_records) + + +def test_bootstrap_entries_accepts_source_override(): + """bootstrap_entries accepts ``source`` and attaches it to each record.""" + import sys as _sys + import types as _types + + saved_records = dict(SerializerStore._records) + saved_active = set(SerializerStore._active_serializers) + + ser_mod = _types.ModuleType("_test_source_ser_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _SourceProbe(ArtifactSerializer): + TYPE = "test_source_probe" + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + ser_mod.__dict__, + ) + _sys.modules["_test_source_ser_mod"] = ser_mod + + try: + SerializerStore.bootstrap_entries( + [("test_source_probe", "_test_source_ser_mod._SourceProbe")], + source="fake-extension", + ) + rec = SerializerStore._records["test_source_probe"] + assert rec.source == "fake-extension" + assert SerializerStore.get_source_for(ser_mod._SourceProbe) == "fake-extension" + finally: + SerializerStore._all_serializers.pop("test_source_probe", None) + SerializerStore._records.clear() + SerializerStore._records.update(saved_records) + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(saved_active) + SerializerStore._ordered_cache = None + _sys.modules.pop("_test_source_ser_mod", None) + + +def test_bootstrap_applies_disabled_toggle(monkeypatch): + """bootstrap() respects -name toggles in ENABLED_ARTIFACT_SERIALIZER config.""" + from metaflow import metaflow_config + + saved_active = set(SerializerStore._active_serializers) + saved_records = dict(SerializerStore._records) + SerializerStore._active_serializers.clear() + SerializerStore._records.clear() + + monkeypatch.setattr( + metaflow_config, + "ENABLED_ARTIFACT_SERIALIZER", + ["-pickle"], + raising=False, + ) + try: + SerializerStore.bootstrap() + pickle_rec = next( + (r for r in SerializerStore._records.values() if r.name == "pickle"), + None, + ) + assert pickle_rec is not None + assert pickle_rec.state == SerializerState.DISABLED + finally: + SerializerStore._active_serializers.clear() + SerializerStore._active_serializers.update(saved_active) + SerializerStore._records.clear() + SerializerStore._records.update(saved_records) + + +def test_list_serializer_status_returns_dicts(): + """list_serializer_status returns one dict per _records entry, with the + documented shape.""" + from metaflow.datastore.artifacts import list_serializer_status + from metaflow.datastore.artifacts.diagnostic import ( + SerializerRecord, + SerializerState, + ) + + # Seed a fake record + rec = SerializerRecord( + name="fake_test_serializer", + class_path="inline.FakeSer", + state=SerializerState.ACTIVE, + priority=42, + type="fake_test_serializer", + import_trigger="eager", + ) + SerializerStore._records["fake_test_serializer"] = rec + + try: + status = list_serializer_status() + assert isinstance(status, list) + match = next((s for s in status if s["name"] == "fake_test_serializer"), None) + assert match is not None + assert match["state"] == "active" + assert match["priority"] == 42 + assert match["type"] == "fake_test_serializer" + assert match["import_trigger"] == "eager" + assert match["class_path"] == "inline.FakeSer" + for key in ( + "name", + "class_path", + "state", + "awaiting_modules", + "last_error", + "priority", + "type", + "import_trigger", + "dispatch_error_count", + ): + assert key in match, "missing key '%s' in status dict" % key + finally: + SerializerStore._records.pop("fake_test_serializer", None) + + +def test_reset_for_tests_clears_registry_state(): + """SerializerStore._reset_for_tests clears _records, _active_serializers, + _pending_by_module, _ordered_cache, and per-class _lazy_imported_names.""" + import sys as _sys + import types as _types + from metaflow.datastore.artifacts.lazy_registry import _interceptor + + # Seed state by bootstrapping an inline serializer + ser_mod = _types.ModuleType("_test_reset_ser_mod") + exec( + """ +from metaflow.datastore.artifacts import ArtifactSerializer, SerializationFormat + +class _ResetProbe(ArtifactSerializer): + TYPE = "test_reset_probe" + @classmethod + def setup_imports(cls, context=None): + cls.lazy_import("json") + @classmethod + def can_serialize(cls, obj): return False + @classmethod + def can_deserialize(cls, metadata): return False + @classmethod + def serialize(cls, obj, format=SerializationFormat.STORAGE): + raise NotImplementedError + @classmethod + def deserialize(cls, data, metadata=None, format=SerializationFormat.STORAGE): + raise NotImplementedError +""", + ser_mod.__dict__, + ) + _sys.modules["_test_reset_ser_mod"] = ser_mod + + try: + SerializerStore.bootstrap_entries( + [ + ("test_reset_probe", "_test_reset_ser_mod._ResetProbe"), + ] + ) + # Also seed a pending record to exercise _pending_by_module + SerializerStore.bootstrap_entries( + [ + ("test_reset_pending", "_never_exists_mod._Absent"), + ] + ) + + # Snapshot pre-reset state + assert "test_reset_probe" in SerializerStore._records + assert ser_mod._ResetProbe in SerializerStore._active_serializers + assert "_never_exists_mod" in SerializerStore._pending_by_module + # ResetProbe should have _lazy_imported_names populated + assert "json" in ser_mod._ResetProbe.__dict__.get("_lazy_imported_names", set()) + + # Pre-reset, the interceptor should be watching _never_exists_mod. + assert "_never_exists_mod" in _interceptor._watched + + # Call reset + SerializerStore._reset_for_tests() + + # Post-reset: all registry state empty + assert SerializerStore._records == {} + assert len(SerializerStore._active_serializers) == 0 + assert SerializerStore._pending_by_module == {} + assert SerializerStore._ordered_cache is None + + # The probe class should no longer have stashed attrs + assert "json" not in ser_mod._ResetProbe.__dict__ + assert ser_mod._ResetProbe.__dict__.get("_lazy_imported_names", set()) == set() + + # Interceptor watches should also be cleared + assert "_never_exists_mod" not in _interceptor._watched + finally: + # In case reset didn't clean up (e.g., test failed mid-way) + SerializerStore._all_serializers.pop("test_reset_probe", None) + SerializerStore._records.pop("test_reset_probe", None) + SerializerStore._records.pop("test_reset_pending", None) + SerializerStore._active_serializers.discard(ser_mod._ResetProbe) + SerializerStore._pending_by_module.clear() + SerializerStore._ordered_cache = None + _sys.modules.pop("_test_reset_ser_mod", None) + for attr in ("json",): + if attr in ser_mod._ResetProbe.__dict__: + delattr(ser_mod._ResetProbe, attr) + # Re-bootstrap so subsequent tests see the normal active pool + # (e.g. PickleSerializer in _active_serializers + _records). + SerializerStore.bootstrap() diff --git a/test/unit/test_serializer_public_api.py b/test/unit/test_serializer_public_api.py new file mode 100644 index 00000000000..50a3e7e8634 --- /dev/null +++ b/test/unit/test_serializer_public_api.py @@ -0,0 +1,55 @@ +"""Smoke tests guarding the public surface of metaflow.datastore.artifacts.""" + + +def test_register_serializer_for_type_not_public(): + """Imperative per-type registration is not a public API.""" + import metaflow.datastore.artifacts as mda + + assert not hasattr(mda, "register_serializer_for_type") + + +def test_serializer_config_not_public(): + """SerializerConfig is not a public export.""" + import metaflow.datastore.artifacts as mda + + assert not hasattr(mda, "SerializerConfig") + + +def test_register_serializer_config_not_public(): + import metaflow.datastore.artifacts as mda + + assert not hasattr(mda, "register_serializer_config") + + +def test_iter_registered_configs_not_public(): + import metaflow.datastore.artifacts as mda + + assert not hasattr(mda, "iter_registered_configs") + + +def test_load_serializer_class_not_public(): + import metaflow.datastore.artifacts as mda + + assert not hasattr(mda, "load_serializer_class") + + +def test_plugins_has_no_artifact_serializers_global(): + """metaflow.plugins does not expose a resolved ARTIFACT_SERIALIZERS global. + Dispatch reads directly from SerializerStore.get_ordered_serializers().""" + import metaflow.plugins as mplugins + + assert not hasattr( + mplugins, "ARTIFACT_SERIALIZERS" + ), "Expected ARTIFACT_SERIALIZERS to be absent; still present" + + +def test_pickle_serializer_is_active_after_import(): + """After import metaflow, PickleSerializer should be in ACTIVE state.""" + from metaflow.datastore.artifacts import list_serializer_status + + status = list_serializer_status() + pickle_rec = next((r for r in status if r.get("type") == "pickle"), None) + assert pickle_rec is not None, "PickleSerializer record missing; status=%r" % status + assert pickle_rec["state"] == "active", ( + "Expected pickle active; got %r" % pickle_rec + )