Skip to content

Commit 9c9b4fc

Browse files
authored
Backport PR #2427 on branch 0.12.x (fix: warn for upcoming zarr autosharding + fix settings dependency) (#2437)
1 parent 5ef1dea commit 9c9b4fc

9 files changed

Lines changed: 167 additions & 94 deletions

File tree

docs/release-notes/2427.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{attr}`anndata.settings.auto_shard_zarr_v3` and {attr}`anndata.settings.zarr_write_format` are no longer dependent on each other, as stated in the docs {user}`ilan-gold`

docs/release-notes/2427.perf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{attr}`anndata.settings.auto_shard_zarr_v3` now utilizes zarr's support for a target shard via `zarr.config.set({"array.target_shard_size_bytes" ...})` (only `zarr` version `>=3.1.4`) to make the target shard size 1GB (uncompressed) if not otherwise set. {user}`ilan-gold`

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ filterwarnings_when_strict = [
161161
"default:Consolidated metadata is:UserWarning",
162162
"default:.*Struct:zarr.core.dtype.common.UnstableSpecificationWarning",
163163
"default:.*FixedLengthUTF32:zarr.core.dtype.common.UnstableSpecificationWarning",
164-
"default:Automatic shard shape inference is experimental",
165164
"default:Writing zarr v2:UserWarning",
166165
# TODO: Remove in conjunction with or before https://github.com/scverse/anndata/pull/1707
167166
"default:.*will obey copy-on-write semantics:FutureWarning",

src/anndata/_io/specs/methods.py

Lines changed: 113 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
import warnings
44
from collections.abc import Mapping
5+
from contextlib import contextmanager, nullcontext
56
from copy import copy
6-
from functools import partial
7+
from functools import partial, wraps
78
from importlib.metadata import version
89
from itertools import product
910
from types import MappingProxyType
10-
from typing import TYPE_CHECKING, Protocol
11+
from typing import TYPE_CHECKING, Protocol, TypeVar
1112
from warnings import warn
1213

1314
import h5py
1415
import numpy as np
1516
import pandas as pd
17+
import zarr
1618
from packaging.version import Version
1719
from scipy import sparse
1820

@@ -23,6 +25,7 @@
2325
from anndata._core.merge import intersect_keys
2426
from anndata._core.sparse_dataset import _CSCDataset, _CSRDataset, sparse_dataset
2527
from anndata._io.utils import check_key, zero_dim_array_as_scalar
28+
from anndata._types import StorageType
2629
from anndata._warnings import OldFormatWarning
2730
from anndata.compat import (
2831
AwkArray,
@@ -40,13 +43,14 @@
4043
_read_attr,
4144
_require_group_write_dataframe,
4245
)
46+
from anndata.typing import RWAble
4347

4448
from ..._settings import settings
4549
from ...compat import NULLABLE_NUMPY_STRING_TYPE, PANDAS_STRING_ARRAY_TYPES, is_zarr_v2
4650
from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial
4751

4852
if TYPE_CHECKING:
49-
from collections.abc import Iterator
53+
from collections.abc import Generator, Iterator
5054
from os import PathLike
5155
from typing import Any, Literal
5256

@@ -57,6 +61,7 @@
5761
from anndata.compat import CSArray, CSMatrix
5862
from anndata.typing import AxisStorable, InMemoryArrayOrScalarType
5963

64+
from ...types import _WriteInternal
6065
from .registry import Reader, Writer
6166

6267
####################
@@ -113,10 +118,36 @@ def zarr_v3_compressor_compat(dataset_kwargs: dict) -> dict:
113118
return dataset_kwargs
114119

115120

116-
def zarr_v3_sharding(dataset_kwargs) -> dict:
117-
if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
121+
@contextmanager
122+
def zarr_v3_sharding(dataset_kwargs: dict, format: Literal[2, 3]) -> Generator[dict]:
123+
auto_sharding = (
124+
"shards" not in dataset_kwargs
125+
and ad.settings.auto_shard_zarr_v3
126+
and format == 3
127+
)
128+
if ad.settings.auto_shard_zarr_v3 is None and format == 3:
129+
warnings.warn(
130+
"zarr v3 autosharding will be the default in the next minor release.",
131+
UserWarning,
132+
stacklevel=2,
133+
)
134+
elif auto_sharding:
118135
dataset_kwargs = {**dataset_kwargs, "shards": "auto"}
119-
return dataset_kwargs
136+
# Auto shard sizes are a relatively recent feature
137+
supports_auto_shard_size = Version(version("zarr")) >= Version("3.1.4")
138+
has_auto_shard_size = supports_auto_shard_size and isinstance(
139+
zarr.config.get("array.target_shard_size_bytes"), int
140+
)
141+
# 1GB uncompressed shard size seems reasonable.
142+
# Shards need to generally held completely in memory before writing.
143+
# Even at a compression ration of 6x, that's still a ~20x improvement on number of files.
144+
# Users can ovetrride this nonetheless, hence the above checks.
145+
with (
146+
zarr.config.set({"array.target_shard_size_bytes": 1_000_000_000})
147+
if supports_auto_shard_size and not has_auto_shard_size and auto_sharding
148+
else nullcontext()
149+
):
150+
yield dataset_kwargs
120151

121152

122153
def _to_cpu_mem_wrapper(write_func):
@@ -141,6 +172,36 @@ def wrapper(
141172
return wrapper
142173

143174

175+
S = TypeVar("S", bound=StorageType)
176+
T = TypeVar("T", bound=RWAble)
177+
178+
179+
def suppress_autoshard_warning(
180+
func: _WriteInternal[S, T],
181+
) -> _WriteInternal[S, T]:
182+
@wraps(func)
183+
def wrapper(
184+
f: S,
185+
k: str,
186+
val: T,
187+
*,
188+
_writer: Writer,
189+
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
190+
):
191+
with warnings.catch_warnings():
192+
# Suppress warnings only if the user has opted into autosharding at the top level.
193+
# If someone provides `shards` explicitly, then they should get the warning.
194+
if ad.settings.auto_shard_zarr_v3 and "shards" not in dataset_kwargs:
195+
warnings.filterwarnings(
196+
"ignore",
197+
r"Automatic shard shape inference is experimental",
198+
UserWarning,
199+
)
200+
return func(f, k, val, _writer=_writer, dataset_kwargs=dataset_kwargs)
201+
202+
return wrapper
203+
204+
144205
################################
145206
# Fallbacks / backwards compat #
146207
################################
@@ -284,6 +345,7 @@ def _read_partial(group, *, items=None, indices=(slice(None), slice(None))):
284345

285346
@_REGISTRY.register_write(ZarrGroup, AnnData, IOSpec("anndata", "0.1.0"))
286347
@_REGISTRY.register_write(H5Group, AnnData, IOSpec("anndata", "0.1.0"))
348+
@suppress_autoshard_warning
287349
def write_anndata(
288350
f: GroupStorageType,
289351
k: str,
@@ -333,6 +395,7 @@ def read_anndata(elem: GroupStorageType | H5File, *, _reader: Reader) -> AnnData
333395

334396
@_REGISTRY.register_write(H5Group, Raw, IOSpec("raw", "0.1.0"))
335397
@_REGISTRY.register_write(ZarrGroup, Raw, IOSpec("raw", "0.1.0"))
398+
@suppress_autoshard_warning
336399
def write_raw(
337400
f: GroupStorageType,
338401
k: str,
@@ -365,6 +428,7 @@ def write_null_h5py(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})):
365428

366429

367430
@_REGISTRY.register_write(ZarrGroup, type(None), IOSpec("null", "0.1.0"))
431+
@suppress_autoshard_warning
368432
def write_null_zarr(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})):
369433
dataset_kwargs = _remove_scalar_compression_args(dataset_kwargs)
370434
# zarr has no first-class null dataset
@@ -392,6 +456,7 @@ def read_mapping(elem: GroupStorageType, *, _reader: Reader) -> dict[str, AxisSt
392456

393457
@_REGISTRY.register_write(H5Group, dict, IOSpec("dict", "0.1.0"))
394458
@_REGISTRY.register_write(ZarrGroup, dict, IOSpec("dict", "0.1.0"))
459+
@suppress_autoshard_warning
395460
def write_mapping(
396461
f: GroupStorageType,
397462
k: str,
@@ -412,6 +477,7 @@ def write_mapping(
412477

413478
@_REGISTRY.register_write(H5Group, list, IOSpec("array", "0.2.0"))
414479
@_REGISTRY.register_write(ZarrGroup, list, IOSpec("array", "0.2.0"))
480+
@suppress_autoshard_warning
415481
def write_list(
416482
f: GroupStorageType,
417483
k: str,
@@ -433,6 +499,7 @@ def write_list(
433499
@_REGISTRY.register_write(ZarrGroup, np.ma.MaskedArray, IOSpec("array", "0.2.0"))
434500
@_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0"))
435501
@_REGISTRY.register_write(ZarrGroup, H5Array, IOSpec("array", "0.2.0"))
502+
@suppress_autoshard_warning
436503
@zero_dim_array_as_scalar
437504
def write_basic(
438505
f: GroupStorageType,
@@ -449,8 +516,10 @@ def write_basic(
449516
f.create_dataset(k, data=elem, shape=elem.shape, dtype=dtype, **dataset_kwargs)
450517
else:
451518
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
452-
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
453-
f.create_array(k, shape=elem.shape, dtype=dtype, **dataset_kwargs)
519+
with zarr_v3_sharding(
520+
dataset_kwargs, format=f.metadata.zarr_format
521+
) as dataset_kwargs:
522+
f.create_array(k, shape=elem.shape, dtype=dtype, **dataset_kwargs)
454523
# see https://github.com/zarr-developers/zarr-python/discussions/2712
455524
if isinstance(elem, ZarrArray | H5Array):
456525
f[k][...] = elem[...]
@@ -507,14 +576,15 @@ def write_chunked_dense_array_to_group(
507576
_to_cpu_mem_wrapper(write_basic)
508577
)
509578
_REGISTRY.register_write(ZarrGroup, CupyArray, IOSpec("array", "0.2.0"))(
510-
_to_cpu_mem_wrapper(write_basic)
579+
suppress_autoshard_warning(_to_cpu_mem_wrapper(write_basic))
511580
)
512581

513582

514583
@_REGISTRY.register_write(ZarrGroup, views.DaskArrayView, IOSpec("array", "0.2.0"))
515584
@_REGISTRY.register_write(ZarrGroup, DaskArray, IOSpec("array", "0.2.0"))
516585
@_REGISTRY.register_write(H5Group, views.DaskArrayView, IOSpec("array", "0.2.0"))
517586
@_REGISTRY.register_write(H5Group, DaskArray, IOSpec("array", "0.2.0"))
587+
@suppress_autoshard_warning
518588
def write_basic_dask_dask_dense(
519589
f: ZarrGroup | H5Group,
520590
k: str,
@@ -527,13 +597,14 @@ def write_basic_dask_dask_dense(
527597

528598
dataset_kwargs = dict(dataset_kwargs)
529599
is_h5 = isinstance(f, H5Group)
530-
if not is_h5:
531-
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
532-
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
533600
if is_zarr_v2() or is_h5:
534601
g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
535602
else:
536-
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
603+
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
604+
with zarr_v3_sharding(
605+
dataset_kwargs, format=f.metadata.zarr_format
606+
) as dataset_kwargs:
607+
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
537608
da.store(elem, g, scheduler="threads")
538609

539610

@@ -598,6 +669,7 @@ def write_vlen_string_array(
598669
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "U"), IOSpec("string-array", "0.2.0"))
599670
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "O"), IOSpec("string-array", "0.2.0"))
600671
@_REGISTRY.register_write(ZarrGroup, (np.ndarray, "T"), IOSpec("string-array", "0.2.0"))
672+
@suppress_autoshard_warning
601673
@zero_dim_array_as_scalar
602674
def write_vlen_string_array_zarr(
603675
f: ZarrGroup,
@@ -635,15 +707,17 @@ def write_vlen_string_array_zarr(
635707
filters, fill_value = None, None
636708
if f.metadata.zarr_format == 2:
637709
filters, fill_value = [VLenUTF8()], ""
638-
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
639-
f.create_array(
640-
k,
641-
shape=elem.shape,
642-
dtype=dtype,
643-
filters=filters,
644-
fill_value=fill_value,
645-
**dataset_kwargs,
646-
)
710+
with zarr_v3_sharding(
711+
dataset_kwargs, format=f.metadata.zarr_format
712+
) as dataset_kwargs:
713+
f.create_array(
714+
k,
715+
shape=elem.shape,
716+
dtype=dtype,
717+
filters=filters,
718+
fill_value=fill_value,
719+
**dataset_kwargs,
720+
)
647721
f[k][:] = elem
648722

649723

@@ -705,8 +779,9 @@ def write_recarray_zarr(
705779
dataset_kwargs = dict(dataset_kwargs)
706780
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
707781
# https://github.com/zarr-developers/zarr-python/issues/3546
708-
# if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
709-
# dataset_kwargs = {**dataset_kwargs, "shards": "auto"}
782+
# with zarr_v3_sharding(
783+
# dataset_kwargs, format=f.metadata.zarr_format
784+
# ) as dataset_kwargs:
710785
f.create_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
711786
f[k][...] = elem
712787

@@ -761,16 +836,20 @@ def write_sparse_compressed(
761836
attr_name, data=attr, shape=attr.shape, dtype=dtype, **dataset_kwargs
762837
)
763838
else:
764-
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
765-
arr = g.create_array(
766-
attr_name, shape=attr.shape, dtype=dtype, **dataset_kwargs
767-
)
839+
with zarr_v3_sharding(
840+
dataset_kwargs, format=f.metadata.zarr_format
841+
) as dataset_kwargs:
842+
arr = g.create_array(
843+
attr_name, shape=attr.shape, dtype=dtype, **dataset_kwargs
844+
)
768845
# see https://github.com/zarr-developers/zarr-python/discussions/2712
769846
arr[...] = attr[...]
770847

771848

772-
write_csr = partial(write_sparse_compressed, fmt="csr")
773-
write_csc = partial(write_sparse_compressed, fmt="csc")
849+
write_csr, write_csc = (
850+
suppress_autoshard_warning(partial(write_sparse_compressed, fmt=fmt))
851+
for fmt in ["csr", "csc"]
852+
)
774853

775854
for store_type, (cls, spec, func) in product(
776855
(H5Group, ZarrGroup),
@@ -807,6 +886,7 @@ def write_sparse_compressed(
807886
@_REGISTRY.register_write(H5Group, _CSCDataset, IOSpec("csc_matrix", "0.1.0"))
808887
@_REGISTRY.register_write(ZarrGroup, _CSRDataset, IOSpec("csr_matrix", "0.1.0"))
809888
@_REGISTRY.register_write(ZarrGroup, _CSCDataset, IOSpec("csc_matrix", "0.1.0"))
889+
@suppress_autoshard_warning
810890
def write_sparse_dataset(
811891
f: GroupStorageType,
812892
k: str,
@@ -931,6 +1011,7 @@ def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None)))
9311011
@_REGISTRY.register_write(
9321012
ZarrGroup, views.AwkwardArrayView, IOSpec("awkward-array", "0.1.0")
9331013
)
1014+
@suppress_autoshard_warning
9341015
def write_awkward(
9351016
f: GroupStorageType,
9361017
k: str,
@@ -974,6 +1055,7 @@ def read_awkward(elem: GroupStorageType, *, _reader: Reader) -> AwkArray:
9741055
@_REGISTRY.register_write(H5Group, pd.DataFrame, IOSpec("dataframe", "0.2.0"))
9751056
@_REGISTRY.register_write(ZarrGroup, views.DataFrameView, IOSpec("dataframe", "0.2.0"))
9761057
@_REGISTRY.register_write(ZarrGroup, pd.DataFrame, IOSpec("dataframe", "0.2.0"))
1058+
@suppress_autoshard_warning
9771059
def write_dataframe(
9781060
f: GroupStorageType,
9791061
key: str,
@@ -1115,6 +1197,7 @@ def read_partial_dataframe_0_1_0(
11151197

11161198
@_REGISTRY.register_write(H5Group, pd.Categorical, IOSpec("categorical", "0.2.0"))
11171199
@_REGISTRY.register_write(ZarrGroup, pd.Categorical, IOSpec("categorical", "0.2.0"))
1200+
@suppress_autoshard_warning
11181201
def write_categorical(
11191202
f: GroupStorageType,
11201203
k: str,

src/anndata/_io/specs/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(self):
100100

101101
def register_write(
102102
self,
103-
dest_type: type,
103+
dest_type: type[S],
104104
src_type: type | tuple[type, str],
105105
spec: IOSpec | Mapping[str, str],
106106
modifiers: Iterable[str] = frozenset(),

0 commit comments

Comments
 (0)