Skip to content

Commit 9ebc55b

Browse files
ilan-goldpre-commit-ci[bot]flying-sheep
authored
fix: warn for upcoming zarr autosharding + fix settings dependency (#2427)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent 10c4f09 commit 9ebc55b

9 files changed

Lines changed: 147 additions & 68 deletions

File tree

docs/release-notes/2427.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{attr}`anndata.settings.auto_shard_zarr_v3` and {attr}`anndata.settings.zarr_write_format` are no longer dependent on each other, as stated in the docs {user}`ilan-gold`

docs/release-notes/2427.perf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{attr}`anndata.settings.auto_shard_zarr_v3` now utilizes zarr's support for a target shard via `zarr.config.set({"array.target_shard_size_bytes" ...})` (only `zarr` version `>=3.1.4`) to make the target shard size 1GB (uncompressed) if not otherwise set. {user}`ilan-gold`

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ filterwarnings_when_strict = [
176176
"default:Consolidated metadata is:UserWarning",
177177
"default:.*Struct:zarr.core.dtype.common.UnstableSpecificationWarning",
178178
"default:.*FixedLengthUTF32:zarr.core.dtype.common.UnstableSpecificationWarning",
179-
"default:Automatic shard shape inference is experimental",
180179
"default:Writing zarr v2:UserWarning",
181180
# TODO: Remove in conjunction with or before https://github.com/scverse/anndata/pull/1707
182181
"default:.*will obey copy-on-write semantics:FutureWarning",

src/anndata/_io/specs/methods.py

Lines changed: 109 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from __future__ import annotations
22

3+
import warnings
34
from collections.abc import Mapping, MutableMapping
5+
from contextlib import contextmanager, nullcontext
46
from copy import copy
5-
from functools import partial
7+
from functools import partial, wraps
8+
from importlib.metadata import version
69
from itertools import product
710
from types import MappingProxyType
811
from typing import TYPE_CHECKING, Protocol
912

1013
import h5py
1114
import numpy as np
1215
import pandas as pd
16+
import zarr
1317
from numcodecs import VLenUTF8
18+
from packaging.version import Version
1419
from scipy import sparse
1520
from zarr.core.dtype import VariableLengthUTF8
1621

@@ -21,6 +26,7 @@
2126
from anndata._core.merge import intersect_keys
2227
from anndata._core.sparse_dataset import _CSCDataset, _CSRDataset, sparse_dataset
2328
from anndata._io.utils import check_key, zero_dim_array_as_scalar
29+
from anndata._types import StorageType
2430
from anndata._warnings import OldFormatWarning
2531
from anndata.compat import (
2632
AwkArray,
@@ -45,16 +51,16 @@
4551
from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial
4652

4753
if TYPE_CHECKING:
48-
from collections.abc import Iterator
54+
from collections.abc import Generator, Iterator
4955
from os import PathLike
5056
from typing import Any, Literal
5157

5258
from numpy import typing as npt
5359
from numpy.typing import NDArray
5460

55-
from anndata._types import _ArrayStorageType, _GroupStorageType
61+
from anndata._types import _ArrayStorageType, _GroupStorageType, _WriteInternal
5662
from anndata.compat import CSArray, CSMatrix, CupyCSMatrix
57-
from anndata.typing import AxisStorable, _InMemoryArrayOrScalarType
63+
from anndata.typing import AxisStorable, RWAble, _InMemoryArrayOrScalarType
5864

5965
from .registry import Reader, Writer
6066

@@ -112,10 +118,35 @@ def zarr_v3_compressor_compat(dataset_kwargs: dict) -> dict:
112118
return dataset_kwargs
113119

114120

115-
def zarr_v3_sharding(dataset_kwargs) -> dict:
116-
if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
121+
@contextmanager
122+
def zarr_v3_sharding(dataset_kwargs: dict, format: Literal[2, 3]) -> Generator[dict]:
123+
auto_sharding = (
124+
"shards" not in dataset_kwargs
125+
and ad.settings.auto_shard_zarr_v3
126+
and format == 3
127+
)
128+
if ad.settings.auto_shard_zarr_v3 is None and format == 3:
129+
warn(
130+
"zarr v3 autosharding will be the default in the next minor release.",
131+
UserWarning,
132+
)
133+
elif auto_sharding:
117134
dataset_kwargs = {**dataset_kwargs, "shards": "auto"}
118-
return dataset_kwargs
135+
# Auto shard sizes are a relatively recent feature
136+
supports_auto_shard_size = Version(version("zarr")) >= Version("3.1.4")
137+
has_auto_shard_size = supports_auto_shard_size and isinstance(
138+
zarr.config.get("array.target_shard_size_bytes"), int
139+
)
140+
# 1GB uncompressed shard size seems reasonable.
141+
# Shards need to generally held completely in memory before writing.
142+
# Even at a compression ration of 6x, that's still a ~20x improvement on number of files.
143+
# Users can ovetrride this nonetheless, hence the above checks.
144+
with (
145+
zarr.config.set({"array.target_shard_size_bytes": 1_000_000_000})
146+
if supports_auto_shard_size and not has_auto_shard_size and auto_sharding
147+
else nullcontext()
148+
):
149+
yield dataset_kwargs
119150

120151

121152
def _to_cpu_mem_wrapper(write_func):
@@ -140,6 +171,32 @@ def wrapper(
140171
return wrapper
141172

142173

174+
def suppress_autoshard_warning[S: StorageType, T: RWAble](
175+
func: _WriteInternal[S, T],
176+
) -> _WriteInternal[S, T]:
177+
@wraps(func)
178+
def wrapper(
179+
f: S,
180+
k: str,
181+
val: T,
182+
*,
183+
_writer: Writer,
184+
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
185+
):
186+
with warnings.catch_warnings():
187+
# Suppress warnings only if the user has opted into autosharding at the top level.
188+
# If someone provides `shards` explicitly, then they should get the warning.
189+
if ad.settings.auto_shard_zarr_v3 and "shards" not in dataset_kwargs:
190+
warnings.filterwarnings(
191+
"ignore",
192+
r"Automatic shard shape inference is experimental",
193+
UserWarning,
194+
)
195+
return func(f, k, val, _writer=_writer, dataset_kwargs=dataset_kwargs)
196+
197+
return wrapper
198+
199+
143200
################################
144201
# Fallbacks / backwards compat #
145202
################################
@@ -277,6 +334,7 @@ def _read_partial(group, *, items=None, indices=(slice(None), slice(None))):
277334

278335
@_REGISTRY.register_write(ZarrGroup, AnnData, IOSpec("anndata", "0.1.0"))
279336
@_REGISTRY.register_write(H5Group, AnnData, IOSpec("anndata", "0.1.0"))
337+
@suppress_autoshard_warning
280338
def write_anndata(
281339
f: _GroupStorageType,
282340
k: str,
@@ -329,6 +387,7 @@ def read_anndata(elem: _GroupStorageType | H5File, *, _reader: Reader) -> AnnDat
329387

330388
@_REGISTRY.register_write(H5Group, Raw, IOSpec("raw", "0.1.0"))
331389
@_REGISTRY.register_write(ZarrGroup, Raw, IOSpec("raw", "0.1.0"))
390+
@suppress_autoshard_warning
332391
def write_raw(
333392
f: _GroupStorageType,
334393
k: str,
@@ -361,6 +420,7 @@ def write_null_h5py(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})):
361420

362421

363422
@_REGISTRY.register_write(ZarrGroup, type(None), IOSpec("null", "0.1.0"))
423+
@suppress_autoshard_warning
364424
def write_null_zarr(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})):
365425
dataset_kwargs = _remove_scalar_compression_args(dataset_kwargs)
366426
# zarr has no first-class null dataset
@@ -384,6 +444,7 @@ def read_mapping(
384444

385445
@_REGISTRY.register_write(H5Group, dict, IOSpec("dict", "0.1.0"))
386446
@_REGISTRY.register_write(ZarrGroup, dict, IOSpec("dict", "0.1.0"))
447+
@suppress_autoshard_warning
387448
def write_mapping(
388449
f: _GroupStorageType,
389450
k: str,
@@ -404,6 +465,7 @@ def write_mapping(
404465

405466
@_REGISTRY.register_write(H5Group, list, IOSpec("array", "0.2.0"))
406467
@_REGISTRY.register_write(ZarrGroup, list, IOSpec("array", "0.2.0"))
468+
@suppress_autoshard_warning
407469
def write_list(
408470
f: _GroupStorageType,
409471
k: str,
@@ -425,6 +487,7 @@ def write_list(
425487
@_REGISTRY.register_write(ZarrGroup, np.ma.MaskedArray, IOSpec("array", "0.2.0"))
426488
@_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0"))
427489
@_REGISTRY.register_write(ZarrGroup, H5Array, IOSpec("array", "0.2.0"))
490+
@suppress_autoshard_warning
428491
@zero_dim_array_as_scalar
429492
def write_basic(
430493
f: _GroupStorageType,
@@ -441,8 +504,10 @@ def write_basic(
441504
f.create_dataset(k, data=elem, shape=elem.shape, dtype=dtype, **dataset_kwargs)
442505
else:
443506
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
444-
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
445-
f.create_array(k, shape=elem.shape, dtype=dtype, **dataset_kwargs)
507+
with zarr_v3_sharding(
508+
dataset_kwargs, format=f.metadata.zarr_format
509+
) as dataset_kwargs:
510+
f.create_array(k, shape=elem.shape, dtype=dtype, **dataset_kwargs)
446511
# see https://github.com/zarr-developers/zarr-python/discussions/2712
447512
if isinstance(elem, ZarrArray | H5Array):
448513
f[k][...] = elem[...]
@@ -499,14 +564,15 @@ def write_chunked_dense_array_to_group(
499564
_to_cpu_mem_wrapper(write_basic)
500565
)
501566
_REGISTRY.register_write(ZarrGroup, CupyArray, IOSpec("array", "0.2.0"))(
502-
_to_cpu_mem_wrapper(write_basic)
567+
suppress_autoshard_warning(_to_cpu_mem_wrapper(write_basic))
503568
)
504569

505570

506571
@_REGISTRY.register_write(ZarrGroup, views.DaskArrayView, IOSpec("array", "0.2.0"))
507572
@_REGISTRY.register_write(ZarrGroup, DaskArray, IOSpec("array", "0.2.0"))
508573
@_REGISTRY.register_write(H5Group, views.DaskArrayView, IOSpec("array", "0.2.0"))
509574
@_REGISTRY.register_write(H5Group, DaskArray, IOSpec("array", "0.2.0"))
575+
@suppress_autoshard_warning
510576
def write_basic_dask_dask_dense(
511577
f: ZarrGroup | H5Group,
512578
k: str,
@@ -521,11 +587,13 @@ def write_basic_dask_dask_dense(
521587
is_h5 = isinstance(f, H5Group)
522588
if not is_h5:
523589
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
524-
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
525590
if is_h5:
526591
g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
527592
else:
528-
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
593+
with zarr_v3_sharding(
594+
dataset_kwargs, format=f.metadata.zarr_format
595+
) as dataset_kwargs:
596+
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
529597
da.store(elem, g, scheduler="threads")
530598

531599

@@ -590,6 +658,7 @@ def write_vlen_string_array(
590658
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "U"), IOSpec("string-array", "0.2.0"))
591659
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "O"), IOSpec("string-array", "0.2.0"))
592660
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "T"), IOSpec("string-array", "0.2.0"))
661+
@suppress_autoshard_warning
593662
@zero_dim_array_as_scalar
594663
def write_vlen_string_array_zarr(
595664
f: ZarrGroup,
@@ -605,15 +674,17 @@ def write_vlen_string_array_zarr(
605674
filters, fill_value = None, None
606675
if f.metadata.zarr_format == 2:
607676
filters, fill_value = [VLenUTF8()], ""
608-
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
609-
f.create_array(
610-
k,
611-
shape=elem.shape,
612-
dtype=dtype,
613-
filters=filters,
614-
fill_value=fill_value,
615-
**dataset_kwargs,
616-
)
677+
with zarr_v3_sharding(
678+
dataset_kwargs, format=f.metadata.zarr_format
679+
) as dataset_kwargs:
680+
f.create_array(
681+
k,
682+
shape=elem.shape,
683+
dtype=dtype,
684+
filters=filters,
685+
fill_value=fill_value,
686+
**dataset_kwargs,
687+
)
617688
f[k][:] = elem
618689

619690

@@ -674,8 +745,9 @@ def write_recarray_zarr(
674745
dataset_kwargs = dict(dataset_kwargs)
675746
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
676747
# https://github.com/zarr-developers/zarr-python/issues/3546
677-
# if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
678-
# dataset_kwargs = {**dataset_kwargs, "shards": "auto"}
748+
# with zarr_v3_sharding(
749+
# dataset_kwargs, format=f.metadata.zarr_format
750+
# ) as dataset_kwargs:
679751
f.create_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
680752
f[k][...] = elem
681753

@@ -730,16 +802,20 @@ def write_sparse_compressed(
730802
attr_name, data=attr, shape=attr.shape, dtype=dtype, **dataset_kwargs
731803
)
732804
else:
733-
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
734-
arr = g.create_array(
735-
attr_name, shape=attr.shape, dtype=dtype, **dataset_kwargs
736-
)
805+
with zarr_v3_sharding(
806+
dataset_kwargs, format=f.metadata.zarr_format
807+
) as dataset_kwargs:
808+
arr = g.create_array(
809+
attr_name, shape=attr.shape, dtype=dtype, **dataset_kwargs
810+
)
737811
# see https://github.com/zarr-developers/zarr-python/discussions/2712
738812
arr[...] = attr[...]
739813

740814

741-
write_csr = partial(write_sparse_compressed, fmt="csr")
742-
write_csc = partial(write_sparse_compressed, fmt="csc")
815+
write_csr, write_csc = (
816+
suppress_autoshard_warning(partial(write_sparse_compressed, fmt=fmt))
817+
for fmt in ["csr", "csc"]
818+
)
743819

744820
for store_type, (cls, spec, func) in product(
745821
(H5Group, ZarrGroup),
@@ -776,6 +852,7 @@ def write_sparse_compressed(
776852
@_REGISTRY.register_write(H5Group, _CSCDataset, IOSpec("csc_matrix", "0.1.0"))
777853
@_REGISTRY.register_write(ZarrGroup, _CSRDataset, IOSpec("csr_matrix", "0.1.0"))
778854
@_REGISTRY.register_write(ZarrGroup, _CSCDataset, IOSpec("csc_matrix", "0.1.0"))
855+
@suppress_autoshard_warning
779856
def write_sparse_dataset(
780857
f: _GroupStorageType,
781858
k: str,
@@ -902,6 +979,7 @@ def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None)))
902979
@_REGISTRY.register_write(
903980
ZarrGroup, views.AwkwardArrayView, IOSpec("awkward-array", "0.1.0")
904981
)
982+
@suppress_autoshard_warning
905983
def write_awkward(
906984
f: _GroupStorageType,
907985
k: str,
@@ -945,6 +1023,7 @@ def read_awkward(elem: _GroupStorageType, *, _reader: Reader) -> AwkArray:
9451023
@_REGISTRY.register_write(H5Group, pd.DataFrame, IOSpec("dataframe", "0.2.0"))
9461024
@_REGISTRY.register_write(ZarrGroup, views.DataFrameView, IOSpec("dataframe", "0.2.0"))
9471025
@_REGISTRY.register_write(ZarrGroup, pd.DataFrame, IOSpec("dataframe", "0.2.0"))
1026+
@suppress_autoshard_warning
9481027
def write_dataframe(
9491028
f: _GroupStorageType,
9501029
key: str,
@@ -1086,6 +1165,7 @@ def read_partial_dataframe_0_1_0(
10861165

10871166
@_REGISTRY.register_write(H5Group, pd.Categorical, IOSpec("categorical", "0.2.0"))
10881167
@_REGISTRY.register_write(ZarrGroup, pd.Categorical, IOSpec("categorical", "0.2.0"))
1168+
@suppress_autoshard_warning
10891169
def write_categorical(
10901170
f: _GroupStorageType,
10911171
k: str,

src/anndata/_io/specs/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(self) -> None:
114114

115115
def register_write[S: StorageType, T: RWAble](
116116
self,
117-
dest_type: type,
117+
dest_type: type[S],
118118
src_type: type | tuple[type, str],
119119
spec: IOSpec | Mapping[str, str],
120120
modifiers: Iterable[str] = frozenset(),

src/anndata/_settings.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -456,16 +456,6 @@ def validate_zarr_write_format(format: int, settings: SettingsManager):
456456
if format not in {2, 3}:
457457
msg = "non-v2 zarr on-disk format not supported"
458458
raise ValueError(msg)
459-
if format == 2 and getattr(settings, "auto_shard_zarr_v3", False):
460-
msg = "Cannot set `zarr_write_format` to 2 with autosharding on. Please set to `False` `anndata.settings.auto_shard_zarr_v3`"
461-
raise ValueError(msg)
462-
463-
464-
def validate_zarr_sharding(auto_shard: bool, settings: SettingsManager): # noqa: FBT001
465-
validate_bool(auto_shard, settings)
466-
if auto_shard and settings.zarr_write_format == 2:
467-
msg = "Cannot shard v2 format data. Please set `anndata.settings.zarr_write_format` to 3."
468-
raise ValueError(msg)
469459

470460

471461
settings.register(
@@ -517,10 +507,11 @@ def validate_sparse_settings(val: Any, settings: SettingsManager) -> None:
517507

518508
settings.register(
519509
"auto_shard_zarr_v3",
520-
default_value=False,
510+
default_value=None,
521511
description="Whether or not to use zarr's auto computation of sharding for v3. For v2 this setting will be ignored. The setting will apply to all calls to anndata's writing mechanism (write_zarr / write_elem) and will **not** override any user-defined kwargs for shards.",
522-
validate=validate_zarr_sharding,
523-
get_from_env=check_and_get_bool,
512+
validate=gen_validator((bool, NoneType)),
513+
option_type=bool | None,
514+
get_from_env=check_and_get_bool_or_none,
524515
)
525516

526517

src/anndata/_settings.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ class _AnnDataSettingsManager(SettingsManager):
4646
min_rows_for_chunked_h5_copy: int = 1000
4747
disallow_forward_slash_in_h5ad: bool = False
4848
write_csr_csc_indices_with_min_possible_dtype: bool = False
49-
auto_shard_zarr_v3: bool = False
49+
auto_shard_zarr_v3: bool | None = None
5050

5151
settings: _AnnDataSettingsManager

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def backing_h5ad(tmp_path: Path) -> Path:
3333
return tmp_path / "test.h5ad"
3434

3535

36+
@pytest.fixture(autouse=True)
37+
def zarr_shard(request: pytest.FixtureRequest):
38+
with ad.settings.override(auto_shard_zarr_v3=True):
39+
yield
40+
41+
3642
@pytest.fixture(
3743
params=[("h5ad", None), ("zarr", 2), ("zarr", 3)],
3844
ids=["h5ad", "zarr2", "zarr3"],

0 commit comments

Comments
 (0)