Skip to content

Commit fed2c78

Browse files
author
Saeid Barati
committed
Address Romain review feedback
- Add type annotations to ArtifactSerializer method signatures for better editor/IDE support. - Remove unused ``context`` parameter from ``deserialize`` (call sites never passed a meaningful value). - Promote ``STORAGE`` / ``WIRE`` to a ``SerializationFormat`` enum; subclass ``str`` so the existing string-literal comparisons keep working. - Reword the ``_serializers_override`` comment and move the property rationale next to the property definition. - Extend the multi-blob error message with "at this time. If you have a need for multi blob serializers, please reach out to the Metaflow team.". - Include ``serializer_info`` in the "No deserializer claimed artifact" error message.
1 parent 3e9a437 commit fed2c78

7 files changed

Lines changed: 84 additions & 40 deletions

File tree

metaflow/datastore/artifacts/serializer.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
11
import inspect
22
from abc import ABCMeta, abstractmethod
33
from collections import namedtuple
4+
from enum import Enum
5+
from typing import Any, List, Optional, Tuple, Type, Union
46

57

6-
# Serialization formats. STORAGE produces (blobs, metadata) for the datastore;
7-
# WIRE produces a str for CLI args, protobuf payloads, and cross-process IPC.
8-
STORAGE = "storage"
9-
WIRE = "wire"
8+
class SerializationFormat(str, Enum):
9+
"""
10+
Representation a serializer produces or consumes.
11+
12+
- ``STORAGE`` yields ``(List[SerializedBlob], SerializationMetadata)`` for
13+
the datastore save path.
14+
- ``WIRE`` yields a ``str`` for CLI args, protobuf payloads, and
15+
cross-process IPC.
16+
17+
Subclassing ``str`` keeps ``SerializationFormat.STORAGE == "storage"``
18+
True, so existing call sites that compare against the string literal
19+
keep working without a migration.
20+
"""
21+
22+
STORAGE = "storage"
23+
WIRE = "wire"
24+
25+
26+
# Module-level aliases kept so call sites can write ``format=STORAGE`` without
27+
# importing the enum itself.
28+
STORAGE = SerializationFormat.STORAGE
29+
WIRE = SerializationFormat.WIRE
1030

1131

