|
1 | | -from ome_zarr_models.common.multiscales import Dataset, MultiscaleBase |
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from collections import Counter |
| 4 | +from typing import TYPE_CHECKING, Any, Literal, Self |
| 5 | + |
| 6 | +from pydantic import ( |
| 7 | + BaseModel, |
| 8 | + Field, |
| 9 | + JsonValue, |
| 10 | + SerializerFunctionWrapHandler, |
| 11 | + field_validator, |
| 12 | + model_serializer, |
| 13 | + model_validator, |
| 14 | +) |
| 15 | + |
| 16 | +from ome_zarr_models._utils import duplicates |
| 17 | +from ome_zarr_models._v06.axes import Axes |
| 18 | +from ome_zarr_models.base import BaseAttrs |
| 19 | +from ome_zarr_models.common.coordinate_transformations import ( |
| 20 | + ScaleTransform, |
| 21 | + Transform, |
| 22 | + TranslationTransform, |
| 23 | + VectorScale, |
| 24 | + VectorTransform, |
| 25 | + _build_transforms, |
| 26 | + _ndim, |
| 27 | +) |
| 28 | +from ome_zarr_models.common.validation import check_length, check_ordered_scales |
| 29 | + |
| 30 | +if TYPE_CHECKING: |
| 31 | + from collections.abc import Sequence |
| 32 | + |
2 | 33 |
|
3 | 34 | __all__ = ["Dataset", "Multiscale"] |
4 | 35 |
|
5 | 36 |
|
6 | | -class Multiscale(MultiscaleBase): |
| 37 | +VALID_NDIM = (2, 3, 4, 5) |
| 38 | +ValidTransform = tuple[ScaleTransform] | tuple[ScaleTransform, TranslationTransform] |
| 39 | + |
| 40 | + |
| 41 | +class Multiscale(BaseAttrs): |
7 | 42 | """ |
8 | 43 | An element of multiscales metadata. |
9 | 44 | """ |
| 45 | + |
| 46 | + axes: Axes |
| 47 | + datasets: tuple[Dataset, ...] = Field(..., min_length=1) |
| 48 | + coordinateTransformations: ValidTransform | None = None |
| 49 | + metadata: JsonValue = None |
| 50 | + name: JsonValue | None = None |
| 51 | + type: JsonValue = None |
| 52 | + version: Literal["0.4"] | None = None |
| 53 | + |
| 54 | + @model_serializer(mode="wrap") |
| 55 | + def _serialize( |
| 56 | + self, |
| 57 | + serializer: SerializerFunctionWrapHandler, |
| 58 | + ) -> dict[str, Any]: |
| 59 | + d: dict[str, Any] = serializer(self) |
| 60 | + if self.coordinateTransformations is None: |
| 61 | + d.pop("coordinateTransformations", None) |
| 62 | + |
| 63 | + return d |
| 64 | + |
| 65 | + @property |
| 66 | + def ndim(self) -> int: |
| 67 | + """ |
| 68 | + Dimensionality of the data described by this metadata. |
| 69 | +
|
| 70 | + Determined by the length of the axes attribute. |
| 71 | + """ |
| 72 | + return len(self.axes) |
| 73 | + |
| 74 | + @model_validator(mode="after") |
| 75 | + def _ensure_axes_top_transforms(data: Self) -> Self: |
| 76 | + """ |
| 77 | + Ensure that the length of the axes matches the dimensionality of the transforms |
| 78 | + defined in the top-level coordinateTransformations, if present. |
| 79 | + """ |
| 80 | + self_ndim = len(data.axes) |
| 81 | + if data.coordinateTransformations is not None: |
| 82 | + for tx in data.coordinateTransformations: |
| 83 | + if hasattr(tx, "ndim") and self_ndim != tx.ndim: |
| 84 | + msg = ( |
| 85 | + f"The length of axes does not match the dimensionality of " |
| 86 | + f"the {tx.type} transform in coordinateTransformations. " |
| 87 | + f"Got {self_ndim} axes, but the {tx.type} transform has " |
| 88 | + f"dimensionality {tx.ndim}" |
| 89 | + ) |
| 90 | + raise ValueError(msg) |
| 91 | + return data |
| 92 | + |
| 93 | + @model_validator(mode="after") |
| 94 | + def _ensure_axes_dataset_transforms(data: Self) -> Self: |
| 95 | + """ |
| 96 | + Ensure that the length of the axes matches the dimensionality of the transforms |
| 97 | + """ |
| 98 | + self_ndim = len(data.axes) |
| 99 | + for ds_idx, ds in enumerate(data.datasets): |
| 100 | + for tx in ds.coordinateTransformations: |
| 101 | + if hasattr(tx, "ndim") and self_ndim != tx.ndim: |
| 102 | + msg = ( |
| 103 | + f"The length of axes does not match the dimensionality of " |
| 104 | + f"the {tx.type} transform in " |
| 105 | + f"datasets[{ds_idx}].coordinateTransformations. " |
| 106 | + f"Got {self_ndim} axes, but the {tx.type} transform has " |
| 107 | + f"dimensionality {tx.ndim}" |
| 108 | + ) |
| 109 | + raise ValueError(msg) |
| 110 | + return data |
| 111 | + |
| 112 | + @field_validator("datasets", mode="after") |
| 113 | + @classmethod |
| 114 | + def _ensure_ordered_scales(cls, datasets: list[Dataset]) -> list[Dataset]: |
| 115 | + """ |
| 116 | + Make sure datasets are ordered from highest resolution to smallest. |
| 117 | + """ |
| 118 | + scale_transforms = [d.coordinateTransformations[0] for d in datasets] |
| 119 | + # Only handle scales given in metadata, not in files |
| 120 | + scale_vector_transforms = [ |
| 121 | + t for t in scale_transforms if isinstance(t, VectorScale) |
| 122 | + ] |
| 123 | + check_ordered_scales(scale_vector_transforms) |
| 124 | + return datasets |
| 125 | + |
| 126 | + @field_validator("axes", mode="after") |
| 127 | + @classmethod |
| 128 | + def _ensure_axis_length(cls, axes: Axes) -> Axes: |
| 129 | + """ |
| 130 | + Ensures that there are between 2 and 5 axes (inclusive) |
| 131 | + """ |
| 132 | + check_length(axes, valid_lengths=VALID_NDIM, variable_name="axes") |
| 133 | + return axes |
| 134 | + |
| 135 | + @field_validator("axes", mode="after") |
| 136 | + @classmethod |
| 137 | + def _ensure_axis_types(cls, axes: Axes) -> Axes: |
| 138 | + """ |
| 139 | + Ensures that the following conditions are true: |
| 140 | +
|
| 141 | + - there are only 2 or 3 axes with type `space` |
| 142 | + - the axes with type `space` are last in the list of axes |
| 143 | + - there is only 1 axis with type `time` |
| 144 | + - there is only 1 axis with type `channel` |
| 145 | + - there is only 1 axis with a type that is not `space`, `time`, or `channel` |
| 146 | + """ |
| 147 | + check_length( |
| 148 | + [ax for ax in axes if ax.type == "space"], |
| 149 | + valid_lengths=[2, 3], |
| 150 | + variable_name="space axes", |
| 151 | + ) |
| 152 | + check_length( |
| 153 | + [ax for ax in axes if ax.type == "time"], |
| 154 | + valid_lengths=[0, 1], |
| 155 | + variable_name="time axes", |
| 156 | + ) |
| 157 | + check_length( |
| 158 | + [ax for ax in axes if ax.type == "channel"], |
| 159 | + valid_lengths=[0, 1], |
| 160 | + variable_name="channel axes", |
| 161 | + ) |
| 162 | + check_length( |
| 163 | + [ax for ax in axes if ax.type not in ["space", "time", "channel"]], |
| 164 | + valid_lengths=[0, 1], |
| 165 | + variable_name="custom axes", |
| 166 | + ) |
| 167 | + |
| 168 | + axis_types = [ax.type for ax in axes] |
| 169 | + type_census = Counter(axis_types) |
| 170 | + num_spaces = type_census["space"] |
| 171 | + if not all(a == "space" for a in axis_types[-num_spaces:]): |
| 172 | + msg = ( |
| 173 | + f"All space axes must be at the end of the axes list. " |
| 174 | + f"Got axes with order: {axis_types}." |
| 175 | + ) |
| 176 | + raise ValueError(msg) |
| 177 | + |
| 178 | + num_times = type_census["time"] |
| 179 | + if num_times == 1 and axis_types[0] != "time": |
| 180 | + msg = "Time axis must be at the beginning of axis list." |
| 181 | + raise ValueError(msg) |
| 182 | + |
| 183 | + return axes |
| 184 | + |
| 185 | + @field_validator("axes", mode="after") |
| 186 | + @classmethod |
| 187 | + def _ensure_unique_axis_names(cls, axes: Axes) -> Axes: |
| 188 | + """ |
| 189 | + Ensures that the names of the axes are unique. |
| 190 | + """ |
| 191 | + name_dupes = duplicates(a.name for a in axes) |
| 192 | + if len(name_dupes) > 0: |
| 193 | + msg = ( |
| 194 | + f"Axis names must be unique. Axis names {tuple(name_dupes.keys())} are " |
| 195 | + "repeated." |
| 196 | + ) |
| 197 | + raise ValueError(msg) |
| 198 | + return axes |
| 199 | + |
| 200 | + |
| 201 | +class Dataset(BaseAttrs): |
| 202 | + """ |
| 203 | + An element of Multiscale.datasets. |
| 204 | + """ |
| 205 | + |
| 206 | + # TODO: validate that path resolves to an actual zarr array |
| 207 | + # TODO: can we validate that the paths must be ordered from highest resolution to |
| 208 | + # smallest using scale metadata? |
| 209 | + path: str |
| 210 | + coordinateTransformations: ValidTransform |
| 211 | + |
| 212 | + @classmethod |
| 213 | + def build( |
| 214 | + cls, *, path: str, scale: Sequence[float], translation: Sequence[float] | None |
| 215 | + ) -> Self: |
| 216 | + """ |
| 217 | + Construct a `Dataset` from a path, a scale, and a translation. |
| 218 | + """ |
| 219 | + return cls( |
| 220 | + path=path, |
| 221 | + coordinateTransformations=_build_transforms( |
| 222 | + scale=scale, translation=translation |
| 223 | + ), |
| 224 | + ) |
| 225 | + |
| 226 | + @field_validator("coordinateTransformations", mode="before") |
| 227 | + def _ensure_scale_translation( |
| 228 | + transforms_obj: object, |
| 229 | + ) -> object: |
| 230 | + """ |
| 231 | + Ensures that |
| 232 | + - there are only 1 or 2 transforms. |
| 233 | + - the first element is a scale transformation |
| 234 | + - the second element, if present, is a translation transform |
| 235 | + """ |
| 236 | + # This is used as a before validator - to help use, we use pydantic to first |
| 237 | + # cast the input (which can in general anything) into a set of transformations. |
| 238 | + # Then we check the transformations are valid. |
| 239 | + # |
| 240 | + # This is a bit convoluted, but we do it because the default pydantic error |
| 241 | + # messages are a mess otherwise |
| 242 | + |
| 243 | + class Transforms(BaseModel): |
| 244 | + transforms: list[Transform] |
| 245 | + |
| 246 | + transforms = Transforms(transforms=transforms_obj).transforms |
| 247 | + check_length(transforms, valid_lengths=[1, 2], variable_name="transforms") |
| 248 | + |
| 249 | + maybe_scale = transforms[0] |
| 250 | + if maybe_scale.type != "scale": |
| 251 | + msg = ( |
| 252 | + "The first element of `coordinateTransformations` must be a scale " |
| 253 | + f"transform. Got {maybe_scale} instead." |
| 254 | + ) |
| 255 | + raise ValueError(msg) |
| 256 | + if len(transforms) == 2: |
| 257 | + maybe_trans = transforms[1] |
| 258 | + if (maybe_trans.type) != "translation": |
| 259 | + msg = ( |
| 260 | + "The second element of `coordinateTransformations` must be a " |
| 261 | + f"translation transform. Got {maybe_trans} instead." |
| 262 | + ) |
| 263 | + raise ValueError(msg) |
| 264 | + |
| 265 | + return transforms_obj |
| 266 | + |
| 267 | + @field_validator("coordinateTransformations", mode="after") |
| 268 | + @classmethod |
| 269 | + def _ensure_transform_dimensionality( |
| 270 | + cls, |
| 271 | + transforms: ValidTransform, |
| 272 | + ) -> ValidTransform: |
| 273 | + """ |
| 274 | + Ensures that the elements in the input sequence define transformations with |
| 275 | + identical dimensionality. If any of the transforms are defined with a path |
| 276 | + instead of concrete values, then no validation will be performed and the |
| 277 | + transforms will be returned as-is. |
| 278 | + """ |
| 279 | + vector_transforms = filter(lambda v: isinstance(v, VectorTransform), transforms) |
| 280 | + ndims = tuple(map(_ndim, vector_transforms)) # type: ignore[arg-type] |
| 281 | + ndims_set = set(ndims) |
| 282 | + if len(ndims_set) > 1: |
| 283 | + msg = ( |
| 284 | + "The transforms have inconsistent dimensionality. " |
| 285 | + f"Got transforms with dimensionality = {ndims}." |
| 286 | + ) |
| 287 | + raise ValueError(msg) |
| 288 | + return transforms |
0 commit comments