Skip to content

Commit a883bff

Browse files
authored
Merge pull request #2038 from sydduckworth/fix-pickle-serialization
Fix `AsdfFile` pickling for instances without file descriptors
2 parents 8c2e988 + 6ffff87 commit a883bff

6 files changed

Lines changed: 108 additions & 44 deletions

File tree

asdf/_tests/test_asdf.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2+
import pickle
23

4+
import numpy as np
35
import pytest
46

57
from asdf import config_context
@@ -379,3 +381,12 @@ def test_fsspec_http(httpserver):
379381
with fsspec.open(fn) as f:
380382
af = open_asdf(f)
381383
assert_tree_match(tree, af.tree)
384+
385+
386+
def test_asdf_file_pickle_from_dict():
387+
"""Verify that an AsdfFile created from a dict (with no file descriptor) can be pickled"""
388+
tree = {"a": 1, "b": {"c": 2, "d": np.ones((10, 10))}}
389+
af = AsdfFile(tree)
390+
pkl = pickle.dumps(af)
391+
loaded = pickle.loads(pkl) # noqa: S301
392+
assert_tree_match(af.tree, loaded.tree)

asdf/_tests/test_extension.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
Validator,
2121
get_cached_extension_manager,
2222
)
23-
from asdf.extension._manager import _resolve_type
23+
from asdf.extension._manager import ValidatorManager, _resolve_type
24+
from asdf.tagged import TaggedList
2425
from asdf.testing.helpers import roundtrip_object
2526

2627

@@ -734,6 +735,31 @@ def test_validator():
734735
af.validate()
735736

736737

738+
class ValidatorFailOn(Validator):
739+
schema_property = "fail"
740+
tags = ["fail"]
741+
742+
def __init__(self, fail_on):
743+
self.fail_on = fail_on
744+
745+
def validate(self, schema_property_value, node, schema):
746+
if schema_property_value == self.fail_on:
747+
yield ValidationError("Node was doomed to fail")
748+
749+
750+
def test_validator_manager():
751+
validator = ValidatorManager([ValidatorFailOn("bar")])
752+
errs = list(validator.validate("fail", "foo", TaggedList([], "fail"), {}))
753+
assert len(errs) == 0
754+
755+
errs = list(validator.validate("fail", "bar", TaggedList([], "other"), {}))
756+
assert len(errs) == 0
757+
758+
errs = list(validator.validate("fail", "bar", TaggedList([], "fail"), {}))
759+
assert len(errs) == 1
760+
assert isinstance(errs[0], ValidationError)
761+
762+
737763
def test_converter_deferral():
738764
class Bar:
739765
def __init__(self, value):

asdf/extension/_manager.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
1+
from __future__ import annotations
2+
13
import sys
4+
from dataclasses import dataclass
25
from functools import lru_cache
6+
from typing import TYPE_CHECKING
37

48
from asdf.tagged import Tagged
59
from asdf.util import get_class_name, uri_match
610

711
from ._extension import ExtensionProxy
812

13+
if TYPE_CHECKING:
14+
from collections.abc import Iterable, Iterator, Mapping
15+
from typing import Any
16+
17+
from asdf.exceptions import ValidationError
18+
from asdf.extension import Validator
19+
from asdf.typing import TreeKey
20+
921

1022
def _resolve_type(path):
1123
"""
@@ -317,32 +329,26 @@ def _get_cached_extension_manager(extensions):
317329

318330
class ValidatorManager:
319331
"""
320-
Wraps a list of custom validators and indexes them by schema property.
332+
Wraps a list of custom validators and binds them to their associated schemas.
321333
322334
Parameters
323335
----------
324336
validators : iterable of asdf.extension.Validator
325337
List of validators to manage.
326338
"""
327339

328-
def __init__(self, validators):
329-
self._validators = list(validators)
340+
def __init__(self, validators: Iterable[Validator]):
341+
self._validators = {}
342+
for validator in validators:
343+
if validator.schema_property not in self._validators:
344+
self._validators[validator.schema_property] = set()
330345

331-
self._validators_by_schema_property = {}
332-
for validator in self._validators:
333-
if validator.schema_property not in self._validators_by_schema_property:
334-
self._validators_by_schema_property[validator.schema_property] = set()
335-
self._validators_by_schema_property[validator.schema_property].add(validator)
346+
self._validators[validator.schema_property].add(validator)
336347

337-
self._jsonschema_validators_by_schema_property = {}
338-
for schema_property in self._validators_by_schema_property:
339-
self._jsonschema_validators_by_schema_property[schema_property] = self._get_jsonschema_validator(
340-
schema_property,
341-
)
342-
343-
def validate(self, schema_property, schema_property_value, node, schema):
344-
"""
345-
Validate an ASDF tree node against custom validators for a schema property.
348+
def validate(
349+
self, schema_property: str, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any]
350+
) -> Iterator[ValidationError]:
351+
"""Validate an ASDF tree node against custom validators for a schema property.
346352
347353
Parameters
348354
----------
@@ -360,27 +366,34 @@ def validate(self, schema_property, schema_property_value, node, schema):
360366
------
361367
asdf.exceptions.ValidationError
362368
"""
363-
if schema_property in self._validators_by_schema_property:
364-
for validator in self._validators_by_schema_property[schema_property]:
365-
if _validator_matches(validator, node):
366-
yield from validator.validate(schema_property_value, node, schema)
369+
for validator in self._validators[schema_property]:
370+
if _validator_matches(validator, node):
371+
yield from validator.validate(schema_property_value, node, schema)
367372