1232
SerializationMetadata = namedtuple(
@@ -30,7 +50,11 @@ class SerializedBlob(object):
3050
If None, auto-detected from value type: str -> reference, bytes -> new data.
3151
"""
3252

33-
def __init__(self, value, is_reference=None):
53+
def __init__(
54+
self,
55+
value: Union[str, bytes],
56+
is_reference: Optional[bool] = None,
57+
):
3458
if not isinstance(value, (str, bytes)):
3559
raise TypeError(
3660
"SerializedBlob value must be str or bytes, got %s" % type(value).__name__
@@ -42,7 +66,7 @@ def __init__(self, value, is_reference=None):
4266
self.is_reference = is_reference
4367

4468
@property
45-
def needs_save(self):
69+
def needs_save(self) -> bool:
4670
"""True if this blob contains new bytes that need to be stored."""
4771
return not self.is_reference
4872

@@ -69,7 +93,7 @@ def __init__(cls, name, bases, namespace):
6993
SerializerStore._ordered_cache = None
7094

7195
@staticmethod
72-
def get_ordered_serializers():
96+
def get_ordered_serializers() -> List[Type["ArtifactSerializer"]]:
7397
"""
7498
Return serializer classes sorted by (PRIORITY, registration_order).
7599
@@ -127,12 +151,12 @@ class ArtifactSerializer(object, metaclass=SerializerStore):
127151
PickleSerializer uses 9999 as the universal fallback.
128152
"""
129153

130-
TYPE = None
131-
PRIORITY = 100
154+
TYPE: Optional[str] = None
155+
PRIORITY: int = 100
132156

133157
@classmethod
134158
@abstractmethod
135-
def can_serialize(cls, obj):
159+
def can_serialize(cls, obj: Any) -> bool:
136160
"""
137161
Return True if this serializer can handle the given object.
138162
@@ -149,7 +173,7 @@ def can_serialize(cls, obj):
149173

150174
@classmethod
151175
@abstractmethod
152-
def can_deserialize(cls, metadata):
176+
def can_deserialize(cls, metadata: SerializationMetadata) -> bool:
153177
"""
154178
Return True if this serializer can deserialize given the metadata.
155179
@@ -166,7 +190,11 @@ def can_deserialize(cls, metadata):
166190

167191
@classmethod
168192
@abstractmethod
169-
def serialize(cls, obj, format=STORAGE):
193+
def serialize(
194+
cls,
195+
obj: Any,
196+
format: SerializationFormat = STORAGE,
197+
) -> Union[Tuple[List[SerializedBlob], SerializationMetadata], str]:
170198
"""
171199
Serialize obj. Must be side-effect-free: this method may be invoked
172200
multiple times (caching, retries, parallel dispatch) and must not
@@ -178,7 +206,7 @@ def serialize(cls, obj, format=STORAGE):
178206
----------
179207
obj : Any
180208
The Python object to serialize.
181-
format : str
209+
format : SerializationFormat
182210
Either ``STORAGE`` (default) or ``WIRE``.
183211
- ``STORAGE`` returns a tuple ``(List[SerializedBlob], SerializationMetadata)``
184212
for persisting through the datastore.
@@ -194,7 +222,12 @@ def serialize(cls, obj, format=STORAGE):
194222

195223
@classmethod
196224
@abstractmethod
197-
def deserialize(cls, data, metadata=None, context=None, format=STORAGE):
225+
def deserialize(
226+
cls,
227+
data: Union[List[bytes], str],
228+
metadata: Optional[SerializationMetadata] = None,
229+
format: SerializationFormat = STORAGE,
230+
) -> Any:
198231
"""
199232
Deserialize back to a Python object.
200233
@@ -205,9 +238,7 @@ def deserialize(cls, data, metadata=None, context=None, format=STORAGE):
205238
metadata : SerializationMetadata, optional
206239
Metadata stored alongside the artifact. Required for STORAGE,
207240
ignored for WIRE.
208-
context : Any, optional
209-
Optional context for deserialization (e.g., task vs client loading).
210-
format : str
241+
format : SerializationFormat
211242
Either ``STORAGE`` (default) or ``WIRE``.
212243
213244
Returns

metaflow/datastore/task_datastore.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,9 @@ def __init__(
117117
self._parent = flow_datastore
118118
self._persist = persist
119119

120-
# ``_serializers`` is a property that dispatches through
121-
# ``SerializerStore.get_ordered_serializers()`` on each access. The
122-
# lookup is cheap (cached inside the store) and picks up serializers
123-
# registered via the lazy import hook after this instance was
124-
# constructed — otherwise long-lived datastores (notebooks, client
125-
# sessions) would silently miss any extension registered after init.
120+
# Tests assign ``self._serializers = [...]`` to pin the dispatch list
121+
# for isolation. When set, the ``_serializers`` property returns this
122+
# override instead of consulting the global registry.
126123
self._serializers_override = None
127124

128125
self._is_done_set = False
@@ -203,6 +200,12 @@ def __init__(
203200

204201
@property
205202
def _serializers(self):
203+
# Dispatch through ``SerializerStore.get_ordered_serializers()`` on
204+
# each access. The lookup is cheap (cached inside the store) and
205+
# picks up serializers registered via the lazy import hook after
206+
# this instance was constructed — otherwise long-lived datastores
207+
# (notebooks, client sessions) would silently miss any extension
208+
# registered after init.
206209
if self._serializers_override is not None:
207210
return self._serializers_override
208211
return SerializerStore.get_ordered_serializers()
@@ -402,7 +405,9 @@ def serialize_iter():
402405
# load. Fail loudly until multi-blob support lands.
403406
raise DataException(
404407
"Serializer %s returned %d blobs for artifact '%s'; "
405-
"only single-blob serializers are supported."
408+
"only single-blob serializers are supported at this "
409+
"time. If you have a need for multi blob "
410+
"serializers, please reach out to the Metaflow team."
406411
% (serializer.__name__, len(blobs), name)
407412
)
408413
artifact_names.append(name)
@@ -472,8 +477,14 @@ def load_artifacts(self, names):
472477
% serializer_source
473478
)
474479
raise DataException(
475-
"No deserializer claimed artifact '%s' (encoding: %s)."
476-
"%s" % (name, metadata.encoding, source_hint)
480+
"No deserializer claimed artifact '%s' (encoding: %s, "
481+
"serializer_info: %r).%s"
482+
% (
483+
name,
484+
metadata.encoding,
485+
metadata.serializer_info,
486+
source_hint,
487+
)
477488
)
478489
deserializers[name] = (deserializer, metadata)
479490
to_load[self._objects[name]].append(name)
@@ -486,9 +497,7 @@ def load_artifacts(self, names):
486497
# Deserialize each time to have fully distinct objects (the user
487498
# would not expect two artifacts with different names to actually
488499
# be aliases of one another)
489-
yield name, deserializer.deserialize(
490-
[blob], metadata, context=None
491-
)
500+
yield name, deserializer.deserialize([blob], metadata)
492501

493502
@require_mode("r")
494503
def get_artifact_sizes(self, names):

metaflow/plugins/datastores/serializers/pickle_serializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def serialize(cls, obj, format=STORAGE):
5555
)
5656

5757
@classmethod
58-
def deserialize(cls, data, metadata=None, context=None, format=STORAGE):
58+
def deserialize(cls, data, metadata=None, format=STORAGE):
5959
if format == WIRE:
6060
raise NotImplementedError(
6161
"PickleSerializer does not support the WIRE format."

test/unit/test_artifact_serializer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def serialize(cls, obj):
5151
)
5252

5353
@classmethod
54-
def deserialize(cls, blobs, metadata, context):
54+
def deserialize(cls, blobs, metadata):
5555
return blobs[0].decode("utf-8")
5656

5757

@@ -76,7 +76,7 @@ def serialize(cls, obj):
7676
)
7777

7878
@classmethod
79-
def deserialize(cls, blobs, metadata, context):
79+
def deserialize(cls, blobs, metadata):
8080
return int(blobs[0].decode("utf-8"))
8181

8282

@@ -99,7 +99,7 @@ def serialize(cls, obj):
9999
raise NotImplementedError
100100

101101
@classmethod
102-
def deserialize(cls, blobs, metadata, context):
102+
def deserialize(cls, blobs, metadata):
103103
raise NotImplementedError
104104

105105

@@ -143,7 +143,7 @@ def serialize(cls, obj):
143143
raise NotImplementedError
144144

145145
@classmethod
146-
def deserialize(cls, blobs, metadata, context):
146+
def deserialize(cls, blobs, metadata):
147147
raise NotImplementedError
148148

149149
assert SerializerStore._all_serializers["test_high"] is _ReplacementSerializer
@@ -286,7 +286,7 @@ def serialize(cls, obj, format=STORAGE):
286286
)
287287

288288
@classmethod
289-
def deserialize(cls, data, metadata=None, context=None, format=STORAGE):
289+
def deserialize(cls, data, metadata=None, format=STORAGE):
290290
if format == WIRE:
291291
return data
292292
return data[0].decode("utf-8")
@@ -295,6 +295,10 @@ def deserialize(cls, data, metadata=None, context=None, format=STORAGE):
295295
def test_format_constants():
296296
assert STORAGE == "storage"
297297
assert WIRE == "wire"
298+
# Enum members compare both by identity and against the underlying string.
299+
assert STORAGE is STORAGE
300+
assert STORAGE.value == "storage"
301+
assert WIRE.value == "wire"
298302

299303

300304
def test_dual_format_storage_roundtrip():

test/unit/test_lazy_serializer_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def serialize(cls, obj, format="storage"):
200200
return [SerializedBlob(blob)], SerializationMetadata("x", 0, "x", {})
201201

202202
@classmethod
203-
def deserialize(cls, data, metadata=None, context=None, format="storage"):
203+
def deserialize(cls, data, metadata=None, format="storage"):
204204
return None
205205

206206
# Remove it so lazy-registry has to pull it in.

test/unit/test_pickle_serializer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_serialize_metadata_serializer_info_empty():
165165
def test_round_trip(obj):
166166
blobs, meta = PickleSerializer.serialize(obj)
167167
raw_blobs = [b.value for b in blobs]
168-
result = PickleSerializer.deserialize(raw_blobs, meta, context=None)
168+
result = PickleSerializer.deserialize(raw_blobs, meta)
169169
assert result == obj
170170

171171

@@ -181,6 +181,6 @@ def test_round_trip_custom_class():
181181
obj = _CustomObj(42)
182182
blobs, meta = PickleSerializer.serialize(obj)
183183
raw_blobs = [b.value for b in blobs]
184-
result = PickleSerializer.deserialize(raw_blobs, meta, context=None)
184+
result = PickleSerializer.deserialize(raw_blobs, meta)
185185
assert result == obj
186186
assert result.x == 42

test/unit/test_serializer_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def serialize(cls, obj):
141141
)
142142

143143
@classmethod
144-
def deserialize(cls, blobs, metadata, context):
144+
def deserialize(cls, blobs, metadata):
145145
return json.loads(blobs[0].decode("utf-8"))
146146

147147
# Explicitly set serializers: custom first, then pickle fallback.
@@ -240,7 +240,7 @@ def serialize(cls, obj, format="storage"):
240240
raise NotImplementedError
241241

242242
@classmethod
243-
def deserialize(cls, data, metadata=None, context=None, format="storage"):
243+
def deserialize(cls, data, metadata=None, format="storage"):
244244
raise NotImplementedError
245245

246246
try:

0 commit comments

Comments
 (0)