Skip to content

Commit e1833a2

Browse files
Merge branch 'main' into feature/concat_hdf5_virtual_datasets
2 parents 5d3eab2 + 19c8e59 commit e1833a2

3 files changed

Lines changed: 55 additions & 27 deletions

File tree

docs/release-notes/2084.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow writing of views of {class}`dask.array.Array` {user}`ilan-gold`

src/anndata/_io/specs/methods.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def write_chunked_dense_array_to_group(
492492
)
493493

494494

495+
@_REGISTRY.register_write(ZarrGroup, views.DaskArrayView, IOSpec("array", "0.2.0"))
495496
@_REGISTRY.register_write(ZarrGroup, DaskArray, IOSpec("array", "0.2.0"))
496497
def write_basic_dask_zarr(
497498
f: ZarrGroup,
@@ -514,6 +515,7 @@ def write_basic_dask_zarr(
514515

515516
# Adding this separately because h5py isn't serializable
516517
# https://github.com/pydata/xarray/issues/4242
518+
@_REGISTRY.register_write(H5Group, views.DaskArrayView, IOSpec("array", "0.2.0"))
517519
@_REGISTRY.register_write(H5Group, DaskArray, IOSpec("array", "0.2.0"))
518520
def write_basic_dask_h5(
519521
f: H5Group,
@@ -803,21 +805,7 @@ def write_sparse_dataset(
803805
f[k].attrs["encoding-version"] = "0.1.0"
804806

805807

806-
@_REGISTRY.register_write(H5Group, (DaskArray, CupyArray), IOSpec("array", "0.2.0"))
807-
@_REGISTRY.register_write(ZarrGroup, (DaskArray, CupyArray), IOSpec("array", "0.2.0"))
808-
@_REGISTRY.register_write(
809-
H5Group, (DaskArray, CupyCSRMatrix), IOSpec("csr_matrix", "0.1.0")
810-
)
811-
@_REGISTRY.register_write(
812-
H5Group, (DaskArray, CupyCSCMatrix), IOSpec("csc_matrix", "0.1.0")
813-
)
814-
@_REGISTRY.register_write(
815-
ZarrGroup, (DaskArray, CupyCSRMatrix), IOSpec("csr_matrix", "0.1.0")
816-
)
817-
@_REGISTRY.register_write(
818-
ZarrGroup, (DaskArray, CupyCSCMatrix), IOSpec("csc_matrix", "0.1.0")
819-
)
820-
def write_cupy_dask_sparse(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):
808+
def write_cupy_dask(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):
821809
_writer.write_elem(
822810
f,
823811
k,
@@ -826,18 +814,6 @@ def write_cupy_dask_sparse(f, k, elem, _writer, dataset_kwargs=MappingProxyType(
826814
)
827815

828816

829-
@_REGISTRY.register_write(
830-
H5Group, (DaskArray, sparse.csr_matrix), IOSpec("csr_matrix", "0.1.0")
831-
)
832-
@_REGISTRY.register_write(
833-
H5Group, (DaskArray, sparse.csc_matrix), IOSpec("csc_matrix", "0.1.0")
834-
)
835-
@_REGISTRY.register_write(
836-
ZarrGroup, (DaskArray, sparse.csr_matrix), IOSpec("csr_matrix", "0.1.0")
837-
)
838-
@_REGISTRY.register_write(
839-
ZarrGroup, (DaskArray, sparse.csc_matrix), IOSpec("csc_matrix", "0.1.0")
840-
)
841817
def write_dask_sparse(
842818
f: GroupStorageType,
843819
k: str,
@@ -886,6 +862,26 @@ def chunk_slice(start: int, stop: int) -> tuple[slice | None, slice | None]:
886862
disk_mtx.append(elem[chunk_slice(chunk_start, chunk_stop)].compute())
887863

888864

865+
for array_type, group_type in product(
866+
[DaskArray, views.DaskArrayView], [H5Group, ZarrGroup]
867+
):
868+
for cupy_array_type, spec in [
869+
(CupyArray, IOSpec("array", "0.2.0")),
870+
(CupyCSCMatrix, IOSpec("csc_matrix", "0.1.0")),
871+
(CupyCSRMatrix, IOSpec("csr_matrix", "0.1.0")),
872+
]:
873+
_REGISTRY.register_write(group_type, (array_type, cupy_array_type), spec)(
874+
write_cupy_dask
875+
)
876+
for scipy_sparse_type, spec in [
877+
(sparse.csr_matrix, IOSpec("csr_matrix", "0.1.0")),
878+
(sparse.csc_matrix, IOSpec("csc_matrix", "0.1.0")),
879+
]:
880+
_REGISTRY.register_write(group_type, (array_type, scipy_sparse_type), spec)(
881+
write_dask_sparse
882+
)
883+
884+
889885
@_REGISTRY.register_read(H5Group, IOSpec("csc_matrix", "0.1.0"))
890886
@_REGISTRY.register_read(H5Group, IOSpec("csr_matrix", "0.1.0"))
891887
@_REGISTRY.register_read(ZarrGroup, IOSpec("csc_matrix", "0.1.0"))

tests/test_dask.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from anndata.experimental.merge import as_group
1818
from anndata.tests.helpers import (
1919
GEN_ADATA_DASK_ARGS,
20+
as_cupy_sparse_dask_array,
2021
as_dense_cupy_dask_array,
2122
as_dense_dask_array,
2223
as_sparse_dask_array,
@@ -310,6 +311,36 @@ def test_dask_to_memory_unbacked(array_func, mem_type):
310311
assert isinstance(orig.uns["da"]["da"], DaskArray)
311312

312313

314+
@pytest.mark.parametrize(
315+
"array_func",
316+
[
317+
pytest.param(as_dense_dask_array, id="dense_dask_array"),
318+
pytest.param(as_sparse_dask_array, id="sparse_dask_array"),
319+
pytest.param(
320+
as_dense_cupy_dask_array,
321+
id="cupy_dense_dask_array",
322+
marks=pytest.mark.gpu,
323+
),
324+
pytest.param(
325+
as_cupy_sparse_dask_array,
326+
id="cupy_sparse_dask_array",
327+
marks=pytest.mark.gpu,
328+
),
329+
],
330+
)
331+
def test_dask_to_disk_view(array_func, diskfmt, tmp_path):
332+
random_state = np.random.default_rng()
333+
orig = ad.AnnData(
334+
# need to change type for cupy
335+
array_func(random_state.binomial(100, 0.005, (20, 15)).astype("float32"))
336+
)
337+
orig = orig[orig.shape[0] // 2]
338+
path = tmp_path / f"test.{diskfmt}"
339+
getattr(orig, f"write_{diskfmt}")(path)
340+
roundtrip = getattr(ad.io, f"read_{diskfmt}")(path)
341+
assert_equal(roundtrip, orig)
342+
343+
313344
# Test if dask arrays turn into numpy arrays after to_memory is called
314345
def test_dask_to_memory_copy_unbacked():
315346
import numpy as np

0 commit comments

Comments
 (0)