Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ stubs/version.py
# claude code
.claude/

workflow.yaml
103 changes: 103 additions & 0 deletions metaflow/algospec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
AlgoSpec -- a single-computation unit within Metaflow.

AlgoSpec is a FlowSpec subclass with a single step: the call method,
marked as @step(start=True, end=True). FlowGraph._identify_start_end
picks it up via the is_start/is_end attributes — no special-casing
needed in graph traversal, lint, runtime, or Maestro.

Lifecycle:
init() -- called once per worker (model loading)
call() -- called per row or batch (computation)
"""

import atexit

from .flowspec import FlowSpec, FlowSpecMeta


class AlgoSpecMeta(FlowSpecMeta):
"""Metaclass for AlgoSpec.

Marks call() as @step(start=True, end=True) before FlowSpecMeta
builds the graph. The step name is the class name lowercased.
"""

_registry = []
_atexit_registered = False

def __init__(cls, name, bases, attrs):
if name == "AlgoSpec":
super().__init__(name, bases, attrs)
return

from .decorators import step

call_fn = attrs.get("call")
if call_fn is not None and callable(call_fn):
attrs["call"] = step(call_fn, start=True, end=True)
attrs["call"].name = name.lower()
attrs["call"].__name__ = name.lower()
cls.call = attrs["call"]
cls._algo_step_name = name.lower()

super().__init__(name, bases, attrs)

if not hasattr(cls.call, "is_step"):
from .exception import MetaflowException

raise MetaflowException(
"%s must implement call(). "
"AlgoSpec subclasses require a call() method." % name
)

AlgoSpecMeta._registry.append(cls)

if not AlgoSpecMeta._atexit_registered:
atexit.register(AlgoSpecMeta._on_exit)
AlgoSpecMeta._atexit_registered = True

def _init_graph(cls):
from .graph import FlowGraph

cls._graph = FlowGraph(cls)
# The method is cls.call but node.name is the class-derived name.
if cls._graph.is_algo_spec:
cls._steps = [cls.call]
else:
cls._steps = [getattr(cls, node.name) for node in cls._graph]

@staticmethod
def _on_exit():
AlgoSpecMeta._registry.clear()


class AlgoSpec(FlowSpec, metaclass=AlgoSpecMeta):
"""Base class for single-computation algo specifications."""

is_algo_spec = True

_EPHEMERAL = FlowSpec._EPHEMERAL | {"is_algo_spec"}

_NON_PARAMETERS = FlowSpec._NON_PARAMETERS | {
"init",
"call",
"is_algo_spec",
}

def init(self):
"""Called once per worker before any call() invocations."""
pass

def call(self):
"""Main computation. Must be overridden."""
raise NotImplementedError("Subclasses must implement call()")

def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)

def __getattr__(self, name):
# Resolve the class-derived step name to the call method
if name == getattr(self.__class__, "_algo_step_name", None):
return self.call
return super().__getattr__(name)
39 changes: 31 additions & 8 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,8 +2121,9 @@ def parent_steps(self) -> Iterator["Step"]:
Parent step
"""
graph_info = self.task["_graph_info"].data
start_step = graph_info.get("start_step", "start")

if self.id != "start":
if self.id != start_step:
flow, run, _ = self.path_components
for node_name, attributes in graph_info["steps"].items():
if self.id in attributes["next"]:
Expand All @@ -2139,8 +2140,9 @@ def child_steps(self) -> Iterator["Step"]:
Child step
"""
graph_info = self.task["_graph_info"].data
end_step = graph_info.get("end_step", "end")

if self.id != "end":
if self.id != end_step:
flow, run, _ = self.path_components
for next_step in graph_info["steps"][self.id]["next"]:
yield Step(f"{flow}/{run}/{next_step}", _namespace_check=False)
Expand Down Expand Up @@ -2176,6 +2178,25 @@ def _iter_filter(self, x):
# exclude _parameters step
return x.id[0] != "_"

_cached_endpoints = None

@property
def _graph_endpoints(self):
"""
Returns (start_step_name, end_step_name) from _parameters metadata.
Falls back to ("start", "end") for runs that predate this change.
"""
if self._cached_endpoints is None:
start, end = "start", "end"
try:
params_meta = self["_parameters"].task.metadata_dict
start = params_meta.get("start_step", "start")
end = params_meta.get("end_step", "end")
except Exception:
pass
self._cached_endpoints = (start, end)
Comment on lines +2186 to +2197
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Bare except Exception silently swallows real failures

The except Exception: pass block treats every failure — including AttributeError, missing _parameters task, network errors, and data corruption — as "just an old run" and falls back to ("start", "end"). This means a corrupted metadata store or a genuine programming mistake silently produces incorrect graph endpoints rather than a visible error.

Consider at minimum logging a warning on the exception, or narrowing the except to KeyError for the missing-step case.

except KeyError:
    pass  # Pre-AlgoSpec run: no start_step/end_step metadata

return self._cached_endpoints

def steps(self, *tags: str) -> Iterator[Step]:
"""
[Legacy function - do not use]
Expand Down Expand Up @@ -2298,17 +2319,18 @@ def finished_at(self) -> Optional[datetime]:
@property
def end_task(self) -> Optional[Task]:
"""
Returns the Task corresponding to the 'end' step.
Returns the Task corresponding to the terminal step.

This returns None if the end step does not yet exist.
This returns None if the terminal step does not yet exist.

Returns
-------
Task, optional
The 'end' task
The terminal step's task
"""
try:
end_step = self["end"]
_, end_step_name = self._graph_endpoints
end_step = self[end_step_name]
except KeyError:
return None

