Skip to content

Commit 4fcaea9

Browse files
authored
Backport PR #1624: Chunked writing of h5py.Dataset and zarr.Array (#1893)
1 parent 3f36ba2 commit 4fcaea9

4 files changed

Lines changed: 101 additions & 14 deletions

File tree

src/anndata/_io/specs/methods.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial
4545

4646
if TYPE_CHECKING:
47-
from collections.abc import Callable
47+
from collections.abc import Callable, Iterator
4848
from os import PathLike
4949
from typing import Any, Literal
5050

@@ -374,13 +374,12 @@ def write_list(
374374
# It's in the `AnnData.concatenate` docstring, but should we keep it?
375375
@_REGISTRY.register_write(H5Group, views.ArrayView, IOSpec("array", "0.2.0"))
376376
@_REGISTRY.register_write(H5Group, np.ndarray, IOSpec("array", "0.2.0"))
377-
@_REGISTRY.register_write(H5Group, h5py.Dataset, IOSpec("array", "0.2.0"))
378377
@_REGISTRY.register_write(H5Group, np.ma.MaskedArray, IOSpec("array", "0.2.0"))
379378
@_REGISTRY.register_write(ZarrGroup, views.ArrayView, IOSpec("array", "0.2.0"))
380379
@_REGISTRY.register_write(ZarrGroup, np.ndarray, IOSpec("array", "0.2.0"))
381-
@_REGISTRY.register_write(ZarrGroup, h5py.Dataset, IOSpec("array", "0.2.0"))
382380
@_REGISTRY.register_write(ZarrGroup, np.ma.MaskedArray, IOSpec("array", "0.2.0"))
383381
@_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0"))
382+
@_REGISTRY.register_write(ZarrGroup, H5Array, IOSpec("array", "0.2.0"))
384383
@zero_dim_array_as_scalar
385384
def write_basic(
386385
f: GroupStorageType,
@@ -394,6 +393,51 @@ def write_basic(
394393
f.create_dataset(k, data=elem, **dataset_kwargs)
395394

396395

396+
def _iter_chunks_for_copy(
397+
elem: ArrayStorageType, dest: ArrayStorageType
398+
) -> Iterator[slice | tuple[list[slice]]]:
399+
"""
400+
Returns an iterator of tuples of slices for copying chunks from `elem` to `dest`.
401+
402+
* If `dest` has chunks, it will return the chunks of `dest`.
403+
* If `dest` is not chunked, we write it in ~100MB chunks or 1000 rows, whichever is larger.
404+
"""
405+
if dest.chunks and hasattr(dest, "iter_chunks"):
406+
return dest.iter_chunks()
407+
else:
408+
shape = elem.shape
409+
# Number of rows that works out to
410+
n_rows = max(
411+
ad.settings.min_rows_for_chunked_h5_copy,
412+
elem.chunks[0] if elem.chunks is not None else 1,
413+
)
414+
return (slice(i, min(i + n_rows, shape[0])) for i in range(0, shape[0], n_rows))
415+
416+
417+
@_REGISTRY.register_write(H5Group, H5Array, IOSpec("array", "0.2.0"))
418+
@_REGISTRY.register_write(H5Group, ZarrArray, IOSpec("array", "0.2.0"))
419+
def write_chunked_dense_array_to_group(
420+
f: GroupStorageType,
421+
k: str,
422+
elem: ArrayStorageType,
423+
*,
424+
_writer: Writer,
425+
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
426+
):
427+
"""Write to a h5py.Dataset in chunks.
428+
429+
`h5py.Group.create_dataset(..., data: h5py.Dataset)` will load all of `data` into memory
430+
before writing. Instead, we will write in chunks to avoid this. We don't need to do this for
431+
zarr since zarr handles this automatically.
432+
"""
433+
dtype = dataset_kwargs.get("dtype", elem.dtype)
434+
kwargs = {**dataset_kwargs, "dtype": dtype}
435+
dest = f.create_dataset(k, shape=elem.shape, **kwargs)
436+
437+
for chunk in _iter_chunks_for_copy(elem, dest):
438+
dest[chunk] = elem[chunk]
439+
440+
397441
_REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))(
398442
_to_cpu_mem_wrapper(write_basic)
399443
)
@@ -604,9 +648,15 @@ def write_sparse_compressed(
604648
if isinstance(f, H5Group) and "maxshape" not in dataset_kwargs:
605649
dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs)
606650

607-
g.create_dataset("data", data=value.data, **dataset_kwargs)
608-
g.create_dataset("indices", data=value.indices, **dataset_kwargs)
609-
g.create_dataset("indptr", data=value.indptr, dtype=indptr_dtype, **dataset_kwargs)
651+
for attr_name in ["data", "indices", "indptr"]:
652+
attr = getattr(value, attr_name)
653+
dtype = indptr_dtype if attr_name == "indptr" else attr.dtype
654+
_writer.write_elem(
655+
g,
656+
attr_name,
657+
attr,
658+
dataset_kwargs={"dtype": dtype, **dataset_kwargs},
659+
)
610660

611661

612662
write_csr = partial(write_sparse_compressed, fmt="csr")

src/anndata/_settings.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ def check_and_get_bool(option, default_value):
126126
)
127127

128128

129+
def check_and_get_int(option, default_value):
130+
return check_and_get_environ_var(
131+
f"ANNDATA_{option.upper()}",
132+
str(int(default_value)),
133+
None,
134+
lambda x: int(x),
135+
)
136+
137+
129138
_docstring = """
130139
This manager allows users to customize settings for the anndata package.
131140
Settings here will generally be for advanced use-cases and should be used with caution.
@@ -396,11 +405,20 @@ def __doc__(self):
396405
# PLACE REGISTERED SETTINGS HERE SO THEY CAN BE PICKED UP FOR DOCSTRING CREATION #
397406
##################################################################################
398407

408+
V = TypeVar("V")
409+
399410

400-
def validate_bool(val: Any) -> None:
401-
if not isinstance(val, bool):
402-
msg = f"{val} not valid boolean"
403-
raise TypeError(msg)
411+
def gen_validator(_type: type[V]) -> Callable[[V], None]:
412+
def validate_type(val: V) -> None:
413+
if not isinstance(val, _type):
414+
msg = f"{val} not valid {_type}"
415+
raise TypeError(msg)
416+
417+
return validate_type
418+
419+
420+
validate_bool = gen_validator(bool)
421+
validate_int = gen_validator(int)
404422

405423

406424
settings.register(
@@ -448,5 +466,14 @@ def validate_sparse_settings(val: Any) -> None:
448466
get_from_env=check_and_get_bool,
449467
)
450468

469+
settings.register(
470+
"min_rows_for_chunked_h5_copy",
471+
default_value=1000,
472+
description="Minimum number of rows at a time to copy when writing out an H5 Dataset to a new location",
473+
validate=validate_int,
474+
get_from_env=check_and_get_int,
475+
)
476+
477+
451478
##################################################################################
452479
##################################################################################

tests/test_io_dispatched.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,5 @@ def zarr_reader(func, elem_name: str, elem, iospec):
174174
write_dispatched(f, "/", adata, callback=zarr_writer)
175175
_ = read_dispatched(f, zarr_reader)
176176

177-
assert h5ad_write_keys == zarr_write_keys
178-
assert h5ad_read_keys == zarr_read_keys
179-
180-
assert sorted(h5ad_write_keys) == sorted(h5ad_read_keys)
177+
assert sorted(h5ad_read_keys) == sorted(zarr_read_keys)
178+
assert sorted(h5ad_write_keys) == sorted(zarr_write_keys)

tests/test_io_elementwise.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ def create_sparse_store(
194194
pytest.param(
195195
pd.array([True, False, True, True]), "nullable-boolean", id="pd_arr_bool"
196196
),
197+
pytest.param(
198+
zarr.ones((100, 100), chunks=(10, 10)),
199+
"array",
200+
id="zarr_dense_array",
201+
),
202+
pytest.param(
203+
create_dense_store(
204+
h5py.File("test1.h5", mode="w", driver="core", backing_store=False)
205+
)["X"],
206+
"array",
207+
id="h5_dense_array",
208+
),
197209
# pytest.param(bytes, b"some bytes", "bytes", id="py_bytes"), # Does not work for zarr
198210
# TODO consider how specific encodings should be. Should we be fully describing the written type?
199211
# Currently the info we add is: "what you wouldn't be able to figure out yourself"

0 commit comments

Comments
 (0)