368-
def get_jsonschema_validators(self):
369-
"""
370-
Get a dictionary of validator methods suitable for use
371-
with the jsonschema library.
373+
def get_jsonschema_validators(self) -> dict[str, JsonSchemaValidators]:
374+
"""Get a dictionary mapping schema names to ``jsonschema``-compatible validator functions."""
375+
return {
376+
schema_property: JsonSchemaValidators(schema_property, frozenset(validators))
377+
for schema_property, validators in self._validators.items()
378+
}
372379

373-
Returns
374-
-------
375-
dict of str: callable
376-
"""
377-
return dict(self._jsonschema_validators_by_schema_property)
378380

379-
def _get_jsonschema_validator(self, schema_property):
380-
def _validator(_, schema_property_value, node, schema):
381-
return self.validate(schema_property, schema_property_value, node, schema)
381+
@dataclass(frozen=True, slots=True)
382+
class JsonSchemaValidators:
383+
"""Callable that wraps a set of `Validator` objects to make them compatible with `jsonschema`.
384+
385+
Each validator is always passed `schema_property` as its first argument regardless of the actual input schema.
386+
"""
387+
388+
schema_property: str
389+
validators: frozenset[Validator]
382390

383-
return _validator
391+
def __call__(
392+
self, _schema_property: Any, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any]
393+
) -> Iterator[ValidationError]:
394+
for validator in self.validators:
395+
if _validator_matches(validator, node):
396+
yield from validator.validate(schema_property_value, node, schema)
384397

385398

386399
def _validator_matches(validator, node):

asdf/extension/_validator.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1+
from __future__ import annotations
2+
13
import abc
4+
from typing import TYPE_CHECKING, Any
5+
6+
if TYPE_CHECKING:
7+
from collections.abc import Iterable, Iterator, Mapping
8+
9+
from asdf.exceptions import ValidationError
10+
from asdf.tagged import Tagged
11+
from asdf.typing import TreeKey
212

313

414
class Validator(abc.ABC):
@@ -8,13 +18,14 @@ class Validator(abc.ABC):
818
"""
919

1020
@abc.abstractproperty
11-
def schema_property(self):
21+
def schema_property(self) -> str:
1222
"""
1323
Name of the schema property used to invoke this validator.
1424
"""
25+
...
1526

1627
@abc.abstractproperty
17-
def tags(self):
28+
def tags(self) -> Iterable[str]:
1829
"""
1930
Get the YAML tags that are appropriate to this validator.
2031
URI patterns are permitted, see `asdf.util.uri_match` for details.
@@ -24,9 +35,12 @@ def tags(self):
2435
iterable of str
2536
Tag URIs or URI patterns.
2637
"""
38+
...
2739

2840
@abc.abstractmethod
29-
def validate(self, schema_property_value, node, schema):
41+
def validate(
42+
self, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any]
43+
) -> Iterator[ValidationError]:
3044
"""
3145
Validate the given node from the ASDF tree.
3246
@@ -54,3 +68,4 @@ def validate(self, schema_property_value, node, schema):
5468
asdf.exceptions.ValidationError
5569
Yield an instance of ValidationError for each error present in the node.
5670
"""
71+
...

asdf/tags/core/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
# to pass an isinstance(..., dict) check and to allow it to be "lazy"
2525
# loaded when "lazy_tree=True".
2626
class AsdfObject(collections.UserDict, dict):
27-
pass
27+
def __reduce__(self):
28+
# Necessary for correct pickling/unpickling
29+
# Otherwise pickle will use dict's reduce method which causes UserDict to fail to unpickle
30+
return super(collections.UserDict, self).__reduce__()
2831

2932

3033
class Software(dict):

pyproject.toml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,7 @@ omit = [
140140
]
141141

142142
[tool.coverage.report]
143-
exclude_lines = [
144-
# Have to re-enable the standard pragma
145-
"pragma: no cover",
143+
exclude_also = [
146144
# Don't complain about packages we have installed
147145
"except ImportError",
148146
# Don't complain if tests don't hit assertions
@@ -152,8 +150,6 @@ exclude_lines = [
152150
'def main\(.*\):',
153151
# Ignore branches that don't pertain to this version of Python
154152
"pragma: py{ ignore_python_version }",
155-
# Ignore type-checking imports
156-
"if TYPE_CHECKING:",
157153
]
158154

159155
[tool.ruff]

0 commit comments

Comments
 (0)