Skip to content

Commit ee34d7c

Browse files
committed
code edits for zarr v3; tests fail
1 parent 0b8d9e4 commit ee34d7c

2 files changed

Lines changed: 84 additions & 48 deletions

File tree

src/spatialdata/_io/io_raster.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Literal
2+
from typing import Any, Literal, cast
33

44
import dask.array as da
55
import numpy as np
@@ -231,7 +231,9 @@ def _write_raster(
231231

232232

233233
def _apply_compression(
234-
storage_options: JSONDict | list[JSONDict], compressor: dict[Literal["lz4", "zstd"], int] | None
234+
storage_options: JSONDict | list[JSONDict],
235+
compressor: dict[Literal["lz4", "zstd"], int] | None,
236+
zarr_format: Literal[2, 3] = 3,
235237
) -> JSONDict | list[JSONDict]:
236238
"""Apply compression settings to storage options.
237239
@@ -241,23 +243,39 @@ def _apply_compression(
241243
Storage options for zarr arrays
242244
compressor
243245
Compression settings as a dictionary with a single key-value pair
246+
zarr_format
247+
The zarr format version (2 or 3)
244248
245249
Returns
246250
-------
247251
Updated storage options with compression settings
248252
"""
249-
from zarr.codecs import Blosc
253+
# For zarr disk format v2, use numcodecs.Blosc
254+
# For zarr disk format v3, use zarr.codecs.Blosc
255+
from numcodecs import Blosc as BloscV2
256+
from zarr.codecs import Blosc as BloscV3
250257

251258
if not compressor:
252259
return storage_options
253260

254261
((compression, compression_level),) = compressor.items()
255262

263+
assert BloscV2.SHUFFLE == 1
264+
blosc_v2 = BloscV2(cname=compression, clevel=compression_level, shuffle=1)
265+
blosc_v3 = BloscV3(cname=compression, clevel=compression_level, shuffle=1)
266+
267+
def _update_dict(d: dict[str, Any]) -> None:
268+
if zarr_format == 2:
269+
d["compressor"] = blosc_v2
270+
elif zarr_format == 3:
271+
d["zarr_array_kwargs"] = {"compressors": [blosc_v3]}
272+
256273
if isinstance(storage_options, dict):
257-
storage_options["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1)
258-
elif isinstance(storage_options, list):
274+
_update_dict(d=storage_options)
275+
else:
276+
assert isinstance(storage_options, list)
259277
for option in storage_options:
260-
option["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1)
278+
_update_dict(d=option)
261279

262280
return storage_options
263281

@@ -311,7 +329,9 @@ def _write_raster_dataarray(
311329
storage_options = {"chunks": chunks}
312330

313331
# Apply compression if specified
314-
storage_options = _apply_compression(storage_options, compressor)
332+
storage_options = _apply_compression(
333+
storage_options, compressor, zarr_format=cast(Literal[2, 3], raster_format.zarr_format)
334+
)
315335

316336
# Scaler needs to be None since we are passing the data already downscaled for the multiscale case.
317337
# We need this because the argument of write_image_ngff is called image while the argument of
@@ -388,7 +408,7 @@ def _write_raster_datatree(
388408
else:
389409
storage_options = [{"chunks": chunk} for chunk in chunks]
390410
# Apply compression if specified
391-
storage_options = _apply_compression(storage_options, compressor)
411+
storage_options = _apply_compression(storage_options, compressor, zarr_format=raster_format.zarr_format)
392412

393413
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
394414
ome_zarr_format = get_ome_zarr_format(raster_format)

tests/io/test_readwrite.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,20 @@ def test_multiple_tables(
149149
sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))})
150150
self._test_table(tmp_path, sdata_tables, sdata_container_format=sdata_container_format)
151151

152+
def test_roundtrip(
153+
self,
154+
tmp_path: str,
155+
sdata: SpatialData,
156+
sdata_container_format: SpatialDataContainerFormatType,
157+
) -> None:
158+
tmpdir = Path(tmp_path) / "tmp.zarr"
159+
160+
sdata.write(tmpdir, sdata_formats=sdata_container_format)
161+
sdata2 = SpatialData.read(tmpdir)
162+
tmpdir2 = Path(tmp_path) / "tmp2.zarr"
163+
sdata2.write(tmpdir2, sdata_formats=sdata_container_format)
164+
_are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*")
165+
152166
def test_compression_roundtrip(
153167
self,
154168
tmp_path: str,
@@ -168,31 +182,55 @@ def test_compression_roundtrip(
168182
full_sdata.write(tmpdir, compressor={"zstd": 8}, sdata_formats=sdata_container_format)
169183

170184
# sourcery skip: no-loop-in-tests
171-
for element in ["image2d", "image2d_multiscale"]:
172-
arr = zarr.open_group(tmpdir / "images", mode="r")[element]["0"]
185+
for element in ["image2d", "image2d_multiscale", "labels2d", "labels2d_multiscale"]:
186+
element_type = "images" if element.startswith("image") else "labels"
187+
arr = zarr.open_group(tmpdir / element_type, mode="r")[element]["0"]
173188
compressor = arr.compressors[0]
174-
assert compressor.cnam == "zstd"
175-
assert compressor.clevel == 8
176189

177-
for element in ["labels2d", "labels2d_multiscale"]:
178-
arr = zarr.open_group(tmpdir / "labels", mode="r")[element]["0"]
179-
compressor = arr.compressors[0]
180-
assert compressor.cname == "zstd"
181-
assert compressor.clevel == 8
190+
# TODO: all these tests fail because the compression arguments are not passed to Dask
191+
if sdata_container_format.zarr_format == 2:
192+
assert compressor.cname == "zstd"
193+
assert compressor.clevel == 8
194+
elif sdata_container_format.zarr_format == 3:
195+
from zarr.codecs.zstd import ZstdCodec
182196

183-
def test_roundtrip(
197+
assert isinstance(compressor, ZstdCodec)
198+
assert compressor.level == 8
199+
200+
@pytest.mark.parametrize("compressor", [{"lz4": 3}, {"zstd": 7}])
201+
@pytest.mark.parametrize("element", [("images", "image2d"), ("labels", "labels2d")])
202+
def test_write_element_compression(
184203
self,
185204
tmp_path: str,
186-
sdata: SpatialData,
205+
full_sdata: SpatialData,
206+
compressor: dict[Literal["lz4", "zstd"], int],
207+
element: str,
187208
sdata_container_format: SpatialDataContainerFormatType,
188-
) -> None:
189-
tmpdir = Path(tmp_path) / "tmp.zarr"
190-
209+
):
210+
tmpdir = Path(tmp_path) / "compression.zarr"
211+
sdata = SpatialData()
191212
sdata.write(tmpdir, sdata_formats=sdata_container_format)
192-
sdata2 = SpatialData.read(tmpdir)
193-
tmpdir2 = Path(tmp_path) / "tmp2.zarr"
194-
sdata2.write(tmpdir2, sdata_formats=sdata_container_format)
195-
_are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*")
213+
214+
sdata["element"] = full_sdata[element[1]]
215+
sdata.write_element("element", compressor=compressor, sdata_formats=sdata_container_format)
216+
217+
arr = zarr.open_group(tmpdir / element[0], mode="r")["element"]["0"]
218+
compression = arr.compressors[0]
219+
220+
# TODO: all these tests fail because the compression arguments are not passed to Dask
221+
if sdata_container_format.zarr_format == 2:
222+
assert compression.cname == list(compressor.keys())[0]
223+
assert compression.clevel == list(compressor.values())[0]
224+
elif sdata_container_format.zarr_format == 3:
225+
from zarr.codecs import ZstdCodec
226+
227+
compressor_name = list(compressor.keys())[0]
228+
if compressor_name == "zstd":
229+
assert isinstance(compression, ZstdCodec)
230+
# TODO: fix
231+
# elif compressor_name == 'lz4':
232+
# assert isinstance(compression, ???)
233+
assert compression.level == list(compressor.values())[0]
196234

197235
def test_incremental_io_list_of_elements(
198236
self,
@@ -381,28 +419,6 @@ def test_io_and_lazy_loading_raster(self, images, labels, sdata_container_format
381419
assert any("from-zarr" in key for key in dask1.dask.layers)
382420
assert len(get_dask_backing_files(sdata2)) > 0
383421

384-
@pytest.mark.parametrize("compressor", [{"lz4": 3}, {"zstd": 7}])
385-
@pytest.mark.parametrize("element", [("images", "image2d"), ("labels", "labels2d")])
386-
def test_write_element_compression(
387-
self,
388-
tmp_path: str,
389-
full_sdata: SpatialData,
390-
compressor: dict[Literal["lz4", "zstd"], int],
391-
element: str,
392-
sdata_container_format: SpatialDataContainerFormatType,
393-
):
394-
tmpdir = Path(tmp_path) / "compression.zarr"
395-
sdata = SpatialData()
396-
sdata.write(tmpdir, sdata_formats=sdata_container_format)
397-
398-
sdata["element"] = full_sdata[element[1]]
399-
sdata.write_element("element", compressor=compressor, sdata_formats=sdata_container_format)
400-
401-
arr = zarr.open_group(tmpdir / element[0], mode="r")["element"]["0"]
402-
compression = arr.compressors[0]
403-
assert compression.cname == list(compressor.keys())[0]
404-
assert compression.clevel == list(compressor.values())[0]
405-
406422
def test_replace_transformation_on_disk_raster(
407423
self, images, labels, sdata_container_format: SpatialDataContainerFormatType
408424
):

0 commit comments

Comments
 (0)