Skip to content

Commit 785881b

Browse files
committed
feat: Add before-validation converts all ASDF NDArrayType to ndarray
1 parent 338323e commit 785881b

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

asdf_pydantic/model.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import textwrap
2-
from typing import ClassVar
2+
from typing import Any, ClassVar
33

4+
import numpy as np
45
import yaml
5-
from pydantic import BaseModel
6+
from asdf.tags.core import NDArrayType
7+
from numpy.typing import NDArray
8+
from pydantic import BaseModel, ValidationInfo, field_validator
69

710

811
class AsdfPydanticModel(BaseModel):
@@ -67,3 +70,14 @@ def schema_asdf(
6770
)
6871
body = yaml.dump(cls.schema())
6972
return header + body
73+
74+
@field_validator("*", mode="before")
75+
@classmethod
76+
def _allow_asdf_NDArrayType_to_be_ndarray(
77+
cls, value: Any, info: ValidationInfo
78+
) -> Any | NDArray:
79+
"""Before Pydantic validation, convert NDArrayType to ndarray."""
80+
if not isinstance(value, NDArrayType):
81+
return value
82+
83+
return np.asarray(value)

tests/patterns/numpy_type_test.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@ def test_convert_ArrayContainer_to_asdf(tmp_path):
3737
"""When writing ArrayContainer to an ASDF file, the array field should be
3838
serialized to the original numpy array.
3939
"""
40-
af = asdf.AsdfFile({"data": ArrayContainer(array=np.array([1, 2, 3]))}).write_to(
41-
tmp_path / "test.asdf"
42-
)
40+
data = ArrayContainer(array=np.array([1, 2, 3]))
41+
af = asdf.AsdfFile({"data": data})
42+
af.write_to(tmp_path / "test.asdf")
4343

4444
with asdf.open(tmp_path / "test.asdf") as af:
45-
assert isinstance(af.tree["array"], np.ndarray), (
46-
f"Expected {type(np.ndarray)}, " f"got {type(af.tree['array'])}"
45+
breakpoint()
46+
assert isinstance(af.tree["data"], np.ndarray), (
47+
f"Expected {type(np.ndarray)}, " f"got {type(af.tree['data'])}"
4748
)
48-
assert np.all(af.tree["array"] == np.array([1, 2, 3]))
49+
assert np.all(af.tree["data"] == np.array([1, 2, 3]))

0 commit comments

Comments
 (0)