Expand Down Expand Up @@ -2481,8 +2503,9 @@ def trigger(self) -> Optional[Trigger]:
Trigger, optional
Container of triggering events
"""
if "start" in self and self["start"].task:
meta = self["start"].task.metadata_dict.get("execution-triggers")
start_step, _ = self._graph_endpoints
if start_step in self and self[start_step].task:
meta = self[start_step].task.metadata_dict.get("execution-triggers")
if meta:
return Trigger(json.loads(meta))
return None
Expand Down
6 changes: 6 additions & 0 deletions metaflow/datastore/artifacts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .serializer import (
ArtifactSerializer,
SerializationMetadata,
SerializedBlob,
SerializerStore,
)
173 changes: 173 additions & 0 deletions metaflow/datastore/artifacts/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from abc import ABCMeta, abstractmethod
from collections import namedtuple


SerializationMetadata = namedtuple(
"SerializationMetadata", ["type", "size", "encoding", "serializer_info"]
)


class SerializedBlob(object):
"""
Represents a single blob produced by a serializer.

A serializer may produce multiple blobs per artifact. Each blob is either:
- New bytes to be stored (is_reference=False, value is bytes)
- A reference to already-stored data (is_reference=True, value is a string key)

Parameters
----------
value : Union[str, bytes]
The blob data (bytes) or a reference key (str).
is_reference : bool, optional
If None, auto-detected from value type: str -> reference, bytes -> new data.
compress_method : str
Compression method for new blobs. Ignored for references. Default "gzip".
NOTE: Not yet wired into the save path — ContentAddressedStore currently
always applies gzip. This field is forward-looking for when per-blob
compression control is needed (e.g., multi-blob IOType support).
"""

def __init__(self, value, is_reference=None, compress_method="gzip"):
if not isinstance(value, (str, bytes)):
raise TypeError(
"SerializedBlob value must be str or bytes, got %s" % type(value).__name__
)
self.value = value
self.compress_method = compress_method
if is_reference is None:
self.is_reference = isinstance(value, str)
else:
self.is_reference = is_reference

@property
def needs_save(self):
"""True if this blob contains new bytes that need to be stored."""
return not self.is_reference


class SerializerStore(ABCMeta):
"""
Metaclass for ArtifactSerializer that auto-registers subclasses by TYPE.

Provides deterministic ordering: serializers are sorted by (PRIORITY, registration_order).
Lower PRIORITY values are tried first. Registration order breaks ties.
"""

_all_serializers = {}
_registration_order = []

def __init__(cls, name, bases, namespace):
super().__init__(name, bases, namespace)
if cls.TYPE is not None:
if cls.TYPE not in SerializerStore._all_serializers:
SerializerStore._registration_order.append(cls.TYPE)
SerializerStore._all_serializers[cls.TYPE] = cls

@staticmethod
def get_ordered_serializers():
"""
Return serializer classes sorted by (PRIORITY, registration_order).

This ordering is deterministic for a given set of loaded serializers.
"""
order = SerializerStore._registration_order
return sorted(
SerializerStore._all_serializers.values(),
key=lambda s: (s.PRIORITY, order.index(s.TYPE)),
)


class ArtifactSerializer(object, metaclass=SerializerStore):
"""
Abstract base class for artifact serializers.

Subclasses must set TYPE to a unique string identifier and implement
all four class methods. Subclasses are auto-registered by the SerializerStore
metaclass on class definition.

Attributes
----------
TYPE : str or None
Unique identifier for this serializer (e.g., "pickle", "iotype").
Set to None in the base class to prevent registration.
PRIORITY : int
Dispatch priority. Lower values are tried first. Default 100.
PickleSerializer uses 9999 as the universal fallback.
"""

TYPE = None
PRIORITY = 100

@classmethod
@abstractmethod
def can_serialize(cls, obj):
"""
Return True if this serializer can handle the given object.

Parameters
----------
obj : Any
The Python object to serialize.

Returns
-------
bool
"""
raise NotImplementedError

@classmethod
@abstractmethod
def can_deserialize(cls, metadata):
"""
Return True if this serializer can deserialize given the metadata.

Parameters
----------
metadata : SerializationMetadata
Metadata stored alongside the artifact.

Returns
-------
bool
"""
raise NotImplementedError

@classmethod
@abstractmethod
def serialize(cls, obj):
"""
Serialize obj to blobs and metadata. Must be side-effect-free.

Parameters
----------
obj : Any
The Python object to serialize.

Returns
-------
tuple
(List[SerializedBlob], SerializationMetadata)
"""
raise NotImplementedError

@classmethod
@abstractmethod
def deserialize(cls, blobs, metadata, context):
"""
Deserialize blobs back to a Python object.

Parameters
----------
blobs : List[bytes]
The raw blob data.
metadata : SerializationMetadata
Metadata stored alongside the artifact.
context : Any
Optional context for deserialization (e.g., task vs client loading).

Returns
-------
Any
"""
raise NotImplementedError
Loading
Loading