Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Iterator
from os import PathLike
from typing import Any, Literal

Expand Down Expand Up @@ -374,13 +374,12 @@ def write_list(
# It's in the `AnnData.concatenate` docstring, but should we keep it?
@_REGISTRY.register_write(H5Group, views.ArrayView, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(H5Group, np.ndarray, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(H5Group, h5py.Dataset, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(H5Group, np.ma.MaskedArray, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, views.ArrayView, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, np.ndarray, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, h5py.Dataset, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, np.ma.MaskedArray, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, ZarrArray, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, H5Array, IOSpec("array", "0.2.0"))
@zero_dim_array_as_scalar
def write_basic(
f: GroupStorageType,
Expand All @@ -394,6 +393,51 @@ def write_basic(
f.create_dataset(k, data=elem, **dataset_kwargs)


def _iter_chunks_for_copy(
elem: ArrayStorageType, dest: ArrayStorageType
) -> Iterator[slice | tuple[list[slice]]]:
"""
Returns an iterator of tuples of slices for copying chunks from `elem` to `dest`.

* If `dest` has chunks, it will return the chunks of `dest`.
* If `dest` is not chunked, we write it in ~100MB chunks or 1000 rows, whichever is larger.
"""
if dest.chunks and hasattr(dest, "iter_chunks"):
return dest.iter_chunks()
else:
shape = elem.shape
# Number of rows that works out to
n_rows = max(
ad.settings.min_rows_for_chunked_h5_copy,
elem.chunks[0] if elem.chunks is not None else 1,
)
return (slice(i, min(i + n_rows, shape[0])) for i in range(0, shape[0], n_rows))


@_REGISTRY.register_write(H5Group, H5Array, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(H5Group, ZarrArray, IOSpec("array", "0.2.0"))
def write_chunked_dense_array_to_group(
f: GroupStorageType,
k: str,
elem: ArrayStorageType,
*,
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
"""Write to a h5py.Dataset in chunks.

`h5py.Group.create_dataset(..., data: h5py.Dataset)` will load all of `data` into memory
before writing. Instead, we will write in chunks to avoid this. We don't need to do this for
zarr since zarr handles this automatically.
"""
dtype = dataset_kwargs.get("dtype", elem.dtype)
kwargs = {**dataset_kwargs, "dtype": dtype}
dest = f.create_dataset(k, shape=elem.shape, **kwargs)

for chunk in _iter_chunks_for_copy(elem, dest):
dest[chunk] = elem[chunk]


_REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))(
_to_cpu_mem_wrapper(write_basic)
)
Expand Down Expand Up @@ -604,9 +648,15 @@ def write_sparse_compressed(
if isinstance(f, H5Group) and "maxshape" not in dataset_kwargs:
dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs)

g.create_dataset("data", data=value.data, **dataset_kwargs)
g.create_dataset("indices", data=value.indices, **dataset_kwargs)
g.create_dataset("indptr", data=value.indptr, dtype=indptr_dtype, **dataset_kwargs)
for attr_name in ["data", "indices", "indptr"]:
attr = getattr(value, attr_name)
dtype = indptr_dtype if attr_name == "indptr" else attr.dtype
_writer.write_elem(
g,
attr_name,
attr,
dataset_kwargs={"dtype": dtype, **dataset_kwargs},
)


write_csr = partial(write_sparse_compressed, fmt="csr")
Expand Down
35 changes: 31 additions & 4 deletions src/anndata/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ def check_and_get_bool(option, default_value):
)


def check_and_get_int(option, default_value):
return check_and_get_environ_var(
f"ANNDATA_{option.upper()}",
str(int(default_value)),
None,
lambda x: int(x),
)


_docstring = """
This manager allows users to customize settings for the anndata package.
Settings here will generally be for advanced use-cases and should be used with caution.
Expand Down Expand Up @@ -396,11 +405,20 @@ def __doc__(self):
# PLACE REGISTERED SETTINGS HERE SO THEY CAN BE PICKED UP FOR DOCSTRING CREATION #
##################################################################################

V = TypeVar("V")


def validate_bool(val: Any) -> None:
if not isinstance(val, bool):
msg = f"{val} not valid boolean"
raise TypeError(msg)
def gen_validator(_type: type[V]) -> Callable[[V], None]:
def validate_type(val: V) -> None:
if not isinstance(val, _type):
msg = f"{val} not valid {_type}"
raise TypeError(msg)

return validate_type


validate_bool = gen_validator(bool)
validate_int = gen_validator(int)


settings.register(
Expand Down Expand Up @@ -448,5 +466,14 @@ def validate_sparse_settings(val: Any) -> None:
get_from_env=check_and_get_bool,
)

settings.register(
"min_rows_for_chunked_h5_copy",
default_value=1000,
description="Minimum number of rows at a time to copy when writing out an H5 Dataset to a new location",
validate=validate_int,
get_from_env=check_and_get_int,
)


##################################################################################
##################################################################################
6 changes: 2 additions & 4 deletions tests/test_io_dispatched.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,5 @@ def zarr_reader(func, elem_name: str, elem, iospec):
write_dispatched(f, "/", adata, callback=zarr_writer)
_ = read_dispatched(f, zarr_reader)

assert h5ad_write_keys == zarr_write_keys
assert h5ad_read_keys == zarr_read_keys

assert sorted(h5ad_write_keys) == sorted(h5ad_read_keys)
assert sorted(h5ad_read_keys) == sorted(zarr_read_keys)
assert sorted(h5ad_write_keys) == sorted(zarr_write_keys)
12 changes: 12 additions & 0 deletions tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,18 @@ def create_sparse_store(
pytest.param(
pd.array([True, False, True, True]), "nullable-boolean", id="pd_arr_bool"
),
pytest.param(
zarr.ones((100, 100), chunks=(10, 10)),
"array",
id="zarr_dense_array",
),
pytest.param(
create_dense_store(
h5py.File("test1.h5", mode="w", driver="core", backing_store=False)
)["X"],
"array",
id="h5_dense_array",
),
# pytest.param(bytes, b"some bytes", "bytes", id="py_bytes"), # Does not work for zarr
# TODO consider how specific encodings should be. Should we be fully describing the written type?
# Currently the info we add is: "what you wouldn't be able to figure out yourself"
Expand Down