Skip to content

Commit a892252

Browse files
authored
fix: compressor=None kwarg handling with zarr v3 (#2270)
1 parent af09beb commit a892252

3 files changed

Lines changed: 34 additions & 19 deletions

File tree

docs/release-notes/2270.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix `compressor` kwarg handling when writing to zarr v3 {user}`ilan-gold`

src/anndata/_io/specs/methods.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,21 @@
9494
# return False
9595

9696

97-
def zarr_v3_compressor_compat(dataset_kwargs) -> dict:
98-
if not is_zarr_v2() and (compressor := dataset_kwargs.pop("compressor", None)):
99-
dataset_kwargs["compressors"] = compressor
97+
def zarr_v3_compressor_compat(dataset_kwargs: dict) -> dict:
98+
"""Handle mismatch between our compressor kwarg and :func:`zarr.create_array` in v3's `compressors` arg
99+
See https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array
100+
101+
Parameters
102+
----------
103+
dataset_kwarg
104+
The kwarg dict potentially containing "compressor"
105+
106+
Returns
107+
-------
108+
The kwarg dict with "compressor" moved to "compressors" if zarr v3 is in use.
109+
"""
110+
if not is_zarr_v2() and "compressor" in dataset_kwargs:
111+
dataset_kwargs["compressors"] = dataset_kwargs.pop("compressor")
100112
return dataset_kwargs
101113

102114

tests/test_readwrite.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -377,13 +377,20 @@ def check_compressed(key, value):
377377

378378

379379
@pytest.mark.parametrize("zarr_write_format", [2, 3])
380-
def test_zarr_compression(tmp_path, zarr_write_format):
380+
@pytest.mark.parametrize(
381+
"use_compression", [True, False], ids=["compressed", "uncompressed"]
382+
)
383+
def test_zarr_compression(
384+
tmp_path: Path, zarr_write_format: Literal[2, 3], *, use_compression: bool
385+
):
381386
if zarr_write_format == 3 and is_zarr_v2():
382387
pytest.xfail("Cannot write zarr v3 format with v2 package")
383388
ad.settings.zarr_write_format = zarr_write_format
384389
pth = str(Path(tmp_path) / "adata.zarr")
385390
adata = gen_adata((10, 8), **GEN_ADATA_NO_XARRAY_ARGS)
386-
if zarr_write_format == 2 or is_zarr_v2():
391+
if not use_compression:
392+
compressor = None
393+
elif zarr_write_format == 2 or is_zarr_v2():
387394
from numcodecs import Blosc
388395

389396
compressor = Blosc(cname="zstd", clevel=3, shuffle=Blosc.BITSHUFFLE)
@@ -393,21 +400,23 @@ def test_zarr_compression(tmp_path, zarr_write_format):
393400
# Don't use Blosc since it's defaults can change:
394401
# https://github.com/zarr-developers/zarr-python/pull/3545
395402
compressor = ZstdCodec(level=3, checksum=True)
396-
not_compressed = []
403+
wrongly_compressed = []
397404

398405
ad.io.write_zarr(pth, adata, compressor=compressor)
399406

400407
def check_compressed(value, key):
401408
if not isinstance(value, ZarrArray) or value.shape == ():
402409
return None
403-
(read_compressor,) = value.compressors
410+
(read_compressor,) = value.compressors or [None]
404411
if zarr_write_format == 2:
405412
if read_compressor != compressor:
406-
not_compressed.append(key)
413+
wrongly_compressed.append(key)
407414
return None
408-
if read_compressor.to_dict() != compressor.to_dict():
409-
print(read_compressor.to_dict(), compressor.to_dict())
410-
not_compressed.append(key)
415+
if (compressor is None and read_compressor is not None) or (
416+
None not in {compressor, read_compressor}
417+
and read_compressor.to_dict() != compressor.to_dict()
418+
):
419+
wrongly_compressed.append(key)
411420

412421
if is_zarr_v2():
413422
with zarr.open(pth, "r") as f:
@@ -416,14 +425,7 @@ def check_compressed(value, key):
416425
f = zarr.open(pth, mode="r")
417426
for key, value in f.members(max_depth=None):
418427
check_compressed(value, key)
419-
420-
if not_compressed:
421-
sep = "\n\t"
422-
msg = (
423-
f"These elements were not compressed correctly:{sep}"
424-
f"{sep.join(not_compressed)}"
425-
)
426-
raise AssertionError(msg)
428+
assert not wrongly_compressed, "Some elements were not (un)compressed correctly"
427429

428430
expected = ad.read_zarr(pth)
429431
assert_equal(adata, expected)

0 commit comments

Comments
 (0)