11from __future__ import annotations
22
3+ import warnings
34from collections .abc import Mapping , MutableMapping
5+ from contextlib import contextmanager , nullcontext
46from copy import copy
5- from functools import partial
7+ from functools import partial , wraps
8+ from importlib .metadata import version
69from itertools import product
710from types import MappingProxyType
811from typing import TYPE_CHECKING , Protocol
912
1013import h5py
1114import numpy as np
1215import pandas as pd
16+ import zarr
1317from numcodecs import VLenUTF8
18+ from packaging .version import Version
1419from scipy import sparse
1520from zarr .core .dtype import VariableLengthUTF8
1621
2126from anndata ._core .merge import intersect_keys
2227from anndata ._core .sparse_dataset import _CSCDataset , _CSRDataset , sparse_dataset
2328from anndata ._io .utils import check_key , zero_dim_array_as_scalar
29+ from anndata ._types import StorageType
2430from anndata ._warnings import OldFormatWarning
2531from anndata .compat import (
2632 AwkArray ,
4551from .registry import _REGISTRY , IOSpec , read_elem , read_elem_partial
4652
4753if TYPE_CHECKING :
48- from collections .abc import Iterator
54+ from collections .abc import Generator , Iterator
4955 from os import PathLike
5056 from typing import Any , Literal
5157
5258 from numpy import typing as npt
5359 from numpy .typing import NDArray
5460
55- from anndata ._types import _ArrayStorageType , _GroupStorageType
61+ from anndata ._types import _ArrayStorageType , _GroupStorageType , _WriteInternal
5662 from anndata .compat import CSArray , CSMatrix , CupyCSMatrix
57- from anndata .typing import AxisStorable , _InMemoryArrayOrScalarType
63+ from anndata .typing import AxisStorable , RWAble , _InMemoryArrayOrScalarType
5864
5965 from .registry import Reader , Writer
6066
@@ -112,10 +118,35 @@ def zarr_v3_compressor_compat(dataset_kwargs: dict) -> dict:
112118 return dataset_kwargs
113119
114120
115- def zarr_v3_sharding (dataset_kwargs ) -> dict :
116- if "shards" not in dataset_kwargs and ad .settings .auto_shard_zarr_v3 :
121+ @contextmanager
122+ def zarr_v3_sharding (dataset_kwargs : dict , format : Literal [2 , 3 ]) -> Generator [dict ]:
123+ auto_sharding = (
124+ "shards" not in dataset_kwargs
125+ and ad .settings .auto_shard_zarr_v3
126+ and format == 3
127+ )
128+ if ad .settings .auto_shard_zarr_v3 is None and format == 3 :
129+ warn (
130+ "zarr v3 autosharding will be the default in the next minor release." ,
131+ UserWarning ,
132+ )
133+ elif auto_sharding :
117134 dataset_kwargs = {** dataset_kwargs , "shards" : "auto" }
118- return dataset_kwargs
135+ # Auto shard sizes are a relatively recent feature
136+ supports_auto_shard_size = Version (version ("zarr" )) >= Version ("3.1.4" )
137+ has_auto_shard_size = supports_auto_shard_size and isinstance (
138+ zarr .config .get ("array.target_shard_size_bytes" ), int
139+ )
140+ # 1GB uncompressed shard size seems reasonable.
141+ # Shards need to generally held completely in memory before writing.
142+ # Even at a compression ration of 6x, that's still a ~20x improvement on number of files.
143+ # Users can ovetrride this nonetheless, hence the above checks.
144+ with (
145+ zarr .config .set ({"array.target_shard_size_bytes" : 1_000_000_000 })
146+ if supports_auto_shard_size and not has_auto_shard_size and auto_sharding
147+ else nullcontext ()
148+ ):
149+ yield dataset_kwargs
119150
120151
121152def _to_cpu_mem_wrapper (write_func ):
@@ -140,6 +171,32 @@ def wrapper(
140171 return wrapper
141172
142173
174+ def suppress_autoshard_warning [S : StorageType , T : RWAble ](
175+ func : _WriteInternal [S , T ],
176+ ) -> _WriteInternal [S , T ]:
177+ @wraps (func )
178+ def wrapper (
179+ f : S ,
180+ k : str ,
181+ val : T ,
182+ * ,
183+ _writer : Writer ,
184+ dataset_kwargs : Mapping [str , Any ] = MappingProxyType ({}),
185+ ):
186+ with warnings .catch_warnings ():
187+ # Suppress warnings only if the user has opted into autosharding at the top level.
188+ # If someone provides `shards` explicitly, then they should get the warning.
189+ if ad .settings .auto_shard_zarr_v3 and "shards" not in dataset_kwargs :
190+ warnings .filterwarnings (
191+ "ignore" ,
192+ r"Automatic shard shape inference is experimental" ,
193+ UserWarning ,
194+ )
195+ return func (f , k , val , _writer = _writer , dataset_kwargs = dataset_kwargs )
196+
197+ return wrapper
198+
199+
143200################################
144201# Fallbacks / backwards compat #
145202################################
@@ -277,6 +334,7 @@ def _read_partial(group, *, items=None, indices=(slice(None), slice(None))):
277334
278335@_REGISTRY .register_write (ZarrGroup , AnnData , IOSpec ("anndata" , "0.1.0" ))
279336@_REGISTRY .register_write (H5Group , AnnData , IOSpec ("anndata" , "0.1.0" ))
337+ @suppress_autoshard_warning
280338def write_anndata (
281339 f : _GroupStorageType ,
282340 k : str ,
@@ -329,6 +387,7 @@ def read_anndata(elem: _GroupStorageType | H5File, *, _reader: Reader) -> AnnDat
329387
330388@_REGISTRY .register_write (H5Group , Raw , IOSpec ("raw" , "0.1.0" ))
331389@_REGISTRY .register_write (ZarrGroup , Raw , IOSpec ("raw" , "0.1.0" ))
390+ @suppress_autoshard_warning
332391def write_raw (
333392 f : _GroupStorageType ,
334393 k : str ,
@@ -361,6 +420,7 @@ def write_null_h5py(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})):
361420
362421
363422@_REGISTRY .register_write (ZarrGroup , type (None ), IOSpec ("null" , "0.1.0" ))
423+ @suppress_autoshard_warning
364424def write_null_zarr (f , k , _v , _writer , dataset_kwargs = MappingProxyType ({})):
365425 dataset_kwargs = _remove_scalar_compression_args (dataset_kwargs )
366426 # zarr has no first-class null dataset
@@ -384,6 +444,7 @@ def read_mapping(
384444
385445@_REGISTRY .register_write (H5Group , dict , IOSpec ("dict" , "0.1.0" ))
386446@_REGISTRY .register_write (ZarrGroup , dict , IOSpec ("dict" , "0.1.0" ))
447+ @suppress_autoshard_warning
387448def write_mapping (
388449 f : _GroupStorageType ,
389450 k : str ,
@@ -404,6 +465,7 @@ def write_mapping(
404465
405466@_REGISTRY .register_write (H5Group , list , IOSpec ("array" , "0.2.0" ))
406467@_REGISTRY .register_write (ZarrGroup , list , IOSpec ("array" , "0.2.0" ))
468+ @suppress_autoshard_warning
407469def write_list (
408470 f : _GroupStorageType ,
409471 k : str ,
@@ -425,6 +487,7 @@ def write_list(
425487@_REGISTRY .register_write (ZarrGroup , np .ma .MaskedArray , IOSpec ("array" , "0.2.0" ))
426488@_REGISTRY .register_write (ZarrGroup , ZarrArray , IOSpec ("array" , "0.2.0" ))
427489@_REGISTRY .register_write (ZarrGroup , H5Array , IOSpec ("array" , "0.2.0" ))
490+ @suppress_autoshard_warning
428491@zero_dim_array_as_scalar
429492def write_basic (
430493 f : _GroupStorageType ,
@@ -441,8 +504,10 @@ def write_basic(
441504 f .create_dataset (k , data = elem , shape = elem .shape , dtype = dtype , ** dataset_kwargs )
442505 else :
443506 dataset_kwargs = zarr_v3_compressor_compat (dataset_kwargs )
444- dataset_kwargs = zarr_v3_sharding (dataset_kwargs )
445- f .create_array (k , shape = elem .shape , dtype = dtype , ** dataset_kwargs )
507+ with zarr_v3_sharding (
508+ dataset_kwargs , format = f .metadata .zarr_format
509+ ) as dataset_kwargs :
510+ f .create_array (k , shape = elem .shape , dtype = dtype , ** dataset_kwargs )
446511 # see https://github.com/zarr-developers/zarr-python/discussions/2712
447512 if isinstance (elem , ZarrArray | H5Array ):
448513 f [k ][...] = elem [...]
@@ -499,14 +564,15 @@ def write_chunked_dense_array_to_group(
499564 _to_cpu_mem_wrapper (write_basic )
500565)
501566_REGISTRY .register_write (ZarrGroup , CupyArray , IOSpec ("array" , "0.2.0" ))(
502- _to_cpu_mem_wrapper (write_basic )
567+ suppress_autoshard_warning ( _to_cpu_mem_wrapper (write_basic ) )
503568)
504569
505570
506571@_REGISTRY .register_write (ZarrGroup , views .DaskArrayView , IOSpec ("array" , "0.2.0" ))
507572@_REGISTRY .register_write (ZarrGroup , DaskArray , IOSpec ("array" , "0.2.0" ))
508573@_REGISTRY .register_write (H5Group , views .DaskArrayView , IOSpec ("array" , "0.2.0" ))
509574@_REGISTRY .register_write (H5Group , DaskArray , IOSpec ("array" , "0.2.0" ))
575+ @suppress_autoshard_warning
510576def write_basic_dask_dask_dense (
511577 f : ZarrGroup | H5Group ,
512578 k : str ,
@@ -521,11 +587,13 @@ def write_basic_dask_dask_dense(
521587 is_h5 = isinstance (f , H5Group )
522588 if not is_h5 :
523589 dataset_kwargs = zarr_v3_compressor_compat (dataset_kwargs )
524- dataset_kwargs = zarr_v3_sharding (dataset_kwargs )
525590 if is_h5 :
526591 g = f .require_dataset (k , shape = elem .shape , dtype = elem .dtype , ** dataset_kwargs )
527592 else :
528- g = f .require_array (k , shape = elem .shape , dtype = elem .dtype , ** dataset_kwargs )
593+ with zarr_v3_sharding (
594+ dataset_kwargs , format = f .metadata .zarr_format
595+ ) as dataset_kwargs :
596+ g = f .require_array (k , shape = elem .shape , dtype = elem .dtype , ** dataset_kwargs )
529597 da .store (elem , g , scheduler = "threads" )
530598
531599
@@ -590,6 +658,7 @@ def write_vlen_string_array(
590658@_REGISTRY .register_write (ZarrGroup , (np .ndarray , "U" ), IOSpec ("string-array" , "0.2.0" ))
591659@_REGISTRY .register_write (ZarrGroup , (np .ndarray , "O" ), IOSpec ("string-array" , "0.2.0" ))
592660@_REGISTRY .register_write (ZarrGroup , (np .ndarray , "T" ), IOSpec ("string-array" , "0.2.0" ))
661+ @suppress_autoshard_warning
593662@zero_dim_array_as_scalar
594663def write_vlen_string_array_zarr (
595664 f : ZarrGroup ,
@@ -605,15 +674,17 @@ def write_vlen_string_array_zarr(
605674 filters , fill_value = None , None
606675 if f .metadata .zarr_format == 2 :
607676 filters , fill_value = [VLenUTF8 ()], ""
608- dataset_kwargs = zarr_v3_sharding (dataset_kwargs )
609- f .create_array (
610- k ,
611- shape = elem .shape ,
612- dtype = dtype ,
613- filters = filters ,
614- fill_value = fill_value ,
615- ** dataset_kwargs ,
616- )
677+ with zarr_v3_sharding (
678+ dataset_kwargs , format = f .metadata .zarr_format
679+ ) as dataset_kwargs :
680+ f .create_array (
681+ k ,
682+ shape = elem .shape ,
683+ dtype = dtype ,
684+ filters = filters ,
685+ fill_value = fill_value ,
686+ ** dataset_kwargs ,
687+ )
617688 f [k ][:] = elem
618689
619690
@@ -674,8 +745,9 @@ def write_recarray_zarr(
674745 dataset_kwargs = dict (dataset_kwargs )
675746 dataset_kwargs = zarr_v3_compressor_compat (dataset_kwargs )
676747 # https://github.com/zarr-developers/zarr-python/issues/3546
677- # if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
678- # dataset_kwargs = {**dataset_kwargs, "shards": "auto"}
748+ # with zarr_v3_sharding(
749+ # dataset_kwargs, format=f.metadata.zarr_format
750+ # ) as dataset_kwargs:
679751 f .create_array (k , shape = elem .shape , dtype = elem .dtype , ** dataset_kwargs )
680752 f [k ][...] = elem
681753
@@ -730,16 +802,20 @@ def write_sparse_compressed(
730802 attr_name , data = attr , shape = attr .shape , dtype = dtype , ** dataset_kwargs
731803 )
732804 else :
733- dataset_kwargs = zarr_v3_sharding (dataset_kwargs )
734- arr = g .create_array (
735- attr_name , shape = attr .shape , dtype = dtype , ** dataset_kwargs
736- )
805+ with zarr_v3_sharding (
806+ dataset_kwargs , format = f .metadata .zarr_format
807+ ) as dataset_kwargs :
808+ arr = g .create_array (
809+ attr_name , shape = attr .shape , dtype = dtype , ** dataset_kwargs
810+ )
737811 # see https://github.com/zarr-developers/zarr-python/discussions/2712
738812 arr [...] = attr [...]
739813
740814
741- write_csr = partial (write_sparse_compressed , fmt = "csr" )
742- write_csc = partial (write_sparse_compressed , fmt = "csc" )
815+ write_csr , write_csc = (
816+ suppress_autoshard_warning (partial (write_sparse_compressed , fmt = fmt ))
817+ for fmt in ["csr" , "csc" ]
818+ )
743819
744820for store_type , (cls , spec , func ) in product (
745821 (H5Group , ZarrGroup ),
@@ -776,6 +852,7 @@ def write_sparse_compressed(
776852@_REGISTRY .register_write (H5Group , _CSCDataset , IOSpec ("csc_matrix" , "0.1.0" ))
777853@_REGISTRY .register_write (ZarrGroup , _CSRDataset , IOSpec ("csr_matrix" , "0.1.0" ))
778854@_REGISTRY .register_write (ZarrGroup , _CSCDataset , IOSpec ("csc_matrix" , "0.1.0" ))
855+ @suppress_autoshard_warning
779856def write_sparse_dataset (
780857 f : _GroupStorageType ,
781858 k : str ,
@@ -902,6 +979,7 @@ def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None)))
902979@_REGISTRY .register_write (
903980 ZarrGroup , views .AwkwardArrayView , IOSpec ("awkward-array" , "0.1.0" )
904981)
982+ @suppress_autoshard_warning
905983def write_awkward (
906984 f : _GroupStorageType ,
907985 k : str ,
@@ -945,6 +1023,7 @@ def read_awkward(elem: _GroupStorageType, *, _reader: Reader) -> AwkArray:
9451023@_REGISTRY .register_write (H5Group , pd .DataFrame , IOSpec ("dataframe" , "0.2.0" ))
9461024@_REGISTRY .register_write (ZarrGroup , views .DataFrameView , IOSpec ("dataframe" , "0.2.0" ))
9471025@_REGISTRY .register_write (ZarrGroup , pd .DataFrame , IOSpec ("dataframe" , "0.2.0" ))
1026+ @suppress_autoshard_warning
9481027def write_dataframe (
9491028 f : _GroupStorageType ,
9501029 key : str ,
@@ -1086,6 +1165,7 @@ def read_partial_dataframe_0_1_0(
10861165
10871166@_REGISTRY .register_write (H5Group , pd .Categorical , IOSpec ("categorical" , "0.2.0" ))
10881167@_REGISTRY .register_write (ZarrGroup , pd .Categorical , IOSpec ("categorical" , "0.2.0" ))
1168+ @suppress_autoshard_warning
10891169def write_categorical (
10901170 f : _GroupStorageType ,
10911171 k : str ,
0 commit comments