22
33import warnings
44from collections .abc import Mapping
5+ from contextlib import contextmanager , nullcontext
56from copy import copy
6- from functools import partial
7+ from functools import partial , wraps
78from importlib .metadata import version
89from itertools import product
910from types import MappingProxyType
10- from typing import TYPE_CHECKING , Protocol
11+ from typing import TYPE_CHECKING , Protocol , TypeVar
1112from warnings import warn
1213
1314import h5py
1415import numpy as np
1516import pandas as pd
17+ import zarr
1618from packaging .version import Version
1719from scipy import sparse
1820
2325from anndata ._core .merge import intersect_keys
2426from anndata ._core .sparse_dataset import _CSCDataset , _CSRDataset , sparse_dataset
2527from anndata ._io .utils import check_key , zero_dim_array_as_scalar
28+ from anndata ._types import StorageType
2629from anndata ._warnings import OldFormatWarning
2730from anndata .compat import (
2831 AwkArray ,
4043 _read_attr ,
4144 _require_group_write_dataframe ,
4245)
46+ from anndata .typing import RWAble
4347
4448from ..._settings import settings
4549from ...compat import NULLABLE_NUMPY_STRING_TYPE , PANDAS_STRING_ARRAY_TYPES , is_zarr_v2
4650from .registry import _REGISTRY , IOSpec , read_elem , read_elem_partial
4751
4852if TYPE_CHECKING :
49- from collections .abc import Iterator
53+ from collections .abc import Generator , Iterator
5054 from os import PathLike
5155 from typing import Any , Literal
5256
5761 from anndata .compat import CSArray , CSMatrix
5862 from anndata .typing import AxisStorable , InMemoryArrayOrScalarType
5963
64+ from ...types import _WriteInternal
6065 from .registry import Reader , Writer
6166
6267####################
@@ -113,10 +118,36 @@ def zarr_v3_compressor_compat(dataset_kwargs: dict) -> dict:
113118 return dataset_kwargs
114119
115120
116- def zarr_v3_sharding (dataset_kwargs ) -> dict :
117- if "shards" not in dataset_kwargs and ad .settings .auto_shard_zarr_v3 :
121+ @contextmanager
122+ def zarr_v3_sharding (dataset_kwargs : dict , format : Literal [2 , 3 ]) -> Generator [dict ]:
123+ auto_sharding = (
124+ "shards" not in dataset_kwargs
125+ and ad .settings .auto_shard_zarr_v3
126+ and format == 3
127+ )
128+ if ad .settings .auto_shard_zarr_v3 is None and format == 3 :
129+ warnings .warn (
130+ "zarr v3 autosharding will be the default in the next minor release." ,
131+ UserWarning ,
132+ stacklevel = 2 ,
133+ )
134+ elif auto_sharding :
118135 dataset_kwargs = {** dataset_kwargs , "shards" : "auto" }
119- return dataset_kwargs
136+ # Auto shard sizes are a relatively recent feature
137+ supports_auto_shard_size = Version (version ("zarr" )) >= Version ("3.1.4" )
138+ has_auto_shard_size = supports_auto_shard_size and isinstance (
139+ zarr .config .get ("array.target_shard_size_bytes" ), int
140+ )
141+ # 1GB uncompressed shard size seems reasonable.
142+ # Shards need to generally held completely in memory before writing.
143+ # Even at a compression ration of 6x, that's still a ~20x improvement on number of files.
144+ # Users can ovetrride this nonetheless, hence the above checks.
145+ with (
146+ zarr .config .set ({"array.target_shard_size_bytes" : 1_000_000_000 })
147+ if supports_auto_shard_size and not has_auto_shard_size and auto_sharding
148+ else nullcontext ()
149+ ):
150+ yield dataset_kwargs
120151
121152
122153def _to_cpu_mem_wrapper (write_func ):
@@ -141,6 +172,36 @@ def wrapper(
141172 return wrapper
142173
143174
175+ S = TypeVar ("S" , bound = StorageType )
176+ T = TypeVar ("T" , bound = RWAble )
177+
178+
179+ def suppress_autoshard_warning (
180+ func : _WriteInternal [S , T ],
181+ ) -> _WriteInternal [S , T ]:
182+ @wraps (func )
183+ def wrapper (
184+ f : S ,
185+ k : str ,
186+ val : T ,
187+ * ,
188+ _writer : Writer ,
189+ dataset_kwargs : Mapping [str , Any ] = MappingProxyType ({}),
190+ ):
191+ with warnings .catch_warnings ():
192+ # Suppress warnings only if the user has opted into autosharding at the top level.
193+ # If someone provides `shards` explicitly, then they should get the warning.
194+ if ad .settings .auto_shard_zarr_v3 and "shards" not in dataset_kwargs :
195+ warnings .filterwarnings (
196+ "ignore" ,
197+ r"Automatic shard shape inference is experimental" ,
198+ UserWarning ,
199+ )
200+ return func (f , k , val , _writer = _writer , dataset_kwargs = dataset_kwargs )
201+
202+ return wrapper
203+
204+
144205################################
145206# Fallbacks / backwards compat #
146207################################
@@ -284,6 +345,7 @@ def _read_partial(group, *, items=None, indices=(slice(None), slice(None))):
284345
285346@_REGISTRY .register_write (ZarrGroup , AnnData , IOSpec ("anndata" , "0.1.0" ))
286347@_REGISTRY .register_write (H5Group , AnnData , IOSpec ("anndata" , "0.1.0" ))
348+ @suppress_autoshard_warning
287349def write_anndata (
288350 f : GroupStorageType ,
289351 k : str ,
@@ -333,6 +395,7 @@ def read_anndata(elem: GroupStorageType | H5File, *, _reader: Reader) -> AnnData
333395
334396@_REGISTRY .register_write (H5Group , Raw , IOSpec ("raw" , "0.1.0" ))
335397@_REGISTRY .register_write (ZarrGroup , Raw , IOSpec ("raw" , "0.1.0" ))
398+ @suppress_autoshard_warning
336399def write_raw (
337400 f : GroupStorageType ,
338401 k : str ,
@@ -365,6 +428,7 @@ def write_null_h5py(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})):
365428
366429
367430@_REGISTRY .register_write (ZarrGroup , type (None ), IOSpec ("null" , "0.1.0" ))
431+ @suppress_autoshard_warning
368432def write_null_zarr (f , k , _v , _writer , dataset_kwargs = MappingProxyType ({})):
369433 dataset_kwargs = _remove_scalar_compression_args (dataset_kwargs )
370434 # zarr has no first-class null dataset
@@ -392,6 +456,7 @@ def read_mapping(elem: GroupStorageType, *, _reader: Reader) -> dict[str, AxisSt
392456
393457@_REGISTRY .register_write (H5Group , dict , IOSpec ("dict" , "0.1.0" ))
394458@_REGISTRY .register_write (ZarrGroup , dict , IOSpec ("dict" , "0.1.0" ))
459+ @suppress_autoshard_warning
395460def write_mapping (
396461 f : GroupStorageType ,
397462 k : str ,
@@ -412,6 +477,7 @@ def write_mapping(
412477
413478@_REGISTRY .register_write (H5Group , list , IOSpec ("array" , "0.2.0" ))
414479@_REGISTRY .register_write (ZarrGroup , list , IOSpec ("array" , "0.2.0" ))
480+ @suppress_autoshard_warning
415481def write_list (
416482 f : GroupStorageType ,
417483 k : str ,
@@ -433,6 +499,7 @@ def write_list(
433499@_REGISTRY .register_write (ZarrGroup , np .ma .MaskedArray , IOSpec ("array" , "0.2.0" ))
434500@_REGISTRY .register_write (ZarrGroup , ZarrArray , IOSpec ("array" , "0.2.0" ))
435501@_REGISTRY .register_write (ZarrGroup , H5Array , IOSpec ("array" , "0.2.0" ))
502+ @suppress_autoshard_warning
436503@zero_dim_array_as_scalar
437504def write_basic (
438505 f : GroupStorageType ,
@@ -449,8 +516,10 @@ def write_basic(
449516 f .create_dataset (k , data = elem , shape = elem .shape , dtype = dtype , ** dataset_kwargs )
450517 else :
451518 dataset_kwargs = zarr_v3_compressor_compat (dataset_kwargs )
452- dataset_kwargs = zarr_v3_sharding (dataset_kwargs )
453- f .create_array (k , shape = elem .shape , dtype = dtype , ** dataset_kwargs )
519+ with zarr_v3_sharding (
520+ dataset_kwargs , format = f .metadata .zarr_format
521+ ) as dataset_kwargs :
522+ f .create_array (k , shape = elem .shape , dtype = dtype , ** dataset_kwargs )
454523 # see https://github.com/zarr-developers/zarr-python/discussions/2712
455524 if isinstance (elem , ZarrArray | H5Array ):
456525 f [k ][...] = elem [...]
@@ -507,14 +576,15 @@ def write_chunked_dense_array_to_group(
507576 _to_cpu_mem_wrapper (write_basic )
508577)
509578_REGISTRY .register_write (ZarrGroup , CupyArray , IOSpec ("array" , "0.2.0" ))(
510- _to_cpu_mem_wrapper (write_basic )
579+ suppress_autoshard_warning ( _to_cpu_mem_wrapper (write_basic ) )
511580)
512581
513582
514583@_REGISTRY .register_write (ZarrGroup , views .DaskArrayView , IOSpec ("array" , "0.2.0" ))
515584@_REGISTRY .register_write (ZarrGroup , DaskArray , IOSpec ("array" , "0.2.0" ))
516585@_REGISTRY .register_write (H5Group , views .DaskArrayView , IOSpec ("array" , "0.2.0" ))
517586@_REGISTRY .register_write (H5Group , DaskArray , IOSpec ("array" , "0.2.0" ))
587+ @suppress_autoshard_warning
518588def write_basic_dask_dask_dense (
519589 f : ZarrGroup | H5Group ,
520590 k : str ,
@@ -527,13 +597,14 @@ def write_basic_dask_dask_dense(
527597
528598 dataset_kwargs = dict (dataset_kwargs )
529599 is_h5 = isinstance (f , H5Group )
530- if not is_h5 :
531- dataset_kwargs = zarr_v3_compressor_compat (dataset_kwargs )
532- dataset_kwargs = zarr_v3_sharding (dataset_kwargs )
533600 if is_zarr_v2 () or is_h5 :
534601 g = f .require_dataset (k , shape = elem .shape , dtype = elem .dtype , ** dataset_kwargs )
535602 else :
536- g = f .require_array (k , shape = elem .shape , dtype = elem .dtype , ** dataset_kwargs )
603+ dataset_kwargs = zarr_v3_compressor_compat (dataset_kwargs )
604+ with zarr_v3_sharding (
605+ dataset_kwargs , format = f .metadata .zarr_format
606+ ) as dataset_kwargs :
607+ g = f .require_array (k , shape = elem .shape , dtype = elem .dtype , ** dataset_kwargs )
537608 da .store (elem , g , scheduler = "threads" )
538609
539610
@@ -598,6 +669,7 @@ def write_vlen_string_array(
598669@_REGISTRY .register_write (ZarrGroup , (np .ndarray , "U" ), IOSpec ("string-array" , "0.2.0" ))
599670@_REGISTRY .register_write (ZarrGroup , (np .ndarray , "O" ), IOSpec ("string-array" , "0.2.0" ))
600671@_REGISTRY .register_write (ZarrGroup , (np .ndarray , "T" ), IOSpec ("string-array" , "0.2.0" ))
672+ @suppress_autoshard_warning
601673@zero_dim_array_as_scalar
602674def write_vlen_string_array_zarr (
603675 f : ZarrGroup ,
@@ -635,15 +707,17 @@ def write_vlen_string_array_zarr(
635707 filters , fill_value = None , None
636708 if f .metadata .zarr_format == 2 :
637709 filters , fill_value = [VLenUTF8 ()], ""
638- dataset_kwargs = zarr_v3_sharding (dataset_kwargs )
639- f .create_array (
640- k ,
641- shape = elem .shape ,
642- dtype = dtype ,
643- filters = filters ,
644- fill_value = fill_value ,
645- ** dataset_kwargs ,
646- )
710+ with zarr_v3_sharding (
711+ dataset_kwargs , format = f .metadata .zarr_format
712+ ) as dataset_kwargs :
713+ f .create_array (
714+ k ,
715+ shape = elem .shape ,
716+ dtype = dtype ,
717+ filters = filters ,
718+ fill_value = fill_value ,
719+ ** dataset_kwargs ,
720+ )
647721 f [k ][:] = elem
648722
649723
@@ -705,8 +779,9 @@ def write_recarray_zarr(
705779 dataset_kwargs = dict (dataset_kwargs )
706780 dataset_kwargs = zarr_v3_compressor_compat (dataset_kwargs )
707781 # https://github.com/zarr-developers/zarr-python/issues/3546
708- # if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
709- # dataset_kwargs = {**dataset_kwargs, "shards": "auto"}
782+ # with zarr_v3_sharding(
783+ # dataset_kwargs, format=f.metadata.zarr_format
784+ # ) as dataset_kwargs:
710785 f .create_array (k , shape = elem .shape , dtype = elem .dtype , ** dataset_kwargs )
711786 f [k ][...] = elem
712787
@@ -761,16 +836,20 @@ def write_sparse_compressed(
761836 attr_name , data = attr , shape = attr .shape , dtype = dtype , ** dataset_kwargs
762837 )
763838 else :
764- dataset_kwargs = zarr_v3_sharding (dataset_kwargs )
765- arr = g .create_array (
766- attr_name , shape = attr .shape , dtype = dtype , ** dataset_kwargs
767- )
839+ with zarr_v3_sharding (
840+ dataset_kwargs , format = f .metadata .zarr_format
841+ ) as dataset_kwargs :
842+ arr = g .create_array (
843+ attr_name , shape = attr .shape , dtype = dtype , ** dataset_kwargs
844+ )
768845 # see https://github.com/zarr-developers/zarr-python/discussions/2712
769846 arr [...] = attr [...]
770847
771848
772- write_csr = partial (write_sparse_compressed , fmt = "csr" )
773- write_csc = partial (write_sparse_compressed , fmt = "csc" )
849+ write_csr , write_csc = (
850+ suppress_autoshard_warning (partial (write_sparse_compressed , fmt = fmt ))
851+ for fmt in ["csr" , "csc" ]
852+ )
774853
775854for store_type , (cls , spec , func ) in product (
776855 (H5Group , ZarrGroup ),
@@ -807,6 +886,7 @@ def write_sparse_compressed(
807886@_REGISTRY .register_write (H5Group , _CSCDataset , IOSpec ("csc_matrix" , "0.1.0" ))
808887@_REGISTRY .register_write (ZarrGroup , _CSRDataset , IOSpec ("csr_matrix" , "0.1.0" ))
809888@_REGISTRY .register_write (ZarrGroup , _CSCDataset , IOSpec ("csc_matrix" , "0.1.0" ))
889+ @suppress_autoshard_warning
810890def write_sparse_dataset (
811891 f : GroupStorageType ,
812892 k : str ,
@@ -931,6 +1011,7 @@ def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None)))
9311011@_REGISTRY .register_write (
9321012 ZarrGroup , views .AwkwardArrayView , IOSpec ("awkward-array" , "0.1.0" )
9331013)
1014+ @suppress_autoshard_warning
9341015def write_awkward (
9351016 f : GroupStorageType ,
9361017 k : str ,
@@ -974,6 +1055,7 @@ def read_awkward(elem: GroupStorageType, *, _reader: Reader) -> AwkArray:
9741055@_REGISTRY .register_write (H5Group , pd .DataFrame , IOSpec ("dataframe" , "0.2.0" ))
9751056@_REGISTRY .register_write (ZarrGroup , views .DataFrameView , IOSpec ("dataframe" , "0.2.0" ))
9761057@_REGISTRY .register_write (ZarrGroup , pd .DataFrame , IOSpec ("dataframe" , "0.2.0" ))
1058+ @suppress_autoshard_warning
9771059def write_dataframe (
9781060 f : GroupStorageType ,
9791061 key : str ,
@@ -1115,6 +1197,7 @@ def read_partial_dataframe_0_1_0(
11151197
11161198@_REGISTRY .register_write (H5Group , pd .Categorical , IOSpec ("categorical" , "0.2.0" ))
11171199@_REGISTRY .register_write (ZarrGroup , pd .Categorical , IOSpec ("categorical" , "0.2.0" ))
1200+ @suppress_autoshard_warning
11181201def write_categorical (
11191202 f : GroupStorageType ,
11201203 k : str ,
0 commit comments