Skip to content

Commit 2ace308

Browse files
authored
Make name required in axes (#295)
* Make name required in axes * De-duplicate code for checking ordered scales
1 parent 6c51997 commit 2ace308

File tree

11 files changed

+946
-339
lines changed

11 files changed

+946
-339
lines changed

docs/changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
- Various optimisations have been made to reduce the number of file requests when creating a model class from an existing [zarr.Group][].
1212
- [ome_zarr_models.open_ome_zarr][] now includes the name of the group it was trying to validate alongside the validation error message.
1313

14+
## Bug fixes
15+
16+
- Axes metadata for version 0.4 and 0.5 has been fixed to require the `name` field.
17+
1418
## 1.1
1519

1620
### New Features

src/ome_zarr_models/_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121

2222
if TYPE_CHECKING:
23-
from collections.abc import Hashable, Iterable
23+
from collections.abc import Iterable
2424

2525
import zarr
2626
from zarr.abc.store import Store
@@ -184,7 +184,10 @@ def get_store_path(store: Store) -> str:
184184
return ""
185185

186186

187-
def duplicates(values: Iterable[Hashable]) -> dict[Hashable, int]:
187+
T = TypeVar("T")
188+
189+
190+
def duplicates(values: Iterable[T]) -> dict[T, int]:
188191
"""
189192
Takes a sequence of hashable elements and returns a dict where the keys are the
190193
elements of the input that occurred at least once, and the values are the

src/ome_zarr_models/_v06/axes.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,28 @@
1-
from ome_zarr_models.common.axes import Axes, Axis, AxisType
1+
from collections.abc import Sequence
2+
from typing import Literal
3+
4+
from pydantic import JsonValue
5+
6+
from ome_zarr_models.base import BaseAttrs
27

38
__all__ = ["Axes", "Axis", "AxisType"]
9+
10+
11+
AxisType = Literal["space", "time", "channel"]
12+
13+
14+
class Axis(BaseAttrs):
15+
"""
16+
Model for an element of `Multiscale.axes`.
17+
"""
18+
19+
# Explicitly name could be any JsonValue, but implicitly it must match Zarr array
20+
# dimension_names which limits it to str | None
21+
22+
name: str | None
23+
type: str | None = None
24+
# Unit probably intended to be str, but the spec doesn't explicitly specify
25+
unit: str | JsonValue | None = None
26+
27+
28+
Axes = Sequence[Axis]
Lines changed: 281 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,288 @@
1-
from ome_zarr_models.common.multiscales import Dataset, MultiscaleBase
1+
from __future__ import annotations
2+
3+
from collections import Counter
4+
from typing import TYPE_CHECKING, Any, Literal, Self
5+
6+
from pydantic import (
7+
BaseModel,
8+
Field,
9+
JsonValue,
10+
SerializerFunctionWrapHandler,
11+
field_validator,
12+
model_serializer,
13+
model_validator,
14+
)
15+
16+
from ome_zarr_models._utils import duplicates
17+
from ome_zarr_models._v06.axes import Axes
18+
from ome_zarr_models.base import BaseAttrs
19+
from ome_zarr_models.common.coordinate_transformations import (
20+
ScaleTransform,
21+
Transform,
22+
TranslationTransform,
23+
VectorScale,
24+
VectorTransform,
25+
_build_transforms,
26+
_ndim,
27+
)
28+
from ome_zarr_models.common.validation import check_length, check_ordered_scales
29+
30+
if TYPE_CHECKING:
31+
from collections.abc import Sequence
32+
233

334
__all__ = ["Dataset", "Multiscale"]
435

536

6-
class Multiscale(MultiscaleBase):
37+
VALID_NDIM = (2, 3, 4, 5)
38+
ValidTransform = tuple[ScaleTransform] | tuple[ScaleTransform, TranslationTransform]
39+
40+
41+
class Multiscale(BaseAttrs):
742
"""
843
An element of multiscales metadata.
944
"""
45+
46+
axes: Axes
47+
datasets: tuple[Dataset, ...] = Field(..., min_length=1)
48+
coordinateTransformations: ValidTransform | None = None
49+
metadata: JsonValue = None
50+
name: JsonValue | None = None
51+
type: JsonValue = None
52+
version: Literal["0.4"] | None = None
53+
54+
@model_serializer(mode="wrap")
55+
def _serialize(
56+
self,
57+
serializer: SerializerFunctionWrapHandler,
58+
) -> dict[str, Any]:
59+
d: dict[str, Any] = serializer(self)
60+
if self.coordinateTransformations is None:
61+
d.pop("coordinateTransformations", None)
62+
63+
return d
64+
65+
@property
66+
def ndim(self) -> int:
67+
"""
68+
Dimensionality of the data described by this metadata.
69+
70+
Determined by the length of the axes attribute.
71+
"""
72+
return len(self.axes)
73+
74+
@model_validator(mode="after")
75+
def _ensure_axes_top_transforms(data: Self) -> Self:
76+
"""
77+
Ensure that the length of the axes matches the dimensionality of the transforms
78+
defined in the top-level coordinateTransformations, if present.
79+
"""
80+
self_ndim = len(data.axes)
81+
if data.coordinateTransformations is not None:
82+
for tx in data.coordinateTransformations:
83+
if hasattr(tx, "ndim") and self_ndim != tx.ndim:
84+
msg = (
85+
f"The length of axes does not match the dimensionality of "
86+
f"the {tx.type} transform in coordinateTransformations. "
87+
f"Got {self_ndim} axes, but the {tx.type} transform has "
88+
f"dimensionality {tx.ndim}"
89+
)
90+
raise ValueError(msg)
91+
return data
92+
93+
@model_validator(mode="after")
94+
def _ensure_axes_dataset_transforms(data: Self) -> Self:
95+
"""
96+
Ensure that the length of the axes matches the dimensionality of the transforms
97+
"""
98+
self_ndim = len(data.axes)
99+
for ds_idx, ds in enumerate(data.datasets):
100+
for tx in ds.coordinateTransformations:
101+
if hasattr(tx, "ndim") and self_ndim != tx.ndim:
102+
msg = (
103+
f"The length of axes does not match the dimensionality of "
104+
f"the {tx.type} transform in "
105+
f"datasets[{ds_idx}].coordinateTransformations. "
106+
f"Got {self_ndim} axes, but the {tx.type} transform has "
107+
f"dimensionality {tx.ndim}"
108+
)
109+
raise ValueError(msg)
110+
return data
111+
112+
@field_validator("datasets", mode="after")
113+
@classmethod
114+
def _ensure_ordered_scales(cls, datasets: list[Dataset]) -> list[Dataset]:
115+
"""
116+
Make sure datasets are ordered from highest resolution to smallest.
117+
"""
118+
scale_transforms = [d.coordinateTransformations[0] for d in datasets]
119+
# Only handle scales given in metadata, not in files
120+
scale_vector_transforms = [
121+
t for t in scale_transforms if isinstance(t, VectorScale)
122+
]
123+
check_ordered_scales(scale_vector_transforms)
124+
return datasets
125+
126+
@field_validator("axes", mode="after")
127+
@classmethod
128+
def _ensure_axis_length(cls, axes: Axes) -> Axes:
129+
"""
130+
Ensures that there are between 2 and 5 axes (inclusive)
131+
"""
132+
check_length(axes, valid_lengths=VALID_NDIM, variable_name="axes")
133+
return axes
134+
135+
@field_validator("axes", mode="after")
136+
@classmethod
137+
def _ensure_axis_types(cls, axes: Axes) -> Axes:
138+
"""
139+
Ensures that the following conditions are true:
140+
141+
- there are only 2 or 3 axes with type `space`
142+
- the axes with type `space` are last in the list of axes
143+
- there is only 1 axis with type `time`
144+
- there is only 1 axis with type `channel`
145+
- there is only 1 axis with a type that is not `space`, `time`, or `channel`
146+
"""
147+
check_length(
148+
[ax for ax in axes if ax.type == "space"],
149+
valid_lengths=[2, 3],
150+
variable_name="space axes",
151+
)
152+
check_length(
153+
[ax for ax in axes if ax.type == "time"],
154+
valid_lengths=[0, 1],
155+
variable_name="time axes",
156+
)
157+
check_length(
158+
[ax for ax in axes if ax.type == "channel"],
159+
valid_lengths=[0, 1],
160+
variable_name="channel axes",
161+
)
162+
check_length(
163+
[ax for ax in axes if ax.type not in ["space", "time", "channel"]],
164+
valid_lengths=[0, 1],
165+
variable_name="custom axes",
166+
)
167+
168+
axis_types = [ax.type for ax in axes]
169+
type_census = Counter(axis_types)
170+
num_spaces = type_census["space"]
171+
if not all(a == "space" for a in axis_types[-num_spaces:]):
172+
msg = (
173+
f"All space axes must be at the end of the axes list. "
174+
f"Got axes with order: {axis_types}."
175+
)
176+
raise ValueError(msg)
177+
178+
num_times = type_census["time"]
179+
if num_times == 1 and axis_types[0] != "time":
180+
msg = "Time axis must be at the beginning of axis list."
181+
raise ValueError(msg)
182+
183+
return axes
184+
185+
@field_validator("axes", mode="after")
186+
@classmethod
187+
def _ensure_unique_axis_names(cls, axes: Axes) -> Axes:
188+
"""
189+
Ensures that the names of the axes are unique.
190+
"""
191+
name_dupes = duplicates(a.name for a in axes)
192+
if len(name_dupes) > 0:
193+
msg = (
194+
f"Axis names must be unique. Axis names {tuple(name_dupes.keys())} are "
195+
"repeated."
196+
)
197+
raise ValueError(msg)
198+
return axes
199+
200+
201+
class Dataset(BaseAttrs):
202+
"""
203+
An element of Multiscale.datasets.
204+
"""
205+
206+
# TODO: validate that path resolves to an actual zarr array
207+
# TODO: can we validate that the paths must be ordered from highest resolution to
208+
# smallest using scale metadata?
209+
path: str
210+
coordinateTransformations: ValidTransform
211+
212+
@classmethod
213+
def build(
214+
cls, *, path: str, scale: Sequence[float], translation: Sequence[float] | None
215+
) -> Self:
216+
"""
217+
Construct a `Dataset` from a path, a scale, and a translation.
218+
"""
219+
return cls(
220+
path=path,
221+
coordinateTransformations=_build_transforms(
222+
scale=scale, translation=translation
223+
),
224+
)
225+
226+
@field_validator("coordinateTransformations", mode="before")
227+
def _ensure_scale_translation(
228+
transforms_obj: object,
229+
) -> object:
230+
"""
231+
Ensures that
232+
- there are only 1 or 2 transforms.
233+
- the first element is a scale transformation
234+
- the second element, if present, is a translation transform
235+
"""
236+
# This is used as a before validator - to help use, we use pydantic to first
237+
# cast the input (which can in general anything) into a set of transformations.
238+
# Then we check the transformations are valid.
239+
#
240+
# This is a bit convoluted, but we do it because the default pydantic error
241+
# messages are a mess otherwise
242+
243+
class Transforms(BaseModel):
244+
transforms: list[Transform]
245+
246+
transforms = Transforms(transforms=transforms_obj).transforms
247+
check_length(transforms, valid_lengths=[1, 2], variable_name="transforms")
248+
249+
maybe_scale = transforms[0]
250+
if maybe_scale.type != "scale":
251+
msg = (
252+
"The first element of `coordinateTransformations` must be a scale "
253+
f"transform. Got {maybe_scale} instead."
254+
)
255+
raise ValueError(msg)
256+
if len(transforms) == 2:
257+
maybe_trans = transforms[1]
258+
if (maybe_trans.type) != "translation":
259+
msg = (
260+
"The second element of `coordinateTransformations` must be a "
261+
f"translation transform. Got {maybe_trans} instead."
262+
)
263+
raise ValueError(msg)
264+
265+
return transforms_obj
266+
267+
@field_validator("coordinateTransformations", mode="after")
268+
@classmethod
269+
def _ensure_transform_dimensionality(
270+
cls,
271+
transforms: ValidTransform,
272+
) -> ValidTransform:
273+
"""
274+
Ensures that the elements in the input sequence define transformations with
275+
identical dimensionality. If any of the transforms are defined with a path
276+
instead of concrete values, then no validation will be performed and the
277+
transforms will be returned as-is.
278+
"""
279+
vector_transforms = filter(lambda v: isinstance(v, VectorTransform), transforms)
280+
ndims = tuple(map(_ndim, vector_transforms)) # type: ignore[arg-type]
281+
ndims_set = set(ndims)
282+
if len(ndims_set) > 1:
283+
msg = (
284+
"The transforms have inconsistent dimensionality. "
285+
f"Got transforms with dimensionality = {ndims}."
286+
)
287+
raise ValueError(msg)
288+
return transforms

src/ome_zarr_models/common/axes.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)