diff --git a/.gitignore b/.gitignore index c92c1689ddf..df051fb791a 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ stubs/version.py # claude code .claude/ +workflow.yaml diff --git a/metaflow/algospec.py b/metaflow/algospec.py new file mode 100644 index 00000000000..4b782f3d079 --- /dev/null +++ b/metaflow/algospec.py @@ -0,0 +1,103 @@ +""" +AlgoSpec -- a single-computation unit within Metaflow. + +AlgoSpec is a FlowSpec subclass with a single step: the call method, +marked as @step(start=True, end=True). FlowGraph._identify_start_end +picks it up via the is_start/is_end attributes — no special-casing +needed in graph traversal, lint, runtime, or Maestro. + +Lifecycle: + init() -- called once per worker (model loading) + call() -- called per row or batch (computation) +""" + +import atexit + +from .flowspec import FlowSpec, FlowSpecMeta + + +class AlgoSpecMeta(FlowSpecMeta): + """Metaclass for AlgoSpec. + + Marks call() as @step(start=True, end=True) before FlowSpecMeta + builds the graph. The step name is the class name lowercased. + """ + + _registry = [] + _atexit_registered = False + + def __init__(cls, name, bases, attrs): + if name == "AlgoSpec": + super().__init__(name, bases, attrs) + return + + from .decorators import step + + call_fn = attrs.get("call") + if call_fn is not None and callable(call_fn): + attrs["call"] = step(call_fn, start=True, end=True) + attrs["call"].name = name.lower() + attrs["call"].__name__ = name.lower() + cls.call = attrs["call"] + cls._algo_step_name = name.lower() + + super().__init__(name, bases, attrs) + + if not hasattr(cls.call, "is_step"): + from .exception import MetaflowException + + raise MetaflowException( + "%s must implement call(). " + "AlgoSpec subclasses require a call() method." % name + ) + + AlgoSpecMeta._registry.append(cls) + + if not AlgoSpecMeta._atexit_registered: + atexit.register(AlgoSpecMeta._on_exit) + AlgoSpecMeta._atexit_registered = True + + def _init_graph(cls): + from .graph import FlowGraph + + cls._graph = FlowGraph(cls) + # The method is cls.call but node.name is the class-derived name. + if cls._graph.is_algo_spec: + cls._steps = [cls.call] + else: + cls._steps = [getattr(cls, node.name) for node in cls._graph] + + @staticmethod + def _on_exit(): + AlgoSpecMeta._registry.clear() + + +class AlgoSpec(FlowSpec, metaclass=AlgoSpecMeta): + """Base class for single-computation algo specifications.""" + + is_algo_spec = True + + _EPHEMERAL = FlowSpec._EPHEMERAL | {"is_algo_spec"} + + _NON_PARAMETERS = FlowSpec._NON_PARAMETERS | { + "init", + "call", + "is_algo_spec", + } + + def init(self): + """Called once per worker before any call() invocations.""" + pass + + def call(self): + """Main computation. Must be overridden.""" + raise NotImplementedError("Subclasses must implement call()") + + def __call__(self, *args, **kwargs): + return self.call(*args, **kwargs) + + def __getattr__(self, name): + # Resolve the class-derived step name to the call method + if name == getattr(self.__class__, "_algo_step_name", None): + return self.call + return super().__getattr__(name) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index e6852b71901..d6c0aa9a1e9 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -2121,8 +2121,9 @@ def parent_steps(self) -> Iterator["Step"]: Parent step """ graph_info = self.task["_graph_info"].data + start_step = graph_info.get("start_step", "start") - if self.id != "start": + if self.id != start_step: flow, run, _ = self.path_components for node_name, attributes in graph_info["steps"].items(): if self.id in attributes["next"]: @@ -2139,8 +2140,9 @@ def child_steps(self) -> Iterator["Step"]: Child step """ graph_info = self.task["_graph_info"].data + end_step = graph_info.get("end_step", "end") - if self.id != "end": + if self.id != end_step: flow, run, _ = self.path_components for next_step in graph_info["steps"][self.id]["next"]: yield Step(f"{flow}/{run}/{next_step}", _namespace_check=False) @@ -2176,6 +2178,25 @@ def _iter_filter(self, x): # exclude _parameters step return x.id[0] != "_" + _cached_endpoints = None + + @property + def _graph_endpoints(self): + """ + Returns (start_step_name, end_step_name) from _parameters metadata. + Falls back to ("start", "end") for runs that predate this change. + """ + if self._cached_endpoints is None: + start, end = "start", "end" + try: + params_meta = self["_parameters"].task.metadata_dict + start = params_meta.get("start_step", "start") + end = params_meta.get("end_step", "end") + except Exception: + pass + self._cached_endpoints = (start, end) + return self._cached_endpoints + def steps(self, *tags: str) -> Iterator[Step]: """ [Legacy function - do not use] @@ -2298,17 +2319,18 @@ def finished_at(self) -> Optional[datetime]: @property def end_task(self) -> Optional[Task]: """ - Returns the Task corresponding to the 'end' step. + Returns the Task corresponding to the terminal step. - This returns None if the end step does not yet exist. + This returns None if the terminal step does not yet exist. Returns ------- Task, optional - The 'end' task + The terminal step's task """ try: - end_step = self["end"] + _, end_step_name = self._graph_endpoints + end_step = self[end_step_name] except KeyError: return None @@ -2481,8 +2503,9 @@ def trigger(self) -> Optional[Trigger]: Trigger, optional Container of triggering events """ - if "start" in self and self["start"].task: - meta = self["start"].task.metadata_dict.get("execution-triggers") + start_step, _ = self._graph_endpoints + if start_step in self and self[start_step].task: + meta = self[start_step].task.metadata_dict.get("execution-triggers") if meta: return Trigger(json.loads(meta)) return None diff --git a/metaflow/datastore/artifacts/__init__.py b/metaflow/datastore/artifacts/__init__.py new file mode 100644 index 00000000000..6ec150a8d97 --- /dev/null +++ b/metaflow/datastore/artifacts/__init__.py @@ -0,0 +1,6 @@ +from .serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, + SerializerStore, +) diff --git a/metaflow/datastore/artifacts/serializer.py b/metaflow/datastore/artifacts/serializer.py new file mode 100644 index 00000000000..07b2418f032 --- /dev/null +++ b/metaflow/datastore/artifacts/serializer.py @@ -0,0 +1,173 @@ +from abc import ABCMeta, abstractmethod +from collections import namedtuple + + +SerializationMetadata = namedtuple( + "SerializationMetadata", ["type", "size", "encoding", "serializer_info"] +) + + +class SerializedBlob(object): + """ + Represents a single blob produced by a serializer. + + A serializer may produce multiple blobs per artifact. Each blob is either: + - New bytes to be stored (is_reference=False, value is bytes) + - A reference to already-stored data (is_reference=True, value is a string key) + + Parameters + ---------- + value : Union[str, bytes] + The blob data (bytes) or a reference key (str). + is_reference : bool, optional + If None, auto-detected from value type: str -> reference, bytes -> new data. + compress_method : str + Compression method for new blobs. Ignored for references. Default "gzip". + NOTE: Not yet wired into the save path — ContentAddressedStore currently + always applies gzip. This field is forward-looking for when per-blob + compression control is needed (e.g., multi-blob IOType support). + """ + + def __init__(self, value, is_reference=None, compress_method="gzip"): + if not isinstance(value, (str, bytes)): + raise TypeError( + "SerializedBlob value must be str or bytes, got %s" % type(value).__name__ + ) + self.value = value + self.compress_method = compress_method + if is_reference is None: + self.is_reference = isinstance(value, str) + else: + self.is_reference = is_reference + + @property + def needs_save(self): + """True if this blob contains new bytes that need to be stored.""" + return not self.is_reference + + +class SerializerStore(ABCMeta): + """ + Metaclass for ArtifactSerializer that auto-registers subclasses by TYPE. + + Provides deterministic ordering: serializers are sorted by (PRIORITY, registration_order). + Lower PRIORITY values are tried first. Registration order breaks ties. + """ + + _all_serializers = {} + _registration_order = [] + + def __init__(cls, name, bases, namespace): + super().__init__(name, bases, namespace) + if cls.TYPE is not None: + if cls.TYPE not in SerializerStore._all_serializers: + SerializerStore._registration_order.append(cls.TYPE) + SerializerStore._all_serializers[cls.TYPE] = cls + + @staticmethod + def get_ordered_serializers(): + """ + Return serializer classes sorted by (PRIORITY, registration_order). + + This ordering is deterministic for a given set of loaded serializers. + """ + order = SerializerStore._registration_order + return sorted( + SerializerStore._all_serializers.values(), + key=lambda s: (s.PRIORITY, order.index(s.TYPE)), + ) + + +class ArtifactSerializer(object, metaclass=SerializerStore): + """ + Abstract base class for artifact serializers. + + Subclasses must set TYPE to a unique string identifier and implement + all four class methods. Subclasses are auto-registered by the SerializerStore + metaclass on class definition. + + Attributes + ---------- + TYPE : str or None + Unique identifier for this serializer (e.g., "pickle", "iotype"). + Set to None in the base class to prevent registration. + PRIORITY : int + Dispatch priority. Lower values are tried first. Default 100. + PickleSerializer uses 9999 as the universal fallback. + """ + + TYPE = None + PRIORITY = 100 + + @classmethod + @abstractmethod + def can_serialize(cls, obj): + """ + Return True if this serializer can handle the given object. + + Parameters + ---------- + obj : Any + The Python object to serialize. + + Returns + ------- + bool + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def can_deserialize(cls, metadata): + """ + Return True if this serializer can deserialize given the metadata. + + Parameters + ---------- + metadata : SerializationMetadata + Metadata stored alongside the artifact. + + Returns + ------- + bool + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def serialize(cls, obj): + """ + Serialize obj to blobs and metadata. Must be side-effect-free. + + Parameters + ---------- + obj : Any + The Python object to serialize. + + Returns + ------- + tuple + (List[SerializedBlob], SerializationMetadata) + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def deserialize(cls, blobs, metadata, context): + """ + Deserialize blobs back to a Python object. + + Parameters + ---------- + blobs : List[bytes] + The raw blob data. + metadata : SerializationMetadata + Metadata stored alongside the artifact. + context : Any + Optional context for deserialization (e.g., task vs client loading). + + Returns + ------- + Any + """ + raise NotImplementedError diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index abeeb8ea5fb..fa390f8b7a2 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -1,6 +1,5 @@ from collections import defaultdict import json -import pickle import sys import time @@ -117,12 +116,12 @@ def __init__( self._parent = flow_datastore self._persist = persist - # The GZIP encodings are for backward compatibility - self._encodings = {"pickle-v2", "gzip+pickle-v2"} - ver = sys.version_info[0] * 10 + sys.version_info[1] - if ver >= 36: - self._encodings.add("pickle-v4") - self._encodings.add("gzip+pickle-v4") + # Import here to ensure serializers are registered via the plugin system. + # The import of metaflow.plugins triggers PickleSerializer (and any + # extension serializers) to register with SerializerStore. + from .artifacts.serializer import SerializerStore + + self._serializers = SerializerStore.get_ordered_serializers() self._is_done_set = False @@ -347,38 +346,42 @@ def save_artifacts(self, artifacts_iter, len_hint=0): """ artifact_names = [] - def pickle_iter(): + def serialize_iter(): for name, obj in artifacts_iter: - encode_type = "gzip+pickle-v4" - if encode_type in self._encodings: - try: - blob = pickle.dumps(obj, protocol=4) - except TypeError as e: - raise UnpicklableArtifactException(name) from e - else: - try: - blob = pickle.dumps(obj, protocol=2) - encode_type = "gzip+pickle-v2" - except (SystemError, OverflowError) as e: - raise DataException( - "Artifact *%s* is very large (over 2GB). " - "You need to use Python 3.6 or newer if you want to " - "serialize large objects." % name - ) from e - except TypeError as e: - raise UnpicklableArtifactException(name) from e + # Find the first serializer that can handle this object + serializer = None + for s in self._serializers: + if s.can_serialize(obj): + serializer = s + break + if serializer is None: + raise UnpicklableArtifactException(name) + + try: + blobs, metadata = serializer.serialize(obj) + except TypeError as e: + raise UnpicklableArtifactException(name) from e self._info[name] = { - "size": len(blob), - "type": str(type(obj)), - "encoding": encode_type, + "size": metadata.size, + "type": metadata.type, + "encoding": metadata.encoding, } + if metadata.serializer_info: + self._info[name]["serializer_info"] = metadata.serializer_info + # For now, serializers produce a single blob per artifact. + # Multi-blob support will be added when IOType lands. + if not blobs: + raise DataException( + "Serializer %s returned no blobs for artifact '%s'" + % (serializer.__name__, name) + ) artifact_names.append(name) - yield blob + yield blobs[0].value # Use the content-addressed store to store all artifacts - save_result = self._ca_store.save_blobs(pickle_iter(), len_hint=len_hint) + save_result = self._ca_store.save_blobs(serialize_iter(), len_hint=len_hint) for name, result in zip(artifact_names, save_result): self._objects[name] = result.key @@ -408,38 +411,50 @@ def load_artifacts(self, names): Iterator[(string, object)] : An iterator over objects retrieved. """ + from .artifacts.serializer import SerializationMetadata + if not self._info: raise DataException( "Datastore for task '%s' does not have the required metadata to " "load artifacts" % self._path ) to_load = defaultdict(list) + deserializers = {} # name -> serializer class for name in names: - info = self._info.get(name) - # We use gzip+pickle-v2 as this is the oldest/most compatible. - # This datastore will always include the proper encoding version so - # this is just to be able to read very old artifacts - if info: - encode_type = info.get("encoding", "gzip+pickle-v2") - else: - encode_type = "gzip+pickle-v2" - if encode_type not in self._encodings: + info = self._info.get(name, {}) + metadata = SerializationMetadata( + type=info.get("type", "object"), + size=info.get("size", 0), + # Default to gzip+pickle-v2 for very old artifacts without encoding + encoding=info.get("encoding", "gzip+pickle-v2"), + serializer_info=info.get("serializer_info", {}), + ) + + # Find deserializer via metadata + deserializer = None + for s in self._serializers: + if s.can_deserialize(metadata): + deserializer = s + break + if deserializer is None: raise DataException( - "Python 3.6 or later is required to load artifact '%s'" % name + "No deserializer found for artifact '%s' " + "(encoding: %s)" % (name, metadata.encoding) ) - else: - to_load[self._objects[name]].append(name) - # At this point, we load what we don't have from the CAS - # We assume that if we have one "old" style artifact, all of them are - # like that which is an easy assumption to make since artifacts are all - # stored by the same implementation of the datastore for a given task. + deserializers[name] = (deserializer, metadata) + to_load[self._objects[name]].append(name) + + # Load blobs from CAS and deserialize for key, blob in self._ca_store.load_blobs(to_load.keys()): - names = to_load[key] - for name in names: - # We unpickle everytime to have fully distinct objects (the user + loaded_names = to_load[key] + for name in loaded_names: + deserializer, metadata = deserializers[name] + # Deserialize each time to have fully distinct objects (the user # would not expect two artifacts with different names to actually # be aliases of one another) - yield name, pickle.loads(blob) + yield name, deserializer.deserialize( + [blob], metadata, context=None + ) @require_mode("r") def get_artifact_sizes(self, names): diff --git a/metaflow/decorators.py b/metaflow/decorators.py index b5c1f334eb6..ab2601ba729 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -1003,9 +1003,7 @@ def step( ) -> Callable[[FlowSpecDerived, Any, StepFlag], None]: ... -def step( - f: Union[Callable[[FlowSpecDerived], None], Callable[[FlowSpecDerived, Any], None]], -): +def step(f=None, *, start=False, end=False): """ Marks a method in a FlowSpec as a Metaflow Step. Note that this decorator needs to be placed as close to the method as possible (ie: @@ -1029,20 +1027,34 @@ def foo(self): Parameters ---------- - f : Union[Callable[[FlowSpecDerived], None], Callable[[FlowSpecDerived, Any], None]] - Function to make into a Metaflow Step + f : callable, optional + Function to make into a Metaflow Step. + start : bool, default False + Mark this step as the entry point of the flow. + end : bool, default False + Mark this step as the terminal step of the flow. Returns ------- - Union[Callable[[FlowSpecDerived, StepFlag], None], Callable[[FlowSpecDerived, Any, StepFlag], None]] + callable Function that is a Metaflow Step """ - f.is_step = True - f.decorators = [] - f.config_decorators = [] - f.wrappers = [] - f.name = f.__name__ - return f + + def decorator(fn): + fn.is_step = True + fn.is_start = start + fn.is_end = end + fn.decorators = [] + fn.config_decorators = [] + fn.wrappers = [] + fn.name = fn.__name__ + return fn + + if f is not None: + # Called as @step (no args) — backward compatible + return decorator(f) + # Called as @step(start=True, end=True) + return decorator def _import_plugin_decorators(globals_dict): diff --git a/metaflow/extension_support/plugins.py b/metaflow/extension_support/plugins.py index 62ed434fd63..d93b37a6c79 100644 --- a/metaflow/extension_support/plugins.py +++ b/metaflow/extension_support/plugins.py @@ -218,6 +218,7 @@ def resolve_plugins(category, path_only=False): ), "runner_cli": lambda x: x.name, "tl_plugin": None, + "artifact_serializer": lambda x: x.TYPE, } diff --git a/metaflow/flowspec.py b/metaflow/flowspec.py index 08e8dc35e54..762cbba89d9 100644 --- a/metaflow/flowspec.py +++ b/metaflow/flowspec.py @@ -516,6 +516,8 @@ def _set_constants(self, graph, kwargs, config_options): graph_info = { "file": os.path.basename(os.path.abspath(sys.argv[0])), + "start_step": graph.start_step, + "end_step": graph.end_step, "parameters": parameters_info, "constants": constants_info, "steps": steps_info, diff --git a/metaflow/graph.py b/metaflow/graph.py index 1c5dd53bfc5..87d5edc49d0 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -48,9 +48,20 @@ def deindent_docstring(doc): class DAGNode(object): def __init__( - self, func_ast, decos, wrappers, config_decorators, doc, source_file, lineno + self, + func_ast, + decos, + wrappers, + config_decorators, + doc, + source_file, + lineno, + is_start=False, + is_end=False, ): self.name = func_ast.name + self.is_start = is_start + self.is_end = is_end self.source_file = source_file # lineno is the start line of decorators in source_file # func_ast.lineno is lines from decorators start to def of function @@ -140,10 +151,9 @@ def _parse(self, func_ast, lineno): self.num_args = len(func_ast.args.args) tail = func_ast.body[-1] - # end doesn't need a transition - if self.name == "end": - # TYPE: end - self.type = "end" + # Note: type assignment for start/end steps is handled by + # FlowGraph._identify_start_end() based on @step attributes + # and name fallback, not by name here. # ensure that the tail an expression if not isinstance(tail, ast.Expr): @@ -212,10 +222,10 @@ def _parse(self, func_ast, lineno): self.type = "split" self.invalid_tail_next = False elif len(self.out_funcs) == 1: - # TYPE: linear - if self.name == "start": - self.type = "start" - elif self.num_args > 1: + # TYPE: linear (or join) + # Note: "start" type is assigned later by + # _identify_start_end() based on structure. + if self.num_args > 1: self.type = "join" else: self.type = "linear" @@ -255,10 +265,12 @@ def __str__(self): class FlowGraph(object): def __init__(self, flow): self.name = flow.__name__ + self.is_algo_spec = getattr(flow, "is_algo_spec", False) self.nodes = self._create_nodes(flow) self.doc = deindent_docstring(flow.__doc__) # nodes sorted in topological order. self.sorted_nodes = [] + self._identify_start_end() self._traverse_graph() self._postprocess() @@ -281,10 +293,54 @@ def _create_nodes(self, flow): func.__doc__, source_file, lineno, + is_start=getattr(func, "is_start", False), + is_end=getattr(func, "is_end", False), ) nodes[element] = node + + # AlgoSpec: re-key from "call" to class-derived step name + if self.is_algo_spec and "call" in nodes: + node = nodes.pop("call") + node.name = getattr(flow, "_algo_step_name", "call") + nodes[node.name] = node + return nodes + def _identify_start_end(self): + """Determine start and end steps. + + Tier 1: steps named "start"/"end" (backward compat). + Tier 2: @step(start=True) / @step(end=True) (explicit). + Tier 3: None — lint will catch it. + + Also assigns "start"/"end" node types based on structure. + """ + # Tier 1: name-based + if "start" in self.nodes: + self.start_step = "start" + else: + # Tier 2: explicit attribute + starts = [n for n, node in self.nodes.items() if node.is_start] + self.start_step = starts[0] if len(starts) == 1 else None + + if "end" in self.nodes: + self.end_step = "end" + else: + ends = [n for n, node in self.nodes.items() if node.is_end] + self.end_step = ends[0] if len(ends) == 1 else None + + # Assign types + if self.start_step and self.start_step == self.end_step: + # Single-step flow (e.g. AlgoSpec) + self.nodes[self.start_step].type = "end" + else: + if self.start_step: + node = self.nodes[self.start_step] + if node.type in (None, "linear"): + node.type = "start" + if self.end_step: + self.nodes[self.end_step].type = "end" + def _postprocess(self): # any node who has a foreach as any of its split parents # has is_inside_foreach=True *unless* all of those `foreach`s @@ -338,8 +394,8 @@ def traverse(node, seen, split_parents, split_branches): split_branches + ([n] if add_split_branch else []), ) - if "start" in self: - traverse(self["start"], [], [], []) + if self.start_step and self.start_step in self: + traverse(self[self.start_step], [], [], []) # fix the order of in_funcs for node in self.nodes.values(): @@ -493,9 +549,15 @@ def populate_block(start_name, end_name): break return resulting_list - graph_structure = populate_block("start", "end") + if self.start_step == self.end_step: + # Single-step flow + graph_structure = [] + else: + graph_structure = populate_block(self.start_step, self.end_step) - steps_info["end"] = node_to_dict("end", self.nodes["end"]) - graph_structure.append("end") + steps_info[self.end_step] = node_to_dict( + self.end_step, self.nodes[self.end_step] + ) + graph_structure.append(self.end_step) return steps_info, graph_structure diff --git a/metaflow/io_types/__init__.py b/metaflow/io_types/__init__.py new file mode 100644 index 00000000000..4079040e1f5 --- /dev/null +++ b/metaflow/io_types/__init__.py @@ -0,0 +1,7 @@ +from .base import IOType +from .scalars import Bool, Float32, Float64, Int32, Int64, Text +from .json_type import Json +from .enum_type import Enum +from .struct_type import Struct +from .collections import List, Map +from .tensor_type import Tensor diff --git a/metaflow/io_types/base.py b/metaflow/io_types/base.py new file mode 100644 index 00000000000..90631f9a7ac --- /dev/null +++ b/metaflow/io_types/base.py @@ -0,0 +1,121 @@ +from abc import ABCMeta, abstractmethod + + +_UNSET = object() + + +class IOType(object, metaclass=ABCMeta): + """ + Base class for typed Metaflow artifacts. + + IOType serves dual purposes: + - **Type descriptor** (no value): ``Int64`` describes an int64 field in a spec. + - **Value wrapper** (with value): ``Int64(42)`` wraps a value for typed serialization. + + Both support ``to_spec()`` for JSON schema generation. + + Subclasses implement four internal operations unified behind a ``format`` + parameter: + + - ``format='wire'``: string-based (for CLI args, protobuf, external APIs) + - ``format='storage'``: blob-based (for S3/disk persistence via datastore) + + Storage byte order is little-endian. + """ + + type_name = None # e.g., "text", "json", "int64" — set by subclasses + + def __init__(self, value=_UNSET): + self._value = value + + @property + def value(self): + """The wrapped Python value, or _UNSET if this is a descriptor only.""" + return self._value + + # --- Public API (UX sugar) --- + + def serialize(self, format="wire"): + """ + Serialize the wrapped value. + + Parameters + ---------- + format : str + ``'wire'`` for string output, ``'storage'`` for blob output. + + Returns + ------- + str or tuple + Wire: a string. Storage: ``(List[SerializedBlob], dict)``. + """ + if format == "wire": + return self._wire_serialize() + elif format == "storage": + return self._storage_serialize() + raise ValueError("format must be 'wire' or 'storage', got %r" % format) + + @classmethod + def deserialize(cls, data, format="wire", **kwargs): + """ + Deserialize data into an IOType instance. + + Parameters + ---------- + data : str or List[bytes] + Wire: a string. Storage: list of byte blobs. + format : str + ``'wire'`` or ``'storage'``. + + Returns + ------- + IOType + """ + if format == "wire": + return cls._wire_deserialize(data) + elif format == "storage": + return cls._storage_deserialize(data, **kwargs) + raise ValueError("format must be 'wire' or 'storage', got %r" % format) + + # --- Four internal operations (subclasses implement these) --- + + @abstractmethod + def _wire_serialize(self): + """Value -> string (for CLI args, protobuf, external APIs).""" + raise NotImplementedError + + @classmethod + @abstractmethod + def _wire_deserialize(cls, s): + """String -> IOType instance.""" + raise NotImplementedError + + @abstractmethod + def _storage_serialize(self): + """Value -> (List[SerializedBlob], metadata_dict). Side-effect-free.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def _storage_deserialize(cls, blobs, **kwargs): + """(List[bytes], metadata) -> IOType instance.""" + raise NotImplementedError + + # --- Spec generation --- + + def to_spec(self): + """JSON type spec. Works with or without a wrapped value.""" + return {"type": self.type_name} + + def __repr__(self): + if self._value is _UNSET: + return "%s()" % self.__class__.__name__ + return "%s(%r)" % (self.__class__.__name__, self._value) + + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + return self._value == other._value + + def __hash__(self): + return hash((type(self), self._value)) diff --git a/metaflow/io_types/collections.py b/metaflow/io_types/collections.py new file mode 100644 index 00000000000..923c2f543dd --- /dev/null +++ b/metaflow/io_types/collections.py @@ -0,0 +1,110 @@ +import json + +from ..datastore.artifacts.serializer import SerializedBlob +from .base import IOType + + +class List(IOType): + """ + Typed list. Wire: JSON string. Storage: JSON UTF-8 bytes. + + element_type is used for spec generation (to_spec) only. Serde uses JSON + for the entire list — individual elements are not delegated to their IOType + serializers. This is intentional: JSON is sufficient for list storage, and + per-element typed serde can be added later if needed. + + Parameters + ---------- + value : list, optional + The wrapped list value. + element_type : IOType subclass, optional + The IOType class for list elements (for spec generation). + """ + + type_name = "list" + + def __init__(self, value=None, element_type=None): + self._element_type = element_type + super().__init__(value) + + def _wire_serialize(self): + return json.dumps(self._value, separators=(",", ":"), sort_keys=True) + + @classmethod + def _wire_deserialize(cls, s): + return cls(json.loads(s)) + + def _storage_serialize(self): + blob = json.dumps(self._value, separators=(",", ":"), sort_keys=True).encode( + "utf-8" + ) + meta = {} + if self._element_type is not None: + meta["element_type"] = self._element_type.type_name + return [SerializedBlob(blob)], meta + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + return cls(json.loads(blobs[0].decode("utf-8"))) + + def to_spec(self): + spec = {"type": self.type_name} + if self._element_type is not None: + spec["element_type"] = self._element_type().to_spec() + return spec + + +class Map(IOType): + """ + Typed map (dict with typed keys and values). Wire: JSON string. Storage: JSON UTF-8 bytes. + + key_type/value_type are used for spec generation (to_spec) only. Serde uses + JSON for the entire map — individual entries are not delegated to their IOType + serializers. Same rationale as List. + + Parameters + ---------- + value : dict, optional + The wrapped dict value. + key_type : IOType subclass, optional + The IOType class for map keys (for spec generation). + value_type : IOType subclass, optional + The IOType class for map values (for spec generation). + """ + + type_name = "map" + + def __init__(self, value=None, key_type=None, value_type=None): + self._key_type = key_type + self._value_type = value_type + super().__init__(value) + + def _wire_serialize(self): + return json.dumps(self._value, separators=(",", ":"), sort_keys=True) + + @classmethod + def _wire_deserialize(cls, s): + return cls(json.loads(s)) + + def _storage_serialize(self): + blob = json.dumps(self._value, separators=(",", ":"), sort_keys=True).encode( + "utf-8" + ) + meta = {} + if self._key_type is not None: + meta["key_type"] = self._key_type.type_name + if self._value_type is not None: + meta["value_type"] = self._value_type.type_name + return [SerializedBlob(blob)], meta + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + return cls(json.loads(blobs[0].decode("utf-8"))) + + def to_spec(self): + spec = {"type": self.type_name} + if self._key_type is not None: + spec["key_type"] = self._key_type().to_spec() + if self._value_type is not None: + spec["value_type"] = self._value_type().to_spec() + return spec diff --git a/metaflow/io_types/enum_type.py b/metaflow/io_types/enum_type.py new file mode 100644 index 00000000000..c3c8c2cb2a4 --- /dev/null +++ b/metaflow/io_types/enum_type.py @@ -0,0 +1,65 @@ +from ..datastore.artifacts.serializer import SerializedBlob +from .base import IOType + + +class Enum(IOType): + """ + Enum type — string value constrained to allowed values. + + Wire: string. Storage: UTF-8 bytes. + + Parameters + ---------- + value : str, optional + The enum value. Validated against allowed_values if provided. + allowed_values : list of str + The set of valid values for this enum. + """ + + type_name = "enum" + + def __init__(self, value=None, allowed_values=None): + if allowed_values is not None: + self._allowed_values = list(allowed_values) + else: + self._allowed_values = [] + if value is not None and self._allowed_values: + if value not in self._allowed_values: + raise ValueError( + "Enum value %r not in allowed values %r" + % (value, self._allowed_values) + ) + super().__init__(value) + + def _wire_serialize(self): + return str(self._value) + + @classmethod + def _wire_deserialize(cls, s): + return cls(s) + + def _storage_serialize(self): + blob = str(self._value).encode("utf-8") + meta = {} + if self._allowed_values: + meta["allowed_values"] = self._allowed_values + return [SerializedBlob(blob)], meta + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + metadata = kwargs.get("metadata", {}) + allowed = metadata.get("allowed_values") + return cls(blobs[0].decode("utf-8"), allowed_values=allowed) + + def to_spec(self): + spec = {"type": self.type_name} + if self._allowed_values: + spec["allowed_values"] = self._allowed_values + return spec + + def __repr__(self): + from .base import _UNSET + + if self._value is _UNSET: + return "Enum(allowed_values=%r)" % self._allowed_values + return "Enum(%r, allowed_values=%r)" % (self._value, self._allowed_values) diff --git a/metaflow/io_types/json_type.py b/metaflow/io_types/json_type.py new file mode 100644 index 00000000000..7b33bf55683 --- /dev/null +++ b/metaflow/io_types/json_type.py @@ -0,0 +1,27 @@ +import json + +from ..datastore.artifacts.serializer import SerializedBlob +from .base import IOType + + +class Json(IOType): + """JSON type (dict or list). Wire: JSON string. Storage: UTF-8 JSON bytes.""" + + type_name = "json" + + def _wire_serialize(self): + return json.dumps(self._value, separators=(",", ":"), sort_keys=True) + + @classmethod + def _wire_deserialize(cls, s): + return cls(json.loads(s)) + + def _storage_serialize(self): + blob = json.dumps(self._value, separators=(",", ":"), sort_keys=True).encode( + "utf-8" + ) + return [SerializedBlob(blob)], {} + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + return cls(json.loads(blobs[0].decode("utf-8"))) diff --git a/metaflow/io_types/scalars.py b/metaflow/io_types/scalars.py new file mode 100644 index 00000000000..ed8131b45ae --- /dev/null +++ b/metaflow/io_types/scalars.py @@ -0,0 +1,150 @@ +import struct + +from ..datastore.artifacts.serializer import SerializedBlob +from .base import IOType + + +class Text(IOType): + """String type. Wire: identity. Storage: UTF-8 bytes.""" + + type_name = "text" + + def _wire_serialize(self): + return str(self._value) + + @classmethod + def _wire_deserialize(cls, s): + return cls(s) + + def _storage_serialize(self): + blob = str(self._value).encode("utf-8") + return [SerializedBlob(blob)], {} + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + return cls(blobs[0].decode("utf-8")) + + +class Bool(IOType): + """Boolean type. Wire: "true"/"false". Storage: 1 byte (0/1).""" + + type_name = "bool" + + def _wire_serialize(self): + return "true" if self._value else "false" + + @classmethod + def _wire_deserialize(cls, s): + if s.lower() == "true": + return cls(True) + elif s.lower() == "false": + return cls(False) + raise ValueError("Bool expects 'true' or 'false', got %r" % s) + + def _storage_serialize(self): + blob = b"\x01" if self._value else b"\x00" + return [SerializedBlob(blob)], {} + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + if len(blobs[0]) != 1 or blobs[0] not in (b"\x00", b"\x01"): + raise ValueError( + "Bool storage expects exactly 1 byte (0x00 or 0x01), got %r" + % blobs[0] + ) + return cls(blobs[0] == b"\x01") + + +class Int32(IOType): + """32-bit signed integer. Wire: str(int). Storage: 4-byte little-endian.""" + + type_name = "int32" + + _MIN = -(2**31) + _MAX = 2**31 - 1 + + def __init__(self, value=None): + if value is not None and not (self._MIN <= value <= self._MAX): + raise ValueError( + "Int32 value %d out of range [%d, %d]" % (value, self._MIN, self._MAX) + ) + super().__init__(value) + + def _wire_serialize(self): + return str(self._value) + + @classmethod + def _wire_deserialize(cls, s): + return cls(int(s)) + + def _storage_serialize(self): + blob = struct.pack(" IOType mapping for @dataclass field inference +PYTHON_TO_IOTYPE = {} # populated after scalar imports to avoid circular deps + + +def _init_python_to_iotype(): + """Lazy init to avoid circular imports.""" + if PYTHON_TO_IOTYPE: + return + from .scalars import Bool, Float64, Int64, Text + + PYTHON_TO_IOTYPE.update( + { + str: Text, + int: Int64, + float: Float64, + bool: Bool, + } + ) + + +def _iotype_for_annotation(annotation): + """ + Resolve a Python type annotation to an IOType class. + + Handles bare types (str, int, float, bool) via PYTHON_TO_IOTYPE. + IOType subclasses pass through directly. + """ + _init_python_to_iotype() + if isinstance(annotation, type) and issubclass(annotation, IOType): + return annotation + iotype = PYTHON_TO_IOTYPE.get(annotation) + if iotype is None: + raise TypeError( + "Cannot infer IOType for annotation %r. " + "Use an explicit IOType (e.g., Json, List, Struct)." % (annotation,) + ) + return iotype + + +class Struct(IOType): + """ + Structured type mapping to Python @dataclass. + + Wire: JSON string. Storage: JSON UTF-8 bytes. + + Wraps a @dataclass instance. Fields are inferred from dataclass annotations + with implicit scalar mapping (str->Text, int->Int64, float->Float64, bool->Bool). + + Parameters + ---------- + value : dataclass instance or dict, optional + The wrapped value. Dataclass instances are serialized via dataclasses.asdict. + Plain dicts are serialized directly as JSON. + dataclass_type : type, optional + The @dataclass class for type descriptor use (no value). + """ + + type_name = "struct" + + def __init__(self, value=None, dataclass_type=None): + if value is not None and dataclasses.is_dataclass(value): + self._dataclass_type = type(value) + elif dataclass_type is not None: + self._dataclass_type = dataclass_type + else: + self._dataclass_type = None + super().__init__(value) + + def _to_dict(self): + """Convert value to dict, handling both dataclass and plain dict.""" + if dataclasses.is_dataclass(self._value): + return dataclasses.asdict(self._value) + if isinstance(self._value, dict): + return self._value + raise TypeError( + "Struct value must be a dataclass instance or dict, got %s" + % type(self._value).__name__ + ) + + def _wire_serialize(self): + return json.dumps(self._to_dict(), separators=(",", ":"), sort_keys=True) + + @classmethod + def _wire_deserialize(cls, s): + return cls(json.loads(s)) + + def _storage_serialize(self): + blob = json.dumps( + self._to_dict(), separators=(",", ":"), sort_keys=True + ).encode("utf-8") + meta = {} + if self._dataclass_type is not None: + meta["dataclass_module"] = self._dataclass_type.__module__ + meta["dataclass_class"] = self._dataclass_type.__name__ + return [SerializedBlob(blob)], meta + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + data = json.loads(blobs[0].decode("utf-8")) + metadata = kwargs.get("metadata", {}) + dc_module = metadata.get("dataclass_module") + dc_class = metadata.get("dataclass_class") + if dc_module and dc_class: + import importlib + + mod = importlib.import_module(dc_module) + dc_type = getattr(mod, dc_class) + # Security: only allow actual dataclasses, not arbitrary classes + if not dataclasses.is_dataclass(dc_type): + raise ValueError( + "Struct metadata references '%s.%s' which is not a dataclass" + % (dc_module, dc_class) + ) + return cls(dc_type(**data), dataclass_type=dc_type) + # Fallback: return as plain dict wrapped in Struct + return cls(data) + + def to_spec(self): + spec = {"type": self.type_name} + if self._dataclass_type is not None and dataclasses.is_dataclass( + self._dataclass_type + ): + # Use typing.get_type_hints() to resolve string annotations + # (handles `from __future__ import annotations`) + try: + hints = typing.get_type_hints(self._dataclass_type) + except Exception: + hints = {} + fields = [] + for f in dataclasses.fields(self._dataclass_type): + annotation = hints.get(f.name, f.type) + try: + field_iotype = _iotype_for_annotation(annotation) + field_spec = field_iotype().to_spec() + except TypeError: + field_spec = {"type": str(annotation)} + fields.append({"name": f.name, **field_spec}) + spec["fields"] = fields + return spec diff --git a/metaflow/io_types/tensor_type.py b/metaflow/io_types/tensor_type.py new file mode 100644 index 00000000000..757c4f072a6 --- /dev/null +++ b/metaflow/io_types/tensor_type.py @@ -0,0 +1,126 @@ +import base64 +import json + +from ..datastore.artifacts.serializer import SerializedBlob +from .base import IOType + + +def _to_little_endian(arr): + """Ensure array is little-endian and contiguous.""" + np = _require_numpy() + arr = np.ascontiguousarray(arr) + if arr.dtype.byteorder == ">" or ( + arr.dtype.byteorder == "=" and import_sys_byteorder() == "big" + ): + arr = arr.byteswap().view(arr.dtype.newbyteorder("<")) + return arr + + +def import_sys_byteorder(): + import sys + + return sys.byteorder + + +class Tensor(IOType): + """ + N-dimensional array type backed by numpy ndarray. + + Wire: base64-encoded string (with shape/dtype in JSON prefix). + Storage: raw little-endian bytes blob + shape/dtype metadata dict. + + All serialization normalizes to little-endian byte order regardless of + the host platform's native endianness. + + Parameters + ---------- + value : numpy.ndarray, optional + The wrapped array value. + """ + + type_name = "tensor" + + def _wire_serialize(self): + np = _require_numpy() + arr = _to_little_endian(self._value) + raw = arr.tobytes() + header = json.dumps( + {"dtype": arr.dtype.str, "shape": list(arr.shape)}, + separators=(",", ":"), + ) + return header + "|" + base64.b64encode(raw).decode("ascii") + + @classmethod + def _wire_deserialize(cls, s): + np = _require_numpy() + header_str, b64_data = s.split("|", 1) + header = json.loads(header_str) + raw = base64.b64decode(b64_data) + arr = np.frombuffer(raw, dtype=np.dtype(header["dtype"])).reshape( + header["shape"] + ) + return cls(arr.copy()) + + def _storage_serialize(self): + np = _require_numpy() + arr = _to_little_endian(self._value) + blob = arr.tobytes() + meta = { + "dtype": arr.dtype.str, + "shape": list(arr.shape), + } + return [SerializedBlob(blob)], meta + + @classmethod + def _storage_deserialize(cls, blobs, **kwargs): + np = _require_numpy() + metadata = kwargs.get("metadata", {}) + dtype = np.dtype(metadata["dtype"]) + shape = tuple(metadata["shape"]) + arr = np.frombuffer(blobs[0], dtype=dtype).reshape(shape) + return cls(arr.copy()) + + def to_spec(self): + spec = {"type": self.type_name} + if self._value is not None: + try: + import numpy as np + + if isinstance(self._value, np.ndarray): + spec["dtype"] = str(self._value.dtype) + spec["shape"] = list(self._value.shape) + except ImportError: + pass + return spec + + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + try: + import numpy as np + + if isinstance(self._value, np.ndarray) and isinstance( + other._value, np.ndarray + ): + return ( + self._value.shape == other._value.shape + and self._value.dtype == other._value.dtype + and np.array_equal(self._value, other._value) + ) + except ImportError: + pass + return self._value == other._value + + # ndarray is not hashable — match numpy's behavior + __hash__ = None + + +def _require_numpy(): + try: + import numpy as np + + return np + except ImportError: + raise ImportError( + "Tensor type requires numpy. Install it with: pip install numpy" + ) diff --git a/metaflow/lint.py b/metaflow/lint.py index 8f82c813b0d..fb8a555e6d5 100644 --- a/metaflow/lint.py +++ b/metaflow/lint.py @@ -58,23 +58,33 @@ def check_reserved_words(graph): @linter.ensure_fundamentals @linter.check def check_basic_steps(graph): - msg = "Add %s *%s* step in your flow." - for prefix, node in (("a", "start"), ("an", "end")): - if node not in graph: - raise LintWarn(msg % (prefix, node)) + if graph.start_step is None: + raise LintWarn( + "Your flow must have a step named 'start', or exactly one " + "step decorated with @step(start=True)." + ) + if graph.end_step is None: + raise LintWarn( + "Your flow must have a step named 'end', or exactly one " + "step decorated with @step(end=True)." + ) @linter.ensure_static_graph @linter.check def check_that_end_is_end(graph): - msg0 = "The *end* step should not have a step.next() transition. " "Just remove it." + if graph.end_step is None or graph.start_step == graph.end_step: + return # single-step graph or missing end — handled elsewhere + node = graph[graph.end_step] + msg0 = ( + "The terminal step *%s* should not have a self.next() transition. " + "Just remove it." % graph.end_step + ) msg1 = ( - "The *end* step should not be a join step (it gets an extra " - "argument). Add a join step before it." + "The terminal step *%s* should not be a join step (it gets an extra " + "argument). Add a join step before it." % graph.end_step ) - node = graph["end"] - if node.has_tail_next or node.invalid_tail_next: raise LintWarn(msg0, node.tail_next_lineno, node.source_file) if node.num_args > 1: @@ -96,6 +106,8 @@ def check_step_names(graph): @linter.ensure_fundamentals @linter.check def check_num_args(graph): + if graph.start_step == graph.end_step: + return # single-step graph msg0 = ( "Step {0.name} has too many arguments. Normal steps take only " "'self' as an argument. Join steps take 'self' and 'inputs'." @@ -192,11 +204,13 @@ def check_path(node, seen): @linter.ensure_static_graph @linter.check def check_for_orphans(graph): + if graph.start_step is None: + return msg = ( - "Step *{0.name}* is unreachable from the start step. Add " - "self.next({0.name}) in another step or remove *{0.name}*." + "Step *{0.name}* is unreachable from the entry step *%s*. Add " + "self.next({0.name}) in another step or remove *{0.name}*." % graph.start_step ) - seen = set(["start"]) + seen = set([graph.start_step]) def traverse(node): for n in node.out_funcs: @@ -204,7 +218,7 @@ def traverse(node): seen.add(n) traverse(graph[n]) - traverse(graph["start"]) + traverse(graph[graph.start_step]) nodeset = frozenset(n.name for n in graph) orphans = nodeset - seen if orphans: @@ -215,9 +229,11 @@ def traverse(node): @linter.ensure_static_graph @linter.check def check_split_join_balance(graph): + if graph.start_step is None or graph.end_step is None: + return msg0 = ( - "Step *end* reached before a split started at step(s) *{roots}* " - "were joined. Add a join step before *end*." + "The terminal step *{end}* was reached before a split started at " + "step(s) *{roots}* were joined. Add a join step before *{end}*." ) msg1 = ( "Step *{0.name}* seems like a join step (it takes an extra input " @@ -253,7 +269,9 @@ def traverse(node, split_stack): _, split_roots = split_stack.pop() roots = ", ".join(split_roots) raise LintWarn( - msg0.format(roots=roots), node.func_lineno, node.source_file + msg0.format(roots=roots, end=graph.end_step), + node.func_lineno, + node.source_file, ) elif node.type == "join": new_stack = split_stack @@ -301,7 +319,7 @@ def parents(n): continue traverse(graph[n], new_stack) - traverse(graph["start"], []) + traverse(graph[graph.start_step], []) @linter.ensure_static_graph diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index 3fc1d3f8db6..d5846fb80c1 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -187,6 +187,13 @@ ("conda_environment_yml_parser", ".pypi.parsers.conda_environment_yml_parser"), ] +# Add artifact serializers here. Ordering is by PRIORITY (lower = tried first). +# PickleSerializer is the universal fallback (PRIORITY=9999). +ARTIFACT_SERIALIZERS_DESC = [ + ("iotype", ".datastores.serializers.iotype_serializer.IOTypeSerializer"), + ("pickle", ".datastores.serializers.pickle_serializer.PickleSerializer"), +] + process_plugins(globals()) @@ -228,6 +235,7 @@ def get_runner_cli_path(): DEPLOYER_IMPL_PROVIDERS = resolve_plugins("deployer_impl_provider") TL_PLUGINS = resolve_plugins("tl_plugin") +ARTIFACT_SERIALIZERS = resolve_plugins("artifact_serializer") from .cards.card_modules import MF_EXTERNAL_CARDS diff --git a/metaflow/plugins/datastores/serializers/__init__.py b/metaflow/plugins/datastores/serializers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/plugins/datastores/serializers/iotype_serializer.py b/metaflow/plugins/datastores/serializers/iotype_serializer.py new file mode 100644 index 00000000000..89e163744ea --- /dev/null +++ b/metaflow/plugins/datastores/serializers/iotype_serializer.py @@ -0,0 +1,62 @@ +import importlib + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, +) +from metaflow.io_types.base import IOType + + +class IOTypeSerializer(ArtifactSerializer): + """ + Bridge between the IOType system and the pluggable serializer framework. + + Auto-detects IOType instances via isinstance check. On deserialization, + reconstructs the original IOType subclass from metadata (module + class name). + + PRIORITY is 50: higher than default (100) but lower than domain-specific + serializers. Always before PickleSerializer (9999). + """ + + TYPE = "iotype" + PRIORITY = 50 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, IOType) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding.startswith("iotype:") + + @classmethod + def serialize(cls, obj): + blobs, meta_dict = obj.serialize(format="storage") + return ( + blobs, + SerializationMetadata( + type=obj.type_name, + size=sum( + len(b.value) for b in blobs if isinstance(b.value, bytes) + ), + encoding="iotype:%s" % obj.type_name, + serializer_info={ + "iotype_module": obj.__class__.__module__, + "iotype_class": obj.__class__.__name__, + **meta_dict, + }, + ), + ) + + @classmethod + def deserialize(cls, blobs, metadata, context): + info = metadata.serializer_info + mod = importlib.import_module(info["iotype_module"]) + iotype_cls = getattr(mod, info["iotype_class"]) + # Security: only allow actual IOType subclasses, not arbitrary classes + if not (isinstance(iotype_cls, type) and issubclass(iotype_cls, IOType)): + raise ValueError( + "IOTypeSerializer metadata references '%s.%s' which is not an " + "IOType subclass" % (info["iotype_module"], info["iotype_class"]) + ) + return iotype_cls.deserialize(blobs, format="storage", metadata=info) diff --git a/metaflow/plugins/datastores/serializers/pickle_serializer.py b/metaflow/plugins/datastores/serializers/pickle_serializer.py new file mode 100644 index 00000000000..01320b9898c --- /dev/null +++ b/metaflow/plugins/datastores/serializers/pickle_serializer.py @@ -0,0 +1,49 @@ +import pickle + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, +) + + +class PickleSerializer(ArtifactSerializer): + """ + Default serializer using Python's pickle module. + + This is the universal fallback — can_serialize always returns True. + PRIORITY is set to 9999 so custom serializers are always tried first. + """ + + TYPE = "pickle" + PRIORITY = 9999 + + _ENCODINGS = frozenset( + ["pickle-v2", "pickle-v4", "gzip+pickle-v2", "gzip+pickle-v4"] + ) + + @classmethod + def can_serialize(cls, obj): + return True + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding in cls._ENCODINGS + + @classmethod + def serialize(cls, obj): + blob = pickle.dumps(obj, protocol=4) + encoding = "pickle-v4" + return ( + [SerializedBlob(blob, is_reference=False, compress_method="gzip")], + SerializationMetadata( + type=str(type(obj)), + size=len(blob), + encoding=encoding, + serializer_info={}, + ), + ) + + @classmethod + def deserialize(cls, blobs, metadata, context): + return pickle.loads(blobs[0]) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 2365892fca6..2c7a9444558 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -238,7 +238,7 @@ def _format_input_paths(task_pathspec, attempt): if self._input_paths: return self._input_paths - if self._step_func.name == "start": + if self._step_func.name == self._graph.start_step: from metaflow import Step flow_name, run_id, _, _ = self._spin_pathspec.split("/") @@ -514,6 +514,28 @@ def persist_constants(self, task_id=None): if not self._params_task.is_cloned: self._params_task.persist(self._flow) + # Register start/end step metadata so the client can determine + # graph endpoints without loading _graph_info. + self._metadata.register_metadata( + self._run_id, + "_parameters", + self._params_task.task_id, + [ + MetaDatum( + field="start_step", + value=self._graph.start_step, + type="graph_structure", + tags=[], + ), + MetaDatum( + field="end_step", + value=self._graph.end_step, + type="graph_structure", + tags=[], + ), + ], + ) + self._is_cloned[self._params_task.path] = self._params_task.is_cloned def should_skip_clone_only_execution(self): @@ -766,10 +788,11 @@ def execute(self): self._run_queue = [] self._active_tasks[0] = 0 else: + first_step = self._graph.start_step if self._params_task: - self._queue_push("start", {"input_paths": [self._params_task.path]}) + self._queue_push(first_step, {"input_paths": [self._params_task.path]}) else: - self._queue_push("start", {}) + self._queue_push(first_step, {}) progress_tstamp = time.time() with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: @@ -938,8 +961,9 @@ def execute(self): deco.runtime_finished(exception) self._run_exit_hooks() - # assert that end was executed and it was successful - if ("end", (), ()) in self._finished: + # assert that the terminal step was executed and it was successful + end = self._graph.end_step + if (end, (), ()) in self._finished: if self._run_url: self._logger( "Done! See the run in the UI at %s" % self._run_url, @@ -956,7 +980,8 @@ def execute(self): self._params_task.mark_resume_done() else: raise MetaflowInternalError( - "The *end* step was not successful by the end of flow." + "The terminal step *%s* was not successful by the end of flow." + % self._graph.end_step ) def _run_exit_hooks(self): @@ -966,7 +991,11 @@ def _run_exit_hooks(self): if not exit_hook_decos: return - successful = ("end", (), ()) in self._finished or self._clone_only + successful = ( + self._graph.end_step, + (), + (), + ) in self._finished or self._clone_only pathspec = f"{self._graph.name}/{self._run_id}" flow_file = self._environment.get_environment_info()["script"] diff --git a/metaflow/task.py b/metaflow/task.py index cc6304e20e6..5cd3e9d25eb 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -324,7 +324,7 @@ def _init_foreach(self, step_name, join_type, inputs, split_index): # then used later to write the foreach-stack metadata for that task # case 1) - reset the stack - if step_name == "start": + if step_name == self.flow._graph.start_step: self.flow._foreach_stack = [] # case 2) - this is a join step diff --git a/test/unit/graph_inference/__init__.py b/test/unit/graph_inference/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/unit/graph_inference/conftest.py b/test/unit/graph_inference/conftest.py new file mode 100644 index 00000000000..b409981980c --- /dev/null +++ b/test/unit/graph_inference/conftest.py @@ -0,0 +1,49 @@ +import os +import pytest +from metaflow import Runner, Flow + +FLOWS_DIR = os.path.join(os.path.dirname(__file__), "flows") + + +def create_flow_fixture(flow_name, flow_file, **runner_kwargs): + """Factory function to create flow fixtures that run via Runner.""" + + def flow_fixture(request): + flow_path = os.path.join(FLOWS_DIR, flow_file) + with Runner(flow_path, cwd=FLOWS_DIR, **runner_kwargs).run() as running: + return running.run + + return flow_fixture + + +custom_named_run = pytest.fixture(scope="session")( + create_flow_fixture("CustomNamedFlow", "custom_named_flow.py") +) + +single_step_run = pytest.fixture(scope="session")( + create_flow_fixture("SingleStepFlow", "single_step_flow.py") +) + +standard_run = pytest.fixture(scope="session")( + create_flow_fixture("StandardFlow", "standard_flow.py") +) + +algo_spec_run = pytest.fixture(scope="session")( + create_flow_fixture("SquareModel", "algo_spec_flow.py") +) + +config_algo_spec_run = pytest.fixture(scope="session")( + create_flow_fixture( + "ConfigAlgoSpec", + "algo_spec_config_flow.py", + environment="conda", + ) +) + +project_algo_spec_run = pytest.fixture(scope="session")( + create_flow_fixture( + "ProjectAlgoSpec", + "algo_spec_project_flow.py", + environment="conda", + ) +) diff --git a/test/unit/graph_inference/flows/__init__.py b/test/unit/graph_inference/flows/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/unit/graph_inference/flows/algo_spec_config_flow.py b/test/unit/graph_inference/flows/algo_spec_config_flow.py new file mode 100644 index 00000000000..5ea27b02df2 --- /dev/null +++ b/test/unit/graph_inference/flows/algo_spec_config_flow.py @@ -0,0 +1,19 @@ +from metaflow import Parameter, Config, conda_base +from metaflow.algospec import AlgoSpec + + +@conda_base(python="3.10") +class ConfigAlgoSpec(AlgoSpec): + """AlgoSpec with Config and @conda_base flow decorator.""" + + config = Config("config", default="config.json") + + multiplier = Parameter("multiplier", type=float, default=2.0) + + def call(self): + scale = self.config["scale"] + self.result = 5**2 * self.multiplier * scale + + +if __name__ == "__main__": + ConfigAlgoSpec() diff --git a/test/unit/graph_inference/flows/algo_spec_flow.py b/test/unit/graph_inference/flows/algo_spec_flow.py new file mode 100644 index 00000000000..8c5ab840be5 --- /dev/null +++ b/test/unit/graph_inference/flows/algo_spec_flow.py @@ -0,0 +1,18 @@ +from metaflow import Parameter +from metaflow.algospec import AlgoSpec + + +class SquareModel(AlgoSpec): + """AlgoSpec test — step name should be 'squaremodel'.""" + + multiplier = Parameter("multiplier", type=float, default=1.0) + + def init(self): + pass + + def call(self): + self.result = 5**2 * self.multiplier + + +if __name__ == "__main__": + SquareModel() diff --git a/test/unit/graph_inference/flows/algo_spec_project_flow.py b/test/unit/graph_inference/flows/algo_spec_project_flow.py new file mode 100644 index 00000000000..f547eb88c54 --- /dev/null +++ b/test/unit/graph_inference/flows/algo_spec_project_flow.py @@ -0,0 +1,20 @@ +from metaflow import Parameter, project, pypi_base +from metaflow.algospec import AlgoSpec + + +@project(name="test_algo_project") +@pypi_base(packages={"requests": "2.31.0"}) +class ProjectAlgoSpec(AlgoSpec): + """AlgoSpec with @project and @pypi_base flow decorators.""" + + value = Parameter("value", type=int, default=7) + + def call(self): + import requests + + self.result = self.value**2 + self.requests_version = requests.__version__ + + +if __name__ == "__main__": + ProjectAlgoSpec() diff --git a/test/unit/graph_inference/flows/config.json b/test/unit/graph_inference/flows/config.json new file mode 100644 index 00000000000..907056938e2 --- /dev/null +++ b/test/unit/graph_inference/flows/config.json @@ -0,0 +1 @@ +{"scale": 3.0} diff --git a/test/unit/graph_inference/flows/custom_named_flow.py b/test/unit/graph_inference/flows/custom_named_flow.py new file mode 100644 index 00000000000..8c2cc6903a7 --- /dev/null +++ b/test/unit/graph_inference/flows/custom_named_flow.py @@ -0,0 +1,23 @@ +from metaflow import FlowSpec, step + + +class CustomNamedFlow(FlowSpec): + """Flow with non-standard step names using @step(start/end=True).""" + + @step(start=True) + def begin(self): + self.x = 1 + self.next(self.middle) + + @step + def middle(self): + self.x += 1 + self.next(self.finish) + + @step(end=True) + def finish(self): + self.x += 1 + + +if __name__ == "__main__": + CustomNamedFlow() diff --git a/test/unit/graph_inference/flows/single_step_flow.py b/test/unit/graph_inference/flows/single_step_flow.py new file mode 100644 index 00000000000..3ba8ff723de --- /dev/null +++ b/test/unit/graph_inference/flows/single_step_flow.py @@ -0,0 +1,13 @@ +from metaflow import FlowSpec, step + + +class SingleStepFlow(FlowSpec): + """Flow with a single step — start==end.""" + + @step(start=True, end=True) + def only(self): + self.x = 42 + + +if __name__ == "__main__": + SingleStepFlow() diff --git a/test/unit/graph_inference/flows/standard_flow.py b/test/unit/graph_inference/flows/standard_flow.py new file mode 100644 index 00000000000..f3ba3e1604a --- /dev/null +++ b/test/unit/graph_inference/flows/standard_flow.py @@ -0,0 +1,18 @@ +from metaflow import FlowSpec, step + + +class StandardFlow(FlowSpec): + """Standard flow with start/end names — backward compat test.""" + + @step + def start(self): + self.x = 1 + self.next(self.end) + + @step + def end(self): + pass + + +if __name__ == "__main__": + StandardFlow() diff --git a/test/unit/graph_inference/test_graph_inference.py b/test/unit/graph_inference/test_graph_inference.py new file mode 100644 index 00000000000..fd60c8dc30e --- /dev/null +++ b/test/unit/graph_inference/test_graph_inference.py @@ -0,0 +1,195 @@ +""" +Unit tests for start_step/end_step inference and @step(start/end) kwargs. + +Tests graph construction only — no execution needed. +""" + +import pytest +from metaflow.graph import FlowGraph + + +class TestStandardFlow: + """Backward compat: steps named start/end are detected by name.""" + + def setup_method(self): + from .flows.standard_flow import StandardFlow + + self.graph = StandardFlow._graph + + def test_start_step(self): + assert self.graph.start_step == "start" + + def test_end_step(self): + assert self.graph.end_step == "end" + + def test_start_type(self): + assert self.graph.nodes["start"].type == "start" + + def test_end_type(self): + assert self.graph.nodes["end"].type == "end" + + def test_sorted_nodes(self): + assert self.graph.sorted_nodes == ["start", "end"] + + +class TestCustomNamedFlow: + """@step(start=True) / @step(end=True) with non-standard names.""" + + def setup_method(self): + from .flows.custom_named_flow import CustomNamedFlow + + self.graph = CustomNamedFlow._graph + + def test_start_step(self): + assert self.graph.start_step == "begin" + + def test_end_step(self): + assert self.graph.end_step == "finish" + + def test_start_type(self): + assert self.graph.nodes["begin"].type == "start" + + def test_end_type(self): + assert self.graph.nodes["finish"].type == "end" + + def test_middle_type(self): + assert self.graph.nodes["middle"].type == "linear" + + def test_sorted_nodes(self): + assert self.graph.sorted_nodes == ["begin", "middle", "finish"] + + def test_output_steps(self): + steps_info, structure = self.graph.output_steps() + assert "begin" in steps_info + assert "middle" in steps_info + assert "finish" in steps_info + assert structure[-1] == "finish" + + +class TestSingleStepFlow: + """@step(start=True, end=True) — start == end.""" + + def setup_method(self): + from .flows.single_step_flow import SingleStepFlow + + self.graph = SingleStepFlow._graph + + def test_start_equals_end(self): + assert self.graph.start_step == "only" + assert self.graph.end_step == "only" + assert self.graph.start_step == self.graph.end_step + + def test_type_is_end(self): + # Single-step: terminal node + assert self.graph.nodes["only"].type == "end" + + def test_sorted_nodes(self): + assert self.graph.sorted_nodes == ["only"] + + def test_output_steps(self): + steps_info, structure = self.graph.output_steps() + assert list(steps_info.keys()) == ["only"] + assert structure == ["only"] + + +class TestAlgoSpec: + """AlgoSpec: call method becomes a step named after the class.""" + + def setup_method(self): + from .flows.algo_spec_flow import SquareModel + + self.graph = SquareModel._graph + self.cls = SquareModel + + def test_start_step_is_class_name(self): + assert self.graph.start_step == "squaremodel" + + def test_start_equals_end(self): + assert self.graph.start_step == self.graph.end_step + + def test_node_exists(self): + assert "squaremodel" in self.graph.nodes + + def test_type_is_end(self): + assert self.graph.nodes["squaremodel"].type == "end" + + def test_is_algo_spec(self): + assert self.graph.is_algo_spec is True + + def test_call_is_step(self): + assert hasattr(self.cls.call, "is_step") + assert self.cls.call.is_step is True + assert self.cls.call.is_start is True + assert self.cls.call.is_end is True + + def test_output_steps(self): + steps_info, structure = self.graph.output_steps() + assert list(steps_info.keys()) == ["squaremodel"] + assert structure == ["squaremodel"] + + def test_direct_callable(self): + model = self.cls(use_cli=False) + model.multiplier = 3.0 + model.call() + assert model.result == 75.0 # 5^2 * 3.0 + + +class TestConfigAlgoSpec: + """AlgoSpec with Config and @conda_base flow decorator.""" + + def setup_method(self): + from .flows.algo_spec_config_flow import ConfigAlgoSpec + + self.cls = ConfigAlgoSpec + self.graph = ConfigAlgoSpec._graph + + def test_step_name(self): + assert self.graph.start_step == "configalgospec" + + def test_start_equals_end(self): + assert self.graph.start_step == self.graph.end_step + + def test_has_config(self): + assert hasattr(self.cls, "config") + + def test_has_parameters(self): + params = dict(self.cls._get_parameters()) + assert "multiplier" in params + + def test_conda_base_applied(self): + from metaflow.flowspec import FlowStateItems + + flow_decos = self.cls._flow_state.get(FlowStateItems.FLOW_DECORATORS, {}) + assert "conda_base" in flow_decos + + +class TestProjectAlgoSpec: + """AlgoSpec with @project and @pypi_base flow decorators.""" + + def setup_method(self): + from .flows.algo_spec_project_flow import ProjectAlgoSpec + + self.cls = ProjectAlgoSpec + self.graph = ProjectAlgoSpec._graph + + def test_step_name(self): + assert self.graph.start_step == "projectalgospec" + + def test_start_equals_end(self): + assert self.graph.start_step == self.graph.end_step + + def test_project_applied(self): + from metaflow.flowspec import FlowStateItems + + flow_decos = self.cls._flow_state.get(FlowStateItems.FLOW_DECORATORS, {}) + assert "project" in flow_decos + + def test_pypi_base_applied(self): + from metaflow.flowspec import FlowStateItems + + flow_decos = self.cls._flow_state.get(FlowStateItems.FLOW_DECORATORS, {}) + assert "pypi_base" in flow_decos + + def test_has_parameter(self): + params = dict(self.cls._get_parameters()) + assert "value" in params diff --git a/test/unit/graph_inference/test_graph_integration.py b/test/unit/graph_inference/test_graph_integration.py new file mode 100644 index 00000000000..b87ee7b9c9f --- /dev/null +++ b/test/unit/graph_inference/test_graph_integration.py @@ -0,0 +1,194 @@ +""" +Integration tests for start_step/end_step — actually executes flows. + +Verifies: +- Flows run to completion +- _graph_info contains start_step/end_step +- _parameters metadata contains start_step/end_step +- Client APIs (end_task, parent_steps, child_steps) work correctly +- Artifacts are persisted and readable +""" + + +class TestStandardFlowIntegration: + """Standard flow with start/end names — backward compat.""" + + def test_flow_completes(self, standard_run): + assert standard_run.successful + assert standard_run.finished + + def test_graph_info_endpoints(self, standard_run): + graph_info = standard_run["_parameters"].task["_graph_info"].data + assert graph_info["start_step"] == "start" + assert graph_info["end_step"] == "end" + + def test_parameters_metadata(self, standard_run): + meta = standard_run["_parameters"].task.metadata_dict + assert meta.get("start_step") == "start" + assert meta.get("end_step") == "end" + + def test_end_task(self, standard_run): + assert standard_run.end_task is not None + + def test_steps_present(self, standard_run): + step_names = {s.id for s in standard_run} + assert step_names == {"start", "end"} + + +class TestCustomNamedFlowIntegration: + """Flow with @step(start=True) / @step(end=True) and custom names.""" + + def test_flow_completes(self, custom_named_run): + assert custom_named_run.successful + assert custom_named_run.finished + + def test_graph_info_endpoints(self, custom_named_run): + graph_info = custom_named_run["_parameters"].task["_graph_info"].data + assert graph_info["start_step"] == "begin" + assert graph_info["end_step"] == "finish" + + def test_parameters_metadata(self, custom_named_run): + meta = custom_named_run["_parameters"].task.metadata_dict + assert meta.get("start_step") == "begin" + assert meta.get("end_step") == "finish" + + def test_graph_endpoints_property(self, custom_named_run): + start, end = custom_named_run._graph_endpoints + assert start == "begin" + assert end == "finish" + + def test_end_task(self, custom_named_run): + end_task = custom_named_run.end_task + assert end_task is not None + assert end_task["x"].data == 3 + + def test_steps_present(self, custom_named_run): + step_names = {s.id for s in custom_named_run} + assert step_names == {"begin", "middle", "finish"} + + def test_parent_steps(self, custom_named_run): + begin_parents = list(custom_named_run["begin"].parent_steps) + assert begin_parents == [] + + middle_parents = [s.id for s in custom_named_run["middle"].parent_steps] + assert middle_parents == ["begin"] + + finish_parents = [s.id for s in custom_named_run["finish"].parent_steps] + assert finish_parents == ["middle"] + + def test_child_steps(self, custom_named_run): + begin_children = [s.id for s in custom_named_run["begin"].child_steps] + assert begin_children == ["middle"] + + middle_children = [s.id for s in custom_named_run["middle"].child_steps] + assert middle_children == ["finish"] + + finish_children = list(custom_named_run["finish"].child_steps) + assert finish_children == [] + + +class TestSingleStepFlowIntegration: + """Single step where start == end.""" + + def test_flow_completes(self, single_step_run): + assert single_step_run.successful + assert single_step_run.finished + + def test_graph_info_start_equals_end(self, single_step_run): + graph_info = single_step_run["_parameters"].task["_graph_info"].data + assert graph_info["start_step"] == "only" + assert graph_info["end_step"] == "only" + + def test_parameters_metadata(self, single_step_run): + meta = single_step_run["_parameters"].task.metadata_dict + assert meta.get("start_step") == "only" + assert meta.get("end_step") == "only" + + def test_end_task(self, single_step_run): + end_task = single_step_run.end_task + assert end_task is not None + assert end_task["x"].data == 42 + + def test_single_step_present(self, single_step_run): + step_names = {s.id for s in single_step_run} + assert step_names == {"only"} + + def test_parent_child_empty(self, single_step_run): + parents = list(single_step_run["only"].parent_steps) + children = list(single_step_run["only"].child_steps) + assert parents == [] + assert children == [] + + +class TestAlgoSpecIntegration: + """AlgoSpec — step named after the class.""" + + def test_flow_completes(self, algo_spec_run): + assert algo_spec_run.successful + assert algo_spec_run.finished + + def test_graph_info_endpoints(self, algo_spec_run): + graph_info = algo_spec_run["_parameters"].task["_graph_info"].data + assert graph_info["start_step"] == "squaremodel" + assert graph_info["end_step"] == "squaremodel" + + def test_parameters_metadata(self, algo_spec_run): + meta = algo_spec_run["_parameters"].task.metadata_dict + assert meta.get("start_step") == "squaremodel" + assert meta.get("end_step") == "squaremodel" + + def test_end_task_data(self, algo_spec_run): + end_task = algo_spec_run.end_task + assert end_task is not None + assert end_task["result"].data == 25.0 # 5^2 * 1.0 (default multiplier) + + def test_single_step(self, algo_spec_run): + step_names = {s.id for s in algo_spec_run} + assert step_names == {"squaremodel"} + + +class TestConfigAlgoSpecIntegration: + """AlgoSpec with Config and @conda_base — full execution.""" + + def test_flow_completes(self, config_algo_spec_run): + assert config_algo_spec_run.successful + assert config_algo_spec_run.finished + + def test_config_used_in_computation(self, config_algo_spec_run): + """Config scale=3.0, multiplier=2.0: 25 * 2.0 * 3.0 = 150.0""" + assert config_algo_spec_run.end_task["result"].data == 150.0 + + def test_graph_endpoints(self, config_algo_spec_run): + graph_info = config_algo_spec_run["_parameters"].task["_graph_info"].data + assert graph_info["start_step"] == "configalgospec" + assert graph_info["end_step"] == "configalgospec" + + def test_parameters_metadata(self, config_algo_spec_run): + meta = config_algo_spec_run["_parameters"].task.metadata_dict + assert meta.get("start_step") == "configalgospec" + assert meta.get("end_step") == "configalgospec" + + +class TestProjectAlgoSpecIntegration: + """AlgoSpec with @project and @pypi_base — runs in pypi env.""" + + def test_flow_completes(self, project_algo_spec_run): + assert project_algo_spec_run.successful + assert project_algo_spec_run.finished + + def test_computation_correct(self, project_algo_spec_run): + """value=7 (default): 7^2 = 49""" + assert project_algo_spec_run.end_task["result"].data == 49 + + def test_pypi_package_available(self, project_algo_spec_run): + """requests was installed via @pypi_base and used in call().""" + assert project_algo_spec_run.end_task["requests_version"].data == "2.31.0" + + def test_graph_endpoints(self, project_algo_spec_run): + graph_info = project_algo_spec_run["_parameters"].task["_graph_info"].data + assert graph_info["start_step"] == "projectalgospec" + assert graph_info["end_step"] == "projectalgospec" + + def test_single_step(self, project_algo_spec_run): + step_names = {s.id for s in project_algo_spec_run} + assert step_names == {"projectalgospec"} diff --git a/test/unit/io_types/__init__.py b/test/unit/io_types/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/unit/io_types/test_collections.py b/test/unit/io_types/test_collections.py new file mode 100644 index 00000000000..8ade0cb1819 --- /dev/null +++ b/test/unit/io_types/test_collections.py @@ -0,0 +1,103 @@ +from metaflow.io_types import List, Map, Text, Int64 + + +# --------------------------------------------------------------------------- +# List +# --------------------------------------------------------------------------- + + +def test_list_wire_round_trip(): + l = List([1, 2, 3]) + s = l.serialize(format="wire") + l2 = List.deserialize(s, format="wire") + assert l2.value == [1, 2, 3] + + +def test_list_storage_round_trip(): + l = List(["a", "b", "c"], element_type=Text) + blobs, meta = l.serialize(format="storage") + assert meta["element_type"] == "text" + l2 = List.deserialize([b.value for b in blobs], format="storage") + assert l2.value == ["a", "b", "c"] + + +def test_list_to_spec(): + l = List(element_type=Text) + spec = l.to_spec() + assert spec == {"type": "list", "element_type": {"type": "text"}} + + +def test_list_to_spec_no_element_type(): + l = List() + assert l.to_spec() == {"type": "list"} + + +def test_list_nested(): + l = List([[1, 2], [3, 4]]) + blobs, _ = l.serialize(format="storage") + l2 = List.deserialize([b.value for b in blobs], format="storage") + assert l2.value == [[1, 2], [3, 4]] + + +def test_list_empty(): + l = List([]) + blobs, _ = l.serialize(format="storage") + l2 = List.deserialize([b.value for b in blobs], format="storage") + assert l2.value == [] + + +def test_list_mixed_types(): + l = List([1, "two", None, True]) + s = l.serialize(format="wire") + l2 = List.deserialize(s, format="wire") + assert l2.value == [1, "two", None, True] + + +# --------------------------------------------------------------------------- +# Map +# --------------------------------------------------------------------------- + + +def test_map_wire_round_trip(): + m = Map({"a": 1, "b": 2}) + s = m.serialize(format="wire") + m2 = Map.deserialize(s, format="wire") + assert m2.value == {"a": 1, "b": 2} + + +def test_map_storage_round_trip(): + m = Map({"x": 10, "y": 20}, key_type=Text, value_type=Int64) + blobs, meta = m.serialize(format="storage") + assert meta["key_type"] == "text" + assert meta["value_type"] == "int64" + m2 = Map.deserialize([b.value for b in blobs], format="storage") + assert m2.value == {"x": 10, "y": 20} + + +def test_map_to_spec(): + m = Map(key_type=Text, value_type=Int64) + spec = m.to_spec() + assert spec == { + "type": "map", + "key_type": {"type": "text"}, + "value_type": {"type": "int64"}, + } + + +def test_map_to_spec_no_types(): + m = Map() + assert m.to_spec() == {"type": "map"} + + +def test_map_nested_values(): + m = Map({"a": {"nested": [1, 2]}, "b": {"nested": [3]}}) + blobs, _ = m.serialize(format="storage") + m2 = Map.deserialize([b.value for b in blobs], format="storage") + assert m2.value == {"a": {"nested": [1, 2]}, "b": {"nested": [3]}} + + +def test_map_empty(): + m = Map({}) + blobs, _ = m.serialize(format="storage") + m2 = Map.deserialize([b.value for b in blobs], format="storage") + assert m2.value == {} diff --git a/test/unit/io_types/test_enum_type.py b/test/unit/io_types/test_enum_type.py new file mode 100644 index 00000000000..b035c748cc5 --- /dev/null +++ b/test/unit/io_types/test_enum_type.py @@ -0,0 +1,45 @@ +import pytest + +from metaflow.io_types import Enum + + +def test_enum_wire_round_trip(): + e = Enum("red", allowed_values=["red", "green", "blue"]) + s = e.serialize(format="wire") + assert s == "red" + e2 = Enum.deserialize(s, format="wire") + assert e2.value == "red" + + +def test_enum_storage_round_trip(): + e = Enum("green", allowed_values=["red", "green", "blue"]) + blobs, meta = e.serialize(format="storage") + assert meta["allowed_values"] == ["red", "green", "blue"] + e2 = Enum.deserialize( + [b.value for b in blobs], format="storage", metadata=meta + ) + assert e2.value == "green" + assert e2._allowed_values == ["red", "green", "blue"] + + +def test_enum_validation(): + with pytest.raises(ValueError, match="not in allowed values"): + Enum("yellow", allowed_values=["red", "green", "blue"]) + + +def test_enum_no_allowed_values(): + e = Enum("anything") + assert e.value == "anything" + s = e.serialize(format="wire") + assert s == "anything" + + +def test_enum_to_spec_with_values(): + e = Enum(allowed_values=["a", "b", "c"]) + spec = e.to_spec() + assert spec == {"type": "enum", "allowed_values": ["a", "b", "c"]} + + +def test_enum_to_spec_without_values(): + e = Enum() + assert e.to_spec() == {"type": "enum"} diff --git a/test/unit/io_types/test_iotype_serializer.py b/test/unit/io_types/test_iotype_serializer.py new file mode 100644 index 00000000000..19d532b8e1e --- /dev/null +++ b/test/unit/io_types/test_iotype_serializer.py @@ -0,0 +1,243 @@ +""" +End-to-end integration tests for IOTypeSerializer bridging IOType to +the pluggable serializer framework via TaskDataStore. +""" + +import os +from dataclasses import dataclass + +import pytest + +from metaflow.datastore.artifacts.serializer import SerializerStore +from metaflow.io_types import Bool, Enum, Float64, Int32, Int64, Json, List, Map, Text +from metaflow.io_types.struct_type import Struct +from metaflow.plugins.datastores.serializers.iotype_serializer import IOTypeSerializer +from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + + +# --------------------------------------------------------------------------- +# IOTypeSerializer unit tests +# --------------------------------------------------------------------------- + + +def test_can_serialize_iotype(): + assert IOTypeSerializer.can_serialize(Text("hello")) is True + assert IOTypeSerializer.can_serialize(Int64(42)) is True + assert IOTypeSerializer.can_serialize(Json({"a": 1})) is True + + +def test_cannot_serialize_plain_python(): + assert IOTypeSerializer.can_serialize("hello") is False + assert IOTypeSerializer.can_serialize(42) is False + assert IOTypeSerializer.can_serialize({"a": 1}) is False + + +def test_can_deserialize_iotype_encoding(): + from metaflow.datastore.artifacts.serializer import SerializationMetadata + + meta = SerializationMetadata("text", 5, "iotype:text", {}) + assert IOTypeSerializer.can_deserialize(meta) is True + + meta = SerializationMetadata("json", 10, "iotype:json", {}) + assert IOTypeSerializer.can_deserialize(meta) is True + + +def test_cannot_deserialize_pickle_encoding(): + from metaflow.datastore.artifacts.serializer import SerializationMetadata + + meta = SerializationMetadata("dict", 100, "pickle-v4", {}) + assert IOTypeSerializer.can_deserialize(meta) is False + + +def test_serialize_returns_correct_metadata(): + blobs, meta = IOTypeSerializer.serialize(Text("hello")) + assert meta.encoding == "iotype:text" + assert meta.type == "text" + assert meta.serializer_info["iotype_class"] == "Text" + assert "iotype_module" in meta.serializer_info + + +def test_round_trip_text(): + original = Text("hello world") + blobs, meta = IOTypeSerializer.serialize(original) + raw_blobs = [b.value for b in blobs] + result = IOTypeSerializer.deserialize(raw_blobs, meta, context=None) + assert isinstance(result, Text) + assert result.value == "hello world" + + +def test_round_trip_int64(): + original = Int64(999) + blobs, meta = IOTypeSerializer.serialize(original) + raw_blobs = [b.value for b in blobs] + result = IOTypeSerializer.deserialize(raw_blobs, meta, context=None) + assert isinstance(result, Int64) + assert result.value == 999 + + +def test_round_trip_json(): + original = Json({"nested": [1, 2, {"deep": True}]}) + blobs, meta = IOTypeSerializer.serialize(original) + raw_blobs = [b.value for b in blobs] + result = IOTypeSerializer.deserialize(raw_blobs, meta, context=None) + assert isinstance(result, Json) + assert result.value == {"nested": [1, 2, {"deep": True}]} + + +def test_round_trip_bool(): + for val in [True, False]: + original = Bool(val) + blobs, meta = IOTypeSerializer.serialize(original) + result = IOTypeSerializer.deserialize([b.value for b in blobs], meta, None) + assert result.value is val + + +def test_round_trip_enum(): + original = Enum("red", allowed_values=["red", "green", "blue"]) + blobs, meta = IOTypeSerializer.serialize(original) + result = IOTypeSerializer.deserialize([b.value for b in blobs], meta, None) + assert isinstance(result, Enum) + assert result.value == "red" + + +def test_round_trip_list(): + original = List([1, 2, 3], element_type=Int64) + blobs, meta = IOTypeSerializer.serialize(original) + result = IOTypeSerializer.deserialize([b.value for b in blobs], meta, None) + assert isinstance(result, List) + assert result.value == [1, 2, 3] + + +def test_round_trip_map(): + original = Map({"a": 1, "b": 2}, key_type=Text, value_type=Int64) + blobs, meta = IOTypeSerializer.serialize(original) + result = IOTypeSerializer.deserialize([b.value for b in blobs], meta, None) + assert isinstance(result, Map) + assert result.value == {"a": 1, "b": 2} + + +@dataclass +class _TestData: + name: str + value: int + + +def test_round_trip_struct(): + original = Struct(_TestData(name="test", value=42)) + blobs, meta = IOTypeSerializer.serialize(original) + result = IOTypeSerializer.deserialize([b.value for b in blobs], meta, None) + assert isinstance(result, Struct) + assert result.value == _TestData(name="test", value=42) + + +# --------------------------------------------------------------------------- +# Priority ordering +# --------------------------------------------------------------------------- + + +def test_deserialize_rejects_non_iotype_class(): + """Metadata pointing to a non-IOType class should be rejected.""" + from metaflow.datastore.artifacts.serializer import SerializationMetadata + + meta = SerializationMetadata( + type="text", + size=5, + encoding="iotype:text", + serializer_info={ + "iotype_module": "subprocess", + "iotype_class": "Popen", + }, + ) + with pytest.raises(ValueError, match="not an IOType subclass"): + IOTypeSerializer.deserialize([b"hello"], meta, context=None) + + +def test_iotype_priority_before_pickle(): + assert IOTypeSerializer.PRIORITY < PickleSerializer.PRIORITY + + +def test_registered_in_store(): + assert "iotype" in SerializerStore._all_serializers + + +# --------------------------------------------------------------------------- +# TaskDataStore integration +# --------------------------------------------------------------------------- + + +@pytest.fixture +def task_datastore(tmp_path): + from metaflow.datastore.flow_datastore import FlowDataStore + from metaflow.plugins.datastores.local_storage import LocalStorage + + storage_root = str(tmp_path / "datastore") + os.makedirs(storage_root, exist_ok=True) + + flow_ds = FlowDataStore( + flow_name="TestFlow", + environment=None, + metadata=None, + event_logger=None, + monitor=None, + storage_impl=LocalStorage, + ds_root=storage_root, + ) + + task_ds = flow_ds.get_task_datastore( + run_id="1", + step_name="start", + task_id="1", + attempt=0, + mode="w", + ) + task_ds.init_task() + # Use only IOTypeSerializer + PickleSerializer (isolate from test pollution) + task_ds._serializers = [IOTypeSerializer, PickleSerializer] + return task_ds + + +def test_iotype_through_datastore(task_datastore): + """IOType artifacts go through IOTypeSerializer, plain objects through pickle.""" + artifacts = [ + ("typed_json", Json({"key": "value"})), + ("typed_int", Int64(42)), + ("plain_dict", {"key": "value"}), # not wrapped in IOType + ] + task_datastore.save_artifacts(iter(artifacts)) + + # IOType artifacts use iotype encoding + assert task_datastore._info["typed_json"]["encoding"] == "iotype:json" + assert task_datastore._info["typed_int"]["encoding"] == "iotype:int64" + # Plain dict falls through to pickle + assert task_datastore._info["plain_dict"]["encoding"] == "pickle-v4" + + # All round-trip correctly + loaded = dict(task_datastore.load_artifacts(["typed_json", "typed_int", "plain_dict"])) + assert isinstance(loaded["typed_json"], Json) + assert loaded["typed_json"].value == {"key": "value"} + assert isinstance(loaded["typed_int"], Int64) + assert loaded["typed_int"].value == 42 + assert loaded["plain_dict"] == {"key": "value"} # plain dict, not IOType + + +def test_mixed_iotypes_through_datastore(task_datastore): + """Multiple IOType varieties in a single save/load cycle.""" + artifacts = [ + ("t", Text("hello")), + ("b", Bool(True)), + ("i32", Int32(100)), + ("f64", Float64(3.14)), + ("j", Json([1, 2, 3])), + ("e", Enum("red", allowed_values=["red", "green"])), + ("l", List([1, 2], element_type=Int64)), + ] + task_datastore.save_artifacts(iter(artifacts)) + loaded = dict(task_datastore.load_artifacts([name for name, _ in artifacts])) + + assert loaded["t"].value == "hello" + assert loaded["b"].value is True + assert loaded["i32"].value == 100 + assert loaded["f64"].value == 3.14 + assert loaded["j"].value == [1, 2, 3] + assert loaded["e"].value == "red" + assert loaded["l"].value == [1, 2] diff --git a/test/unit/io_types/test_json_type.py b/test/unit/io_types/test_json_type.py new file mode 100644 index 00000000000..5dd6cbe7942 --- /dev/null +++ b/test/unit/io_types/test_json_type.py @@ -0,0 +1,49 @@ +from metaflow.io_types import Json + + +def test_json_wire_round_trip_dict(): + j = Json({"key": "value", "nested": [1, 2, 3]}) + s = j.serialize(format="wire") + j2 = Json.deserialize(s, format="wire") + assert j2.value == {"key": "value", "nested": [1, 2, 3]} + + +def test_json_wire_round_trip_list(): + j = Json([1, "two", None, True]) + s = j.serialize(format="wire") + j2 = Json.deserialize(s, format="wire") + assert j2.value == [1, "two", None, True] + + +def test_json_storage_round_trip(): + j = Json({"a": {"b": [1, 2]}, "c": None}) + blobs, meta = j.serialize(format="storage") + assert len(blobs) == 1 + j2 = Json.deserialize([b.value for b in blobs], format="storage") + assert j2.value == j.value + + +def test_json_to_spec(): + assert Json().to_spec() == {"type": "json"} + + +def test_json_empty_dict(): + j = Json({}) + blobs, _ = j.serialize(format="storage") + j2 = Json.deserialize([b.value for b in blobs], format="storage") + assert j2.value == {} + + +def test_json_empty_list(): + j = Json([]) + blobs, _ = j.serialize(format="storage") + j2 = Json.deserialize([b.value for b in blobs], format="storage") + assert j2.value == [] + + +def test_json_deeply_nested(): + data = {"a": {"b": {"c": {"d": [1, 2, {"e": True}]}}}} + j = Json(data) + s = j.serialize(format="wire") + j2 = Json.deserialize(s, format="wire") + assert j2.value == data diff --git a/test/unit/io_types/test_scalars.py b/test/unit/io_types/test_scalars.py new file mode 100644 index 00000000000..5321d08e9f1 --- /dev/null +++ b/test/unit/io_types/test_scalars.py @@ -0,0 +1,234 @@ +import struct + +import pytest + +from metaflow.io_types import Bool, Float32, Float64, Int32, Int64, Text + + +# --------------------------------------------------------------------------- +# Text +# --------------------------------------------------------------------------- + + +def test_text_wire_round_trip(): + t = Text("hello world") + s = t.serialize(format="wire") + assert s == "hello world" + t2 = Text.deserialize(s, format="wire") + assert t2.value == "hello world" + + +def test_text_storage_round_trip(): + t = Text("utf-8 test: cafe\u0301") + blobs, meta = t.serialize(format="storage") + assert len(blobs) == 1 + t2 = Text.deserialize([b.value for b in blobs], format="storage") + assert t2.value == t.value + + +def test_text_to_spec(): + assert Text().to_spec() == {"type": "text"} + + +def test_text_empty_string(): + t = Text("") + assert t.serialize(format="wire") == "" + blobs, _ = t.serialize(format="storage") + t2 = Text.deserialize([b.value for b in blobs], format="storage") + assert t2.value == "" + + +# --------------------------------------------------------------------------- +# Bool +# --------------------------------------------------------------------------- + + +def test_bool_wire_round_trip(): + assert Bool(True).serialize(format="wire") == "true" + assert Bool(False).serialize(format="wire") == "false" + assert Bool.deserialize("true", format="wire").value is True + assert Bool.deserialize("false", format="wire").value is False + + +def test_bool_wire_case_insensitive(): + assert Bool.deserialize("TRUE", format="wire").value is True + assert Bool.deserialize("False", format="wire").value is False + + +def test_bool_wire_invalid(): + with pytest.raises(ValueError, match="true.*false"): + Bool.deserialize("yes", format="wire") + + +def test_bool_storage_round_trip(): + for val in [True, False]: + b = Bool(val) + blobs, _ = b.serialize(format="storage") + assert len(blobs[0].value) == 1 + b2 = Bool.deserialize([blobs[0].value], format="storage") + assert b2.value is val + + +def test_bool_storage_rejects_invalid_bytes(): + with pytest.raises(ValueError, match="0x00 or 0x01"): + Bool.deserialize([b"\x02"], format="storage") + with pytest.raises(ValueError, match="0x00 or 0x01"): + Bool.deserialize([b"\x00\x01"], format="storage") + + +def test_bool_to_spec(): + assert Bool().to_spec() == {"type": "bool"} + + +# --------------------------------------------------------------------------- +# Int32 +# --------------------------------------------------------------------------- + + +def test_int32_wire_round_trip(): + i = Int32(42) + s = i.serialize(format="wire") + assert s == "42" + i2 = Int32.deserialize(s, format="wire") + assert i2.value == 42 + + +def test_int32_storage_round_trip(): + i = Int32(-1000) + blobs, _ = i.serialize(format="storage") + assert len(blobs[0].value) == 4 + i2 = Int32.deserialize([blobs[0].value], format="storage") + assert i2.value == -1000 + + +def test_int32_storage_little_endian(): + i = Int32(1) + blobs, _ = i.serialize(format="storage") + assert blobs[0].value == b"\x01\x00\x00\x00" # little-endian + + +def test_int32_range_check(): + Int32(2**31 - 1) # max + Int32(-(2**31)) # min + with pytest.raises(ValueError, match="out of range"): + Int32(2**31) + with pytest.raises(ValueError, match="out of range"): + Int32(-(2**31) - 1) + + +def test_int32_zero(): + i = Int32(0) + blobs, _ = i.serialize(format="storage") + i2 = Int32.deserialize([blobs[0].value], format="storage") + assert i2.value == 0 + + +def test_int32_to_spec(): + assert Int32().to_spec() == {"type": "int32"} + + +# --------------------------------------------------------------------------- +# Int64 +# --------------------------------------------------------------------------- + + +def test_int64_wire_round_trip(): + i = Int64(2**40) + s = i.serialize(format="wire") + i2 = Int64.deserialize(s, format="wire") + assert i2.value == 2**40 + + +def test_int64_storage_round_trip(): + for val in [0, -1, 2**60, -(2**60)]: + i = Int64(val) + blobs, _ = i.serialize(format="storage") + assert len(blobs[0].value) == 8 + i2 = Int64.deserialize([blobs[0].value], format="storage") + assert i2.value == val + + +def test_int64_storage_little_endian(): + i = Int64(1) + blobs, _ = i.serialize(format="storage") + assert blobs[0].value == b"\x01\x00\x00\x00\x00\x00\x00\x00" + + +def test_int64_to_spec(): + assert Int64().to_spec() == {"type": "int64"} + + +# --------------------------------------------------------------------------- +# Float32 +# --------------------------------------------------------------------------- + + +def test_float32_wire_round_trip(): + f = Float32(3.14) + s = f.serialize(format="wire") + f2 = Float32.deserialize(s, format="wire") + assert abs(f2.value - 3.14) < 0.01 # float32 precision + + +def test_float32_storage_round_trip(): + f = Float32(1.5) # exactly representable in float32 + blobs, _ = f.serialize(format="storage") + assert len(blobs[0].value) == 4 + f2 = Float32.deserialize([blobs[0].value], format="storage") + assert f2.value == 1.5 + + +def test_float32_to_spec(): + assert Float32().to_spec() == {"type": "float32"} + + +# --------------------------------------------------------------------------- +# Float64 +# --------------------------------------------------------------------------- + + +def test_float64_wire_round_trip(): + f = Float64(3.141592653589793) + s = f.serialize(format="wire") + f2 = Float64.deserialize(s, format="wire") + assert f2.value == 3.141592653589793 + + +def test_float64_storage_round_trip(): + f = Float64(2.718281828459045) + blobs, _ = f.serialize(format="storage") + assert len(blobs[0].value) == 8 + f2 = Float64.deserialize([blobs[0].value], format="storage") + assert f2.value == 2.718281828459045 + + +def test_float64_to_spec(): + assert Float64().to_spec() == {"type": "float64"} + + +# --------------------------------------------------------------------------- +# IOType base behavior +# --------------------------------------------------------------------------- + + +def test_repr_with_value(): + assert repr(Int64(42)) == "Int64(42)" + assert repr(Text("hi")) == "Text('hi')" + + +def test_repr_without_value(): + assert repr(Int64()) == "Int64()" + + +def test_equality(): + assert Int64(42) == Int64(42) + assert Int64(42) != Int64(43) + assert Text("a") == Text("a") + assert Text("a") != Int64(0) # different types + + +def test_invalid_format(): + with pytest.raises(ValueError, match="format must be"): + Int64(42).serialize(format="invalid") + with pytest.raises(ValueError, match="format must be"): + Int64.deserialize("42", format="invalid") diff --git a/test/unit/io_types/test_struct_type.py b/test/unit/io_types/test_struct_type.py new file mode 100644 index 00000000000..387076f3437 --- /dev/null +++ b/test/unit/io_types/test_struct_type.py @@ -0,0 +1,102 @@ +import dataclasses +from dataclasses import dataclass + +import pytest + +from metaflow.io_types import Struct + + +@dataclass +class SimpleData: + name: str + count: int + score: float + active: bool + + +@dataclass +class NestedData: + label: str + sub: dict # not auto-inferred, but works with JSON serde + + +def test_struct_wire_round_trip(): + s = Struct(SimpleData(name="test", count=5, score=3.14, active=True)) + wire = s.serialize(format="wire") + s2 = Struct.deserialize(wire, format="wire") + # Wire deserializes to dict (no dataclass type info in wire format) + assert s2.value == {"name": "test", "count": 5, "score": 3.14, "active": True} + + +def test_struct_storage_round_trip(): + original = SimpleData(name="test", count=5, score=3.14, active=True) + s = Struct(original) + blobs, meta = s.serialize(format="storage") + assert "dataclass_module" in meta + assert "dataclass_class" in meta + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata=meta) + assert s2.value == original + assert type(s2.value) is SimpleData + + +def test_struct_storage_without_dataclass_type(): + """When metadata lacks dataclass info, falls back to dict.""" + s = Struct(SimpleData(name="x", count=1, score=0.0, active=False)) + blobs, _ = s.serialize(format="storage") + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata={}) + assert isinstance(s2.value, dict) + assert s2.value["name"] == "x" + + +def test_struct_to_spec(): + s = Struct(dataclass_type=SimpleData) + spec = s.to_spec() + assert spec["type"] == "struct" + assert len(spec["fields"]) == 4 + field_names = [f["name"] for f in spec["fields"]] + assert field_names == ["name", "count", "score", "active"] + # Check implicit mapping + name_field = next(f for f in spec["fields"] if f["name"] == "name") + assert name_field["type"] == "text" + count_field = next(f for f in spec["fields"] if f["name"] == "count") + assert count_field["type"] == "int64" + + +def test_struct_to_spec_no_dataclass(): + s = Struct() + assert s.to_spec() == {"type": "struct"} + + +def test_struct_nested_data(): + nd = NestedData(label="test", sub={"key": [1, 2, 3]}) + s = Struct(nd) + blobs, meta = s.serialize(format="storage") + s2 = Struct.deserialize([b.value for b in blobs], format="storage", metadata=meta) + assert s2.value == nd + + +def test_struct_wire_deserialize_then_reserialize(): + """Wire round-trip: deserialize returns dict, re-serialize should work on dict.""" + s = Struct(SimpleData(name="test", count=5, score=3.14, active=True)) + wire = s.serialize(format="wire") + s2 = Struct.deserialize(wire, format="wire") + # s2 wraps a dict — re-serializing should work + wire2 = s2.serialize(format="wire") + s3 = Struct.deserialize(wire2, format="wire") + assert s3.value == s2.value + + +def test_struct_security_rejects_non_dataclass(): + """Metadata pointing to a non-dataclass class should be rejected.""" + blobs = [b'{"cmd": "echo pwned"}'] + meta = {"dataclass_module": "subprocess", "dataclass_class": "Popen"} + with pytest.raises(ValueError, match="not a dataclass"): + Struct.deserialize(blobs, format="storage", metadata=meta) + + +def test_struct_plain_dict_value(): + """Struct wrapping a plain dict works for serde.""" + s = Struct({"x": 1, "y": "hello"}) + wire = s.serialize(format="wire") + s2 = Struct.deserialize(wire, format="wire") + assert s2.value == {"x": 1, "y": "hello"} diff --git a/test/unit/io_types/test_tensor_type.py b/test/unit/io_types/test_tensor_type.py new file mode 100644 index 00000000000..4f32330e788 --- /dev/null +++ b/test/unit/io_types/test_tensor_type.py @@ -0,0 +1,101 @@ +import pytest + +np = pytest.importorskip("numpy") + +from metaflow.io_types import Tensor + + +def test_tensor_wire_round_trip(): + arr = np.array([1.0, 2.0, 3.0], dtype=np.float64) + t = Tensor(arr) + s = t.serialize(format="wire") + assert "|" in s # header|base64 + t2 = Tensor.deserialize(s, format="wire") + np.testing.assert_array_equal(t2.value, arr) + + +def test_tensor_storage_round_trip(): + arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + t = Tensor(arr) + blobs, meta = t.serialize(format="storage") + assert meta["dtype"] in (" 1: + order = SerializerStore._registration_order + indices = [order.index(s.TYPE) for s in priority_100] + assert indices == sorted(indices) + + +def test_deterministic_ordering(): + """Calling get_ordered_serializers twice returns the same order.""" + first = SerializerStore.get_ordered_serializers() + second = SerializerStore.get_ordered_serializers() + assert [s.TYPE for s in first] == [s.TYPE for s in second] + + +def test_high_priority_before_low(): + """_HighPrioritySerializer (PRIORITY=10) comes before _LowPrioritySerializer (PRIORITY=200).""" + ordered = SerializerStore.get_ordered_serializers() + types = [s.TYPE for s in ordered] + assert types.index("test_high") < types.index("test_low") + + +# --------------------------------------------------------------------------- +# SerializationMetadata tests +# --------------------------------------------------------------------------- + + +def test_metadata_fields(): + meta = SerializationMetadata( + type="dict", + size=1024, + encoding="pickle-v4", + serializer_info={"key": "value"}, + ) + assert meta.type == "dict" + assert meta.size == 1024 + assert meta.encoding == "pickle-v4" + assert meta.serializer_info == {"key": "value"} + + +def test_metadata_is_namedtuple(): + meta = SerializationMetadata("str", 10, "utf-8", {}) + assert isinstance(meta, tuple) + assert len(meta) == 4 + + +# --------------------------------------------------------------------------- +# SerializedBlob tests +# --------------------------------------------------------------------------- + + +def test_blob_bytes_auto_detect(): + """bytes value auto-detects as not a reference.""" + blob = SerializedBlob(b"hello") + assert blob.is_reference is False + assert blob.needs_save is True + + +def test_blob_str_auto_detect(): + """str value auto-detects as a reference.""" + blob = SerializedBlob("sha1_key_abc123") + assert blob.is_reference is True + assert blob.needs_save is False + + +def test_blob_explicit_is_reference_override(): + """Explicit is_reference overrides auto-detection.""" + # bytes but marked as reference (edge case) + blob = SerializedBlob(b"data", is_reference=True) + assert blob.is_reference is True + assert blob.needs_save is False + + # str but marked as not a reference (edge case) + blob = SerializedBlob("inline_data", is_reference=False) + assert blob.is_reference is False + assert blob.needs_save is True + + +def test_blob_compress_method_default(): + blob = SerializedBlob(b"data") + assert blob.compress_method == "gzip" + + +def test_blob_compress_method_custom(): + blob = SerializedBlob(b"data", compress_method="raw") + assert blob.compress_method == "raw" + + +def test_blob_value_preserved(): + data = b"\x00\x01\x02\x03" + blob = SerializedBlob(data) + assert blob.value is data + + key = "abc123def456" + blob = SerializedBlob(key) + assert blob.value is key + + +def test_blob_rejects_invalid_types(): + """SerializedBlob must be str or bytes — reject everything else.""" + for bad_value in [123, 3.14, None, [], {}]: + with pytest.raises(TypeError, match="must be str or bytes"): + SerializedBlob(bad_value) diff --git a/test/unit/test_pickle_serializer.py b/test/unit/test_pickle_serializer.py new file mode 100644 index 00000000000..298708b0953 --- /dev/null +++ b/test/unit/test_pickle_serializer.py @@ -0,0 +1,191 @@ +import pickle + +import pytest + +from metaflow.datastore.artifacts.serializer import ( + SerializationMetadata, + SerializerStore, +) +from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + + +# --------------------------------------------------------------------------- +# Registration and identity +# --------------------------------------------------------------------------- + + +def test_type_is_pickle(): + assert PickleSerializer.TYPE == "pickle" + + +def test_priority_is_fallback(): + assert PickleSerializer.PRIORITY == 9999 + + +def test_registered_in_store(): + assert "pickle" in SerializerStore._all_serializers + assert SerializerStore._all_serializers["pickle"] is PickleSerializer + + +def test_last_in_ordering(): + """PickleSerializer should be last (highest PRIORITY) among registered serializers.""" + ordered = SerializerStore.get_ordered_serializers() + assert ordered[-1] is PickleSerializer + + +# --------------------------------------------------------------------------- +# can_serialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "obj", + [ + 42, + "hello", + 3.14, + None, + True, + [1, 2, 3], + {"key": "value"}, + (1, "a"), + set([1, 2]), + b"bytes", + object(), + ], + ids=[ + "int", + "str", + "float", + "None", + "bool", + "list", + "dict", + "tuple", + "set", + "bytes", + "object", + ], +) +def test_can_serialize_any_object(obj): + assert PickleSerializer.can_serialize(obj) is True + + +# --------------------------------------------------------------------------- +# can_deserialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "encoding", + ["pickle-v2", "pickle-v4", "gzip+pickle-v2", "gzip+pickle-v4"], +) +def test_can_deserialize_valid_encodings(encoding): + meta = SerializationMetadata("object", 100, encoding, {}) + assert PickleSerializer.can_deserialize(meta) is True + + +@pytest.mark.parametrize( + "encoding", + ["json", "iotype:text", "msgpack", "unknown", ""], +) +def test_cannot_deserialize_unknown_encodings(encoding): + meta = SerializationMetadata("object", 100, encoding, {}) + assert PickleSerializer.can_deserialize(meta) is False + + +# --------------------------------------------------------------------------- +# serialize +# --------------------------------------------------------------------------- + + +def test_serialize_returns_single_blob(): + blobs, meta = PickleSerializer.serialize({"key": "value"}) + assert len(blobs) == 1 + assert blobs[0].needs_save is True + assert blobs[0].is_reference is False + + +def test_serialize_metadata_encoding(): + _, meta = PickleSerializer.serialize(42) + assert meta.encoding == "pickle-v4" + + +def test_serialize_metadata_type(): + _, meta = PickleSerializer.serialize([1, 2, 3]) + assert "list" in meta.type + + +def test_serialize_metadata_size(): + obj = {"a": 1, "b": 2} + blobs, meta = PickleSerializer.serialize(obj) + assert meta.size == len(blobs[0].value) + assert meta.size > 0 + + +def test_serialize_metadata_serializer_info_empty(): + _, meta = PickleSerializer.serialize("hello") + assert meta.serializer_info == {} + + +def test_serialize_compress_method(): + blobs, _ = PickleSerializer.serialize(42) + assert blobs[0].compress_method == "gzip" + + +# --------------------------------------------------------------------------- +# Round-trip: serialize -> deserialize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "obj", + [ + 42, + "hello world", + 3.14, + None, + True, + False, + [1, "two", 3.0], + {"nested": {"key": [1, 2, 3]}}, + (1, 2, 3), + set([1, 2, 3]), + b"raw bytes", + ], + ids=[ + "int", + "str", + "float", + "None", + "True", + "False", + "list", + "nested_dict", + "tuple", + "set", + "bytes", + ], +) +def test_round_trip(obj): + blobs, meta = PickleSerializer.serialize(obj) + raw_blobs = [b.value for b in blobs] + result = PickleSerializer.deserialize(raw_blobs, meta, context=None) + assert result == obj + + +class _CustomObj: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return isinstance(other, _CustomObj) and self.x == other.x + + +def test_round_trip_custom_class(): + obj = _CustomObj(42) + blobs, meta = PickleSerializer.serialize(obj) + raw_blobs = [b.value for b in blobs] + result = PickleSerializer.deserialize(raw_blobs, meta, context=None) + assert result == obj + assert result.x == 42 diff --git a/test/unit/test_serializer_integration.py b/test/unit/test_serializer_integration.py new file mode 100644 index 00000000000..5da366a11a7 --- /dev/null +++ b/test/unit/test_serializer_integration.py @@ -0,0 +1,212 @@ +""" +Integration tests for the pluggable serializer framework wired into TaskDataStore. + +Tests that: +- PickleSerializer handles standard Python objects through save/load_artifacts +- Custom serializers take priority over PickleSerializer +- Backward compat: old artifacts (without serializer_info) still load +- Metadata includes serializer_info when present +""" + +import json +import os +import shutil +import tempfile + +import pytest + +from metaflow.datastore.artifacts.serializer import ( + ArtifactSerializer, + SerializationMetadata, + SerializedBlob, + SerializerStore, +) +from metaflow.plugins.datastores.serializers.pickle_serializer import PickleSerializer + + +# --------------------------------------------------------------------------- +# Test PickleSerializer round-trip through save/load artifacts +# --------------------------------------------------------------------------- + + +@pytest.fixture +def task_datastore(tmp_path): + """Create a minimal TaskDataStore wired to a local storage backend.""" + from metaflow.datastore.flow_datastore import FlowDataStore + from metaflow.plugins.datastores.local_storage import LocalStorage + + storage_root = str(tmp_path / "datastore") + os.makedirs(storage_root, exist_ok=True) + + flow_ds = FlowDataStore( + flow_name="TestFlow", + environment=None, + metadata=None, + event_logger=None, + monitor=None, + storage_impl=LocalStorage, + ds_root=storage_root, + ) + + task_ds = flow_ds.get_task_datastore( + run_id="1", + step_name="start", + task_id="1", + attempt=0, + mode="w", + ) + task_ds.init_task() + # Isolate from test serializers registered by other test files. + # Only use PickleSerializer (as the plugin system would provide). + task_ds._serializers = [PickleSerializer] + return task_ds + + +def test_save_load_pickle_round_trip(task_datastore): + """Standard Python objects go through PickleSerializer and round-trip.""" + artifacts = [ + ("my_dict", {"key": "value", "nested": [1, 2, 3]}), + ("my_int", 42), + ("my_str", "hello world"), + ("my_none", None), + ] + task_datastore.save_artifacts(iter(artifacts)) + + # Verify metadata + for name, _ in artifacts: + info = task_datastore._info[name] + assert "encoding" in info + assert info["encoding"] == "pickle-v4" + assert info["size"] > 0 + assert "type" in info + + # Load and verify + loaded = dict(task_datastore.load_artifacts([name for name, _ in artifacts])) + assert loaded["my_dict"] == {"key": "value", "nested": [1, 2, 3]} + assert loaded["my_int"] == 42 + assert loaded["my_str"] == "hello world" + assert loaded["my_none"] is None + + +def test_distinct_objects_on_load(task_datastore): + """Loading the same artifact twice yields distinct object instances.""" + shared_list = [1, 2, 3] + task_datastore.save_artifacts(iter([("a", shared_list), ("b", shared_list)])) + + loaded = dict(task_datastore.load_artifacts(["a", "b"])) + assert loaded["a"] == loaded["b"] + assert loaded["a"] is not loaded["b"] # distinct instances + + +def test_metadata_has_no_serializer_info_for_pickle(task_datastore): + """PickleSerializer returns empty serializer_info, so _info should not contain it.""" + task_datastore.save_artifacts(iter([("x", 42)])) + info = task_datastore._info["x"] + # Empty serializer_info should NOT be stored (saves space in metadata) + assert "serializer_info" not in info + + +# --------------------------------------------------------------------------- +# Test custom serializer takes priority +# --------------------------------------------------------------------------- + + +def test_custom_serializer_takes_priority(task_datastore): + """A custom serializer with lower PRIORITY claims matching objects over pickle.""" + + # Define and register a custom serializer inside the test + class _JsonStringSerializer(ArtifactSerializer): + TYPE = "test_json_str" + PRIORITY = 50 + + @classmethod + def can_serialize(cls, obj): + return isinstance(obj, str) + + @classmethod + def can_deserialize(cls, metadata): + return metadata.encoding == "test_json_str" + + @classmethod + def serialize(cls, obj): + blob = json.dumps(obj).encode("utf-8") + return ( + [SerializedBlob(blob, is_reference=False)], + SerializationMetadata( + type="str", + size=len(blob), + encoding="test_json_str", + serializer_info={"format": "json-utf8"}, + ), + ) + + @classmethod + def deserialize(cls, blobs, metadata, context): + return json.loads(blobs[0].decode("utf-8")) + + # Explicitly set serializers: custom first, then pickle fallback. + # Don't use get_ordered_serializers() to avoid pollution from other test files. + task_datastore._serializers = [_JsonStringSerializer, PickleSerializer] + + try: + task_datastore.save_artifacts(iter([("msg", "hello"), ("num", 42)])) + + # "msg" should use our custom serializer (str → _JsonStringSerializer) + msg_info = task_datastore._info["msg"] + assert msg_info["encoding"] == "test_json_str" + assert msg_info["serializer_info"] == {"format": "json-utf8"} + + # "num" should fall through to PickleSerializer (int → not claimed by custom) + num_info = task_datastore._info["num"] + assert num_info["encoding"] == "pickle-v4" + + # Both round-trip correctly + loaded = dict(task_datastore.load_artifacts(["msg", "num"])) + assert loaded["msg"] == "hello" + assert loaded["num"] == 42 + finally: + # Clean up: remove from registry to avoid polluting other tests + SerializerStore._all_serializers.pop("test_json_str", None) + SerializerStore._registration_order = [ + t for t in SerializerStore._registration_order if t != "test_json_str" + ] + + +# --------------------------------------------------------------------------- +# Backward compat: old metadata format +# --------------------------------------------------------------------------- + + +def test_backward_compat_old_metadata(task_datastore): + """Artifacts saved with old metadata format (no serializer_info) still load.""" + # Save normally first + task_datastore.save_artifacts(iter([("old_artifact", {"a": 1})])) + + # Simulate old metadata format: no serializer_info, old encoding + task_datastore._info["old_artifact"] = { + "size": 100, + "type": "", + "encoding": "gzip+pickle-v4", + # no "serializer_info" key + } + + # Should still load via PickleSerializer (can_deserialize handles gzip+pickle-v4) + loaded = dict(task_datastore.load_artifacts(["old_artifact"])) + assert loaded["old_artifact"] == {"a": 1} + + +def test_backward_compat_no_encoding(task_datastore): + """Very old artifacts without encoding field default to gzip+pickle-v2.""" + # Save an artifact + task_datastore.save_artifacts(iter([("ancient", 99)])) + + # Simulate very old metadata: no encoding, no serializer_info + task_datastore._info["ancient"] = { + "size": 10, + "type": "", + # no "encoding" key — defaults to gzip+pickle-v2 + } + + # Should still load + loaded = dict(task_datastore.load_artifacts(["ancient"])) + assert loaded["ancient"] == 99