Skip to content

Commit e9ea783

Browse files
committed
ome zarr chunks
1 parent 6a3eef7 commit e9ea783

File tree

2 files changed

+117
-8
lines changed

2 files changed

+117
-8
lines changed

src/spatialdata/_io/io_raster.py

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import Sequence
34
from pathlib import Path
45
from typing import Any, Literal
56

@@ -38,6 +39,88 @@
3839
)
3940

4041

42+
def _is_flat_int_sequence(value: object) -> bool:
43+
if isinstance(value, str | bytes):
44+
return False
45+
if not isinstance(value, Sequence):
46+
return False
47+
return all(isinstance(v, int) for v in value)
48+
49+
50+
def _is_dask_chunk_grid(value: object) -> bool:
51+
if isinstance(value, str | bytes):
52+
return False
53+
if not isinstance(value, Sequence):
54+
return False
55+
return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value)
56+
57+
58+
def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool:
59+
# Match Dask's private _check_regular_chunks() logic without depending on its internal API.
60+
for axis_chunks in chunk_grid:
61+
if len(axis_chunks) <= 1:
62+
continue
63+
if len(set(axis_chunks[:-1])) > 1:
64+
return False
65+
if axis_chunks[-1] > axis_chunks[0]:
66+
return False
67+
return True
68+
69+
70+
def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None:
71+
if isinstance(chunks, int):
72+
return chunks
73+
if _is_flat_int_sequence(chunks):
74+
return tuple(chunks)
75+
if _is_dask_chunk_grid(chunks):
76+
chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks)
77+
if _is_regular_dask_chunk_grid(chunk_grid):
78+
return tuple(axis_chunks[0] for axis_chunks in chunk_grid)
79+
return None
80+
return None
81+
82+
83+
def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int:
84+
normalized = _chunks_to_zarr_chunks(chunks)
85+
if normalized is None:
86+
raise ValueError(
87+
"storage_options['chunks'] must be a Zarr chunk shape or a regular Dask chunk grid. "
88+
"Irregular Dask chunk grids must be rechunked before writing or omitted."
89+
)
90+
return normalized
91+
92+
93+
def _prepare_single_scale_storage_options(
94+
storage_options: JSONDict | list[JSONDict] | None,
95+
) -> JSONDict | list[JSONDict] | None:
96+
if storage_options is None:
97+
return None
98+
if isinstance(storage_options, dict):
99+
prepared = dict(storage_options)
100+
if "chunks" in prepared:
101+
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
102+
return prepared
103+
return [dict(options) for options in storage_options]
104+
105+
106+
def _prepare_multiscale_storage_options(
107+
storage_options: JSONDict | list[JSONDict] | None,
108+
) -> JSONDict | list[JSONDict] | None:
109+
if storage_options is None:
110+
return None
111+
if isinstance(storage_options, dict):
112+
prepared = dict(storage_options)
113+
if "chunks" in prepared:
114+
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
115+
return prepared
116+
117+
prepared_options = [dict(options) for options in storage_options]
118+
for options in prepared_options:
119+
if "chunks" in options:
120+
options["chunks"] = _normalize_explicit_chunks(options["chunks"])
121+
return prepared_options
122+
123+
41124
def _read_multiscale(
42125
store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format
43126
) -> DataArray | DataTree:
@@ -251,13 +334,8 @@ def _write_raster_dataarray(
251334
if transformations is None:
252335
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
253336
input_axes: tuple[str, ...] = tuple(raster_data.dims)
254-
chunks = raster_data.chunks
255337
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
256-
if storage_options is not None:
257-
if "chunks" not in storage_options and isinstance(storage_options, dict):
258-
storage_options["chunks"] = chunks
259-
else:
260-
storage_options = {"chunks": chunks}
338+
storage_options = _prepare_single_scale_storage_options(storage_options)
261339
# Scaler needs to be None since we are passing the data already downscaled for the multiscale case.
262340
# We need this because the argument of write_image_ngff is called image while the argument of
263341
# write_labels_ngff is called label.
@@ -322,10 +400,9 @@ def _write_raster_datatree(
322400
transformations = _get_transformations_xarray(xdata)
323401
if transformations is None:
324402
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
325-
chunks = get_pyramid_levels(raster_data, "chunks")
326403

327404
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
328-
storage_options = [{"chunks": chunk} for chunk in chunks]
405+
storage_options = _prepare_multiscale_storage_options(storage_options)
329406
ome_zarr_format = get_ome_zarr_format(raster_format)
330407
dask_delayed = write_multi_scale_ngff(
331408
pyramid=data,

tests/io/test_readwrite.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from typing import Any, Literal
99

10+
import dask.array as da
1011
import dask.dataframe as dd
1112
import numpy as np
1213
import pandas as pd
@@ -30,6 +31,7 @@
3031
SpatialDataContainerFormatType,
3132
SpatialDataContainerFormatV01,
3233
)
34+
from spatialdata._io.io_raster import write_image
3335
from spatialdata.datasets import blobs
3436
from spatialdata.models import Image2DModel
3537
from spatialdata.models._utils import get_channel_names
@@ -623,6 +625,36 @@ def test_bug_rechunking_after_queried_raster():
623625
queried.write(f)
624626

625627

628+
def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: Path) -> None:
629+
data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488)))
630+
image = Image2DModel.parse(data, dims=("c", "y", "x"))
631+
sdata = SpatialData(images={"image": image})
632+
633+
sdata.write(tmp_path / "data.zarr")
634+
635+
636+
def test_write_image_normalizes_explicit_regular_dask_chunk_grid(tmp_path: Path) -> None:
637+
data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488)))
638+
image = Image2DModel.parse(data, dims=("c", "y", "x"))
639+
group = zarr.open_group(tmp_path / "image.zarr", mode="w")
640+
641+
write_image(image, group, "image", storage_options={"chunks": image.data.chunks})
642+
643+
assert group["s0"].chunks == (3, 300, 512)
644+
645+
646+
def test_write_image_rejects_explicit_irregular_dask_chunk_grid(tmp_path: Path) -> None:
647+
data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488)))
648+
image = Image2DModel.parse(data, dims=("c", "y", "x"))
649+
group = zarr.open_group(tmp_path / "image.zarr", mode="w")
650+
651+
with pytest.raises(
652+
ValueError,
653+
match="storage_options\\['chunks'\\] must be a Zarr chunk shape or a regular Dask chunk grid",
654+
):
655+
write_image(image, group, "image", storage_options={"chunks": image.data.chunks})
656+
657+
626658
@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS)
627659
def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None:
628660
# data only in-memory, so the SpatialData object and all its elements are self-contained

0 commit comments

Comments
 (0)