Skip to content

Commit 795fb34

Browse files
ilan-goldmilos7250pre-commit-ci[bot]amalia-k510
authored
fix: merge strategies for concat_on_disk (#2122) (#2142)
Co-authored-by: Miloš Mičík <56844787+milos7250@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: amalia-k510 <kareshamal@gmail.com>
1 parent 9f90b2a commit 795fb34

4 files changed

Lines changed: 97 additions & 44 deletions

File tree

docs/release-notes/2122.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Respect off-axis merge options in {func}`anndata.experimental.concat_on_disk` {user}`ilan-gold`

src/anndata/_core/merge.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,16 @@ def equal_dask_array(a, b) -> bool:
143143
return False
144144
if isinstance(b, DaskArray) and tokenize(a) == tokenize(b):
145145
return True
146-
if isinstance(a._meta, CSMatrix):
146+
if isinstance(a._meta, np.ndarray):
147+
return da.equal(a, b, where=~(da.isnan(a) & da.isnan(b))).all().compute()
148+
if a.chunksize == b.chunksize and isinstance(
149+
a._meta, CupySparseMatrix | CSMatrix | CSArray
150+
):
147151
# TODO: Maybe also do this in the other case?
148152
return da.map_blocks(equal, a, b, drop_axis=(0, 1)).all()
149-
else:
150-
return da.equal(a, b, where=~(da.isnan(a) == da.isnan(b))).all()
153+
msg = "Misaligned chunks detected when checking for merge equality of dask arrays. Reading full arrays into memory."
154+
warn(msg, UserWarning, stacklevel=3)
155+
return equal(a.compute(), b.compute())
151156

152157

153158
@equal.register(np.ndarray)

src/anndata/experimental/merge.py

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .._core.sparse_dataset import BaseCompressedSparseDataset, sparse_dataset
2828
from .._io.specs import read_elem, write_elem
2929
from ..compat import H5Array, H5Group, ZarrArray, ZarrGroup
30-
from . import read_dispatched
30+
from . import read_dispatched, read_elem_lazy
3131

3232
if TYPE_CHECKING:
3333
from collections.abc import Callable, Collection, Iterable, Sequence
@@ -173,7 +173,7 @@ def write_concat_dense( # noqa: PLR0917
173173
output_path: ZarrGroup | H5Group,
174174
axis: Literal[0, 1] = 0,
175175
reindexers: Reindexer | None = None,
176-
fill_value=None,
176+
fill_value: Any = None,
177177
):
178178
"""
179179
Writes the concatenation of given dense arrays to disk using dask.
@@ -206,7 +206,7 @@ def write_concat_sparse( # noqa: PLR0917
206206
max_loaded_elems: int,
207207
axis: Literal[0, 1] = 0,
208208
reindexers: Reindexer | None = None,
209-
fill_value=None,
209+
fill_value: Any = None,
210210
):
211211
"""
212212
Writes and concatenates sparse datasets into a single output dataset.
@@ -246,20 +246,20 @@ def write_concat_sparse( # noqa: PLR0917
246246

247247

248248
def _write_concat_mappings( # noqa: PLR0913, PLR0917
249-
mappings,
249+
mappings: Collection[dict],
250250
output_group: ZarrGroup | H5Group,
251-
keys,
252-
path,
253-
max_loaded_elems,
254-
axis=0,
255-
index=None,
256-
reindexers=None,
257-
fill_value=None,
251+
keys: Collection[str],
252+
output_path: str | Path,
253+
max_loaded_elems: int,
254+
axis: Literal[0, 1] = 0,
255+
index: pd.Index = None,
256+
reindexers: list[Reindexer] | None = None,
257+
fill_value: Any = None,
258258
):
259259
"""
260260
Write a list of mappings to a zarr/h5 group.
261261
"""
262-
mapping_group = output_group.create_group(path)
262+
mapping_group = output_group.create_group(output_path)
263263
mapping_group.attrs.update({
264264
"encoding-type": "dict",
265265
"encoding-version": "0.1.0",
@@ -280,13 +280,13 @@ def _write_concat_mappings( # noqa: PLR0913, PLR0917
280280

281281
def _write_concat_arrays( # noqa: PLR0913, PLR0917
282282
arrays: Sequence[ZarrArray | H5Array | BaseCompressedSparseDataset],
283-
output_group,
284-
output_path,
285-
max_loaded_elems,
286-
axis=0,
287-
reindexers=None,
288-
fill_value=None,
289-
join="inner",
283+
output_group: ZarrGroup | H5Group,
284+
output_path: str | Path,
285+
max_loaded_elems: int,
286+
axis: Literal[0, 1] = 0,
287+
reindexers: list[Reindexer] | None = None,
288+
fill_value: Any = None,
289+
join: Literal["inner", "outer"] = "inner",
290290
):
291291
init_elem = arrays[0]
292292
init_type = type(init_elem)
@@ -324,14 +324,14 @@ def _write_concat_arrays( # noqa: PLR0913, PLR0917
324324

325325
def _write_concat_sequence( # noqa: PLR0913, PLR0917
326326
arrays: Sequence[pd.DataFrame | BaseCompressedSparseDataset | H5Array | ZarrArray],
327-
output_group,
328-
output_path,
329-
max_loaded_elems,
330-
axis=0,
331-
index=None,
332-
reindexers=None,
333-
fill_value=None,
334-
join="inner",
327+
output_group: ZarrGroup | H5Group,
328+
output_path: str | Path,
329+
max_loaded_elems: int,
330+
axis: Literal[0, 1] = 0,
331+
index: pd.Index = None,
332+
reindexers: list[Reindexer] | None = None,
333+
fill_value: Any = None,
334+
join: Literal["inner", "outer"] = "inner",
335335
):
336336
"""
337337
array, dataframe, csc_matrix, csc_matrix
@@ -376,17 +376,27 @@ def _write_concat_sequence( # noqa: PLR0913, PLR0917
376376
raise NotImplementedError(msg)
377377

378378

379-
def _write_alt_mapping(groups, output_group, alt_axis_name, alt_indices, merge):
380-
alt_mapping = merge([read_as_backed(g[alt_axis_name]) for g in groups])
381-
# If its empty, we need to write an empty dataframe with the correct index
382-
if not alt_mapping:
383-
alt_df = pd.DataFrame(index=alt_indices)
384-
write_elem(output_group, alt_axis_name, alt_df)
385-
else:
386-
write_elem(output_group, alt_axis_name, alt_mapping)
379+
def _write_alt_mapping(
380+
groups: Collection[H5Group, ZarrGroup],
381+
output_group: ZarrGroup | H5Group,
382+
alt_axis_name: Literal["obs", "var"],
383+
merge: Callable,
384+
reindexers: list[Reindexer],
385+
):
386+
alt_mapping = merge([
387+
{k: r(read_elem(v), axis=0) for k, v in dict(g[f"{alt_axis_name}m"]).items()}
388+
for r, g in zip(reindexers, groups, strict=True)
389+
])
390+
write_elem(output_group, f"{alt_axis_name}m", alt_mapping)
387391

388392

389-
def _write_alt_annot(groups, output_group, alt_axis_name, alt_indices, merge):
393+
def _write_alt_annot(
394+
groups: Collection[H5Group, ZarrGroup],
395+
output_group: ZarrGroup | H5Group,
396+
alt_axis_name: Literal["obs", "var"],
397+
alt_indices: pd.Index,
398+
merge: Callable,
399+
):
390400
# Annotation for other axis
391401
alt_annot = merge_dataframes(
392402
[read_elem(g[alt_axis_name]) for g in groups], alt_indices, merge
@@ -395,7 +405,13 @@ def _write_alt_annot(groups, output_group, alt_axis_name, alt_indices, merge):
395405

396406

397407
def _write_axis_annot( # noqa: PLR0917
398-
groups, output_group, axis_name, concat_indices, label, label_col, join
408+
groups: Collection[H5Group, ZarrGroup],
409+
output_group: ZarrGroup | H5Group,
410+
axis_name: Literal["obs", "var"],
411+
concat_indices: pd.Index,
412+
label: str,
413+
label_col: str,
414+
join: Literal["inner", "outer"],
399415
):
400416
concat_annot = pd.concat(
401417
unify_dtypes(read_elem(g[axis_name]) for g in groups),
@@ -408,6 +424,23 @@ def _write_axis_annot( # noqa: PLR0917
408424
write_elem(output_group, axis_name, concat_annot)
409425

410426

427+
def _write_alt_pairwise(
428+
groups: Collection[H5Group, ZarrGroup],
429+
output_group: ZarrGroup | H5Group,
430+
alt_axis_name: Literal["obs", "var"],
431+
merge: Callable,
432+
reindexers: list[Reindexer],
433+
):
434+
alt_pairwise = merge([
435+
{
436+
k: r(r(read_elem_lazy(v), axis=0), axis=1)
437+
for k, v in dict(g[f"{alt_axis_name}p"]).items()
438+
}
439+
for r, g in zip(reindexers, groups, strict=True)
440+
])
441+
write_elem(output_group, f"{alt_axis_name}p", alt_pairwise)
442+
443+
411444
def concat_on_disk( # noqa: PLR0912, PLR0913, PLR0915
412445
in_files: Collection[PathLike[str] | str] | Mapping[str, PathLike[str] | str],
413446
out_file: PathLike[str] | str,
@@ -490,7 +523,8 @@ def concat_on_disk( # noqa: PLR0912, PLR0913, PLR0915
490523
DataFrames are padded with missing values.
491524
pairwise
492525
Whether pairwise elements along the concatenated dimension should be included.
493-
This is False by default, since the resulting arrays are often not meaningful.
526+
This is False by default, since the resulting arrays are often not meaningful, and raises {class}`NotImplementedError` when True.
527+
If you are interested in this feature, please open an issue.
494528
495529
Notes
496530
-----
@@ -634,7 +668,10 @@ def concat_on_disk( # noqa: PLR0912, PLR0913, PLR0915
634668
_write_alt_annot(groups, output_group, alt_axis_name, alt_index, merge)
635669

636670
# Write {alt_axis_name}m
637-
_write_alt_mapping(groups, output_group, alt_axis_name, alt_index, merge)
671+
_write_alt_mapping(groups, output_group, alt_axis_name, merge, reindexers)
672+
673+
# Write {alt_axis_name}p
674+
_write_alt_pairwise(groups, output_group, alt_axis_name, merge, reindexers)
638675

639676
# Write X
640677

tests/test_concatenate_disk.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from scipy import sparse
1010

1111
from anndata import AnnData, concat
12+
from anndata._core import merge
1213
from anndata._core.merge import _resolve_axis
1314
from anndata.experimental.merge import as_group, concat_on_disk
1415
from anndata.io import read_elem, write_elem
@@ -31,6 +32,11 @@
3132
)
3233

3334

35+
@pytest.fixture(params=list(merge.MERGE_STRATEGIES.keys()))
36+
def merge_strategy(request):
37+
return request.param
38+
39+
3440
@pytest.fixture(params=[0, 1])
3541
def axis(request) -> Literal[0, 1]:
3642
return request.param
@@ -86,16 +92,17 @@ def assert_eq_concat_on_disk(
8692
file_format: Literal["zarr", "h5ad"],
8793
max_loaded_elems: int | None = None,
8894
*args,
95+
merge_strategy: merge.StrategiesLiteral | None = None,
8996
**kwargs,
9097
):
9198
# create one from the concat function
92-
res1 = concat(adatas, *args, **kwargs)
99+
res1 = concat(adatas, *args, merge=merge_strategy, **kwargs)
93100
# create one from the on disk concat function
94101
paths = _adatas_to_paths(adatas, tmp_path, file_format)
95102
out_name = tmp_path / f"out.{file_format}"
96103
if max_loaded_elems is not None:
97104
kwargs["max_loaded_elems"] = max_loaded_elems
98-
concat_on_disk(paths, out_name, *args, **kwargs)
105+
concat_on_disk(paths, out_name, *args, merge=merge_strategy, **kwargs)
99106
res2 = read_elem(as_group(out_name, mode="r"))
100107
assert_equal(res1, res2, exact=False)
101108

@@ -112,6 +119,7 @@ def get_array_type(array_type, axis):
112119

113120

114121
@pytest.mark.parametrize("reindex", [True, False], ids=["reindex", "no_reindex"])
122+
@pytest.mark.filterwarnings("ignore:Misaligned chunks detected")
115123
def test_anndatas(
116124
*,
117125
axis: Literal[0, 1],
@@ -121,6 +129,7 @@ def test_anndatas(
121129
max_loaded_elems: int,
122130
file_format: Literal["zarr", "h5ad"],
123131
reindex: bool,
132+
merge_strategy: merge.StrategiesLiteral,
124133
):
125134
_, off_axis_name = _resolve_axis(1 - axis)
126135
random_axes = {0, 1} if reindex else {axis}
@@ -159,6 +168,7 @@ def test_anndatas(
159168
max_loaded_elems,
160169
axis=axis,
161170
join=join_type,
171+
merge_strategy=merge_strategy,
162172
)
163173

164174

0 commit comments

Comments
 (0)