Skip to content

Commit 20696b5

Browse files
committed
deduplicate storage option util; use chunks from data when not specified in storage options
1 parent 930922c commit 20696b5

File tree

2 files changed

+24
-27
lines changed

2 files changed

+24
-27
lines changed

src/spatialdata/_io/io_raster.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,34 +90,27 @@ def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int:
9090
return normalized
9191

9292

93-
def _prepare_single_scale_storage_options(
93+
def _prepare_storage_options(
9494
storage_options: JSONDict | list[JSONDict] | None,
95-
) -> JSONDict | list[JSONDict] | None:
95+
data: list[da.Array],
96+
) -> list[JSONDict]:
9697
if storage_options is None:
97-
return None
98-
if isinstance(storage_options, dict):
99-
prepared = dict(storage_options)
100-
if "chunks" in prepared:
101-
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
102-
return prepared
103-
return [dict(options) for options in storage_options]
104-
105-
106-
def _prepare_multiscale_storage_options(
107-
storage_options: JSONDict | list[JSONDict] | None,
108-
) -> JSONDict | list[JSONDict] | None:
109-
if storage_options is None:
110-
return None
98+
return [{"chunks": _normalize_explicit_chunks(arr.chunks)} for arr in data]
11199
if isinstance(storage_options, dict):
100+
if "chunks" not in storage_options:
101+
return [{**storage_options, "chunks": _normalize_explicit_chunks(arr.chunks)} for arr in data]
112102
prepared = dict(storage_options)
113-
if "chunks" in prepared:
114-
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
115-
return prepared
116-
117-
prepared_options = [dict(options) for options in storage_options]
118-
for options in prepared_options:
119-
if "chunks" in options:
120-
options["chunks"] = _normalize_explicit_chunks(options["chunks"])
103+
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
104+
return prepared # type: ignore[return-value]
105+
106+
prepared_options = []
107+
for i, options in enumerate(storage_options):
108+
opts = dict(options)
109+
if "chunks" not in opts:
110+
opts["chunks"] = _normalize_explicit_chunks(data[i].chunks)
111+
else:
112+
opts["chunks"] = _normalize_explicit_chunks(opts["chunks"])
113+
prepared_options.append(opts)
121114
return prepared_options
122115

123116

@@ -335,7 +328,7 @@ def _write_raster_dataarray(
335328
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
336329
input_axes: tuple[str, ...] = tuple(raster_data.dims)
337330
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
338-
storage_options = _prepare_single_scale_storage_options(storage_options)
331+
storage_options = _prepare_storage_options(storage_options, [data])
339332
# Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default
340333
# write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ...
341334
# even when the input is a plain DataArray.
@@ -405,7 +398,7 @@ def _write_raster_datatree(
405398
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
406399

407400
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
408-
storage_options = _prepare_multiscale_storage_options(storage_options)
401+
storage_options = _prepare_storage_options(storage_options, data)
409402
ome_zarr_format = get_ome_zarr_format(raster_format)
410403
dask_delayed = write_multi_scale_ngff(
411404
pyramid=data,

tests/io/test_readwrite.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,11 @@ def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path:
631631
image = Image2DModel.parse(data, dims=("c", "y", "x"))
632632
sdata = SpatialData(images={"image": image})
633633

634-
sdata.write(tmp_path / "data.zarr")
634+
with pytest.raises(
635+
ValueError,
636+
match="storage_options\\['chunks'\\] must be a Zarr chunk shape or a regular Dask chunk grid",
637+
):
638+
sdata.write(tmp_path / "data.zarr")
635639

636640

637641
def test_write_image_normalizes_explicit_regular_dask_chunk_grid(tmp_path: Path) -> None:

0 commit comments

Comments
 (0)