diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 327e5fbbc..2399461bc 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -44,7 +44,7 @@ from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Iterator from os import PathLike from typing import Any, Literal @@ -374,13 +374,12 @@ def write_list( # It's in the `AnnData.concatenate` docstring, but should we keep it? @_REGISTRY.register_write(H5Group, views.ArrayView, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(H5Group, np.ndarray, IOSpec("array", "0.2.0")) -@_REGISTRY.register_write(H5Group, h5py.Dataset, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(H5Group, np.ma.MaskedArray, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, views.ArrayView, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, np.ndarray, IOSpec("array", "0.2.0")) -@_REGISTRY.register_write(ZarrGroup, h5py.Dataset, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, np.ma.MaskedArray, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0")) +@_REGISTRY.register_write(ZarrGroup, H5Array, IOSpec("array", "0.2.0")) @zero_dim_array_as_scalar def write_basic( f: GroupStorageType, @@ -394,6 +393,51 @@ def write_basic( f.create_dataset(k, data=elem, **dataset_kwargs) +def _iter_chunks_for_copy( + elem: ArrayStorageType, dest: ArrayStorageType +) -> Iterator[slice | tuple[list[slice]]]: + """ + Returns an iterator of tuples of slices for copying chunks from `elem` to `dest`. + + * If `dest` has chunks, it will return the chunks of `dest`. + * If `dest` is not chunked, we write it in ~100MB chunks or 1000 rows, whichever is larger. + """ + if dest.chunks and hasattr(dest, "iter_chunks"): + return dest.iter_chunks() + else: + shape = elem.shape + # Number of rows that works out to + n_rows = max( + ad.settings.min_rows_for_chunked_h5_copy, + elem.chunks[0] if elem.chunks is not None else 1, + ) + return (slice(i, min(i + n_rows, shape[0])) for i in range(0, shape[0], n_rows)) + + +@_REGISTRY.register_write(H5Group, H5Array, IOSpec("array", "0.2.0")) +@_REGISTRY.register_write(H5Group, ZarrArray, IOSpec("array", "0.2.0")) +def write_chunked_dense_array_to_group( + f: GroupStorageType, + k: str, + elem: ArrayStorageType, + *, + _writer: Writer, + dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), +): + """Write to a h5py.Dataset in chunks. + + `h5py.Group.create_dataset(..., data: h5py.Dataset)` will load all of `data` into memory + before writing. Instead, we will write in chunks to avoid this. We don't need to do this for + zarr since zarr handles this automatically. + """ + dtype = dataset_kwargs.get("dtype", elem.dtype) + kwargs = {**dataset_kwargs, "dtype": dtype} + dest = f.create_dataset(k, shape=elem.shape, **kwargs) + + for chunk in _iter_chunks_for_copy(elem, dest): + dest[chunk] = elem[chunk] + + _REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))( _to_cpu_mem_wrapper(write_basic) ) @@ -604,9 +648,15 @@ def write_sparse_compressed( if isinstance(f, H5Group) and "maxshape" not in dataset_kwargs: dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs) - g.create_dataset("data", data=value.data, **dataset_kwargs) - g.create_dataset("indices", data=value.indices, **dataset_kwargs) - g.create_dataset("indptr", data=value.indptr, dtype=indptr_dtype, **dataset_kwargs) + for attr_name in ["data", "indices", "indptr"]: + attr = getattr(value, attr_name) + dtype = indptr_dtype if attr_name == "indptr" else attr.dtype + _writer.write_elem( + g, + attr_name, + attr, + dataset_kwargs={"dtype": dtype, **dataset_kwargs}, + ) write_csr = partial(write_sparse_compressed, fmt="csr") diff --git a/src/anndata/_settings.py b/src/anndata/_settings.py index ae066f9e5..f2a809173 100644 --- a/src/anndata/_settings.py +++ b/src/anndata/_settings.py @@ -126,6 +126,15 @@ def check_and_get_bool(option, default_value): ) +def check_and_get_int(option, default_value): + return check_and_get_environ_var( + f"ANNDATA_{option.upper()}", + str(int(default_value)), + None, + lambda x: int(x), + ) + + _docstring = """ This manager allows users to customize settings for the anndata package. Settings here will generally be for advanced use-cases and should be used with caution. @@ -396,11 +405,20 @@ def __doc__(self): # PLACE REGISTERED SETTINGS HERE SO THEY CAN BE PICKED UP FOR DOCSTRING CREATION # ################################################################################## +V = TypeVar("V") + -def validate_bool(val: Any) -> None: - if not isinstance(val, bool): - msg = f"{val} not valid boolean" - raise TypeError(msg) +def gen_validator(_type: type[V]) -> Callable[[V], None]: + def validate_type(val: V) -> None: + if not isinstance(val, _type): + msg = f"{val} not valid {_type}" + raise TypeError(msg) + + return validate_type + + +validate_bool = gen_validator(bool) +validate_int = gen_validator(int) settings.register( @@ -448,5 +466,14 @@ def validate_sparse_settings(val: Any) -> None: get_from_env=check_and_get_bool, ) +settings.register( + "min_rows_for_chunked_h5_copy", + default_value=1000, + description="Minimum number of rows at a time to copy when writing out an H5 Dataset to a new location", + validate=validate_int, + get_from_env=check_and_get_int, +) + + ################################################################################## ################################################################################## diff --git a/tests/test_io_dispatched.py b/tests/test_io_dispatched.py index 3246d64d2..ee5fb63f5 100644 --- a/tests/test_io_dispatched.py +++ b/tests/test_io_dispatched.py @@ -174,7 +174,5 @@ def zarr_reader(func, elem_name: str, elem, iospec): write_dispatched(f, "/", adata, callback=zarr_writer) _ = read_dispatched(f, zarr_reader) - assert h5ad_write_keys == zarr_write_keys - assert h5ad_read_keys == zarr_read_keys - - assert sorted(h5ad_write_keys) == sorted(h5ad_read_keys) + assert sorted(h5ad_read_keys) == sorted(zarr_read_keys) + assert sorted(h5ad_write_keys) == sorted(zarr_write_keys) diff --git a/tests/test_io_elementwise.py b/tests/test_io_elementwise.py index 31188b173..c9000c964 100644 --- a/tests/test_io_elementwise.py +++ b/tests/test_io_elementwise.py @@ -194,6 +194,18 @@ def create_sparse_store( pytest.param( pd.array([True, False, True, True]), "nullable-boolean", id="pd_arr_bool" ), + pytest.param( + zarr.ones((100, 100), chunks=(10, 10)), + "array", + id="zarr_dense_array", + ), + pytest.param( + create_dense_store( + h5py.File("test1.h5", mode="w", driver="core", backing_store=False) + )["X"], + "array", + id="h5_dense_array", + ), # pytest.param(bytes, b"some bytes", "bytes", id="py_bytes"), # Does not work for zarr # TODO consider how specific encodings should be. Should we be fully describing the written type? # Currently the info we add is: "what you wouldn't be able to figure out yourself"