Skip to content

Commit 355cf10

Browse files
committed
add tests and make this an optional argument
1 parent f3383ac commit 355cf10

2 files changed

Lines changed: 86 additions & 16 deletions

File tree

src/anndata/experimental/merge.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,13 @@ def write_concat_dense( # noqa: PLR0917
208208
})
209209

210210

211-
def write_concat_sparse( # noqa: PLR0917
211+
def write_concat_sparse( # noqa: PLR0913, PLR0917
212212
datasets: Sequence[BaseCompressedSparseDataset],
213213
output_group: ZarrGroup | H5Group,
214214
output_path: ZarrGroup | H5Group,
215+
*,
215216
max_loaded_elems: int,
217+
virtual_concat: bool = False,
216218
axis: Literal[0, 1] = 0,
217219
reindexers: Reindexer | None = None,
218220
fill_value: Any = None,
@@ -235,6 +237,9 @@ def write_concat_sparse( # noqa: PLR0917
235237
A reindexer object that defines the reindexing operation to be applied.
236238
fill_value
237239
The fill value to use for missing elements. Defaults to None.
240+
virtual_concat
241+
Whether to use virtual concatenation for sparse arrays.
242+
238243
"""
239244
elems = None
240245
use_reindexing = True
@@ -245,7 +250,7 @@ def write_concat_sparse( # noqa: PLR0917
245250
elems = _gen_slice_to_append(
246251
datasets, reindexers, max_loaded_elems, axis, fill_value
247252
)
248-
if datasets[0].backend == "hdf5" and not use_reindexing:
253+
if datasets[0].backend == "hdf5" and not use_reindexing and virtual_concat:
249254
BaseCompressedSparseDataset.virtual_concat_hdf5(
250255
datasets, output_group, output_path
251256
)
@@ -278,6 +283,7 @@ def _write_concat_mappings( # noqa: PLR0913, PLR0917
278283
index: pd.Index = None,
279284
reindexers: list[Reindexer] | None = None,
280285
fill_value: Any = None,
286+
virtual_concat: bool = False,
281287
):
282288
"""
283289
Write a list of mappings to a zarr/h5 group.
@@ -298,6 +304,7 @@ def _write_concat_mappings( # noqa: PLR0913, PLR0917
298304
reindexers=reindexers,
299305
fill_value=fill_value,
300306
max_loaded_elems=max_loaded_elems,
307+
virtual_concat=virtual_concat,
301308
)
302309

303310

@@ -306,6 +313,7 @@ def _write_concat_arrays( # noqa: PLR0913, PLR0917
306313
output_group: ZarrGroup | H5Group,
307314
output_path: str | Path,
308315
max_loaded_elems: int,
316+
virtual_concat: bool = False,
309317
axis: Literal[0, 1] = 0,
310318
reindexers: list[Reindexer] | None = None,
311319
fill_value: Any = None,
@@ -331,17 +339,24 @@ def _write_concat_arrays( # noqa: PLR0913, PLR0917
331339
arrays,
332340
output_group,
333341
output_path,
334-
max_loaded_elems,
335-
axis,
336-
reindexers,
337-
fill_value,
342+
max_loaded_elems=max_loaded_elems,
343+
virtual_concat=virtual_concat,
344+
axis=axis,
345+
reindexers=reindexers,
346+
fill_value=fill_value,
338347
)
339348
else:
340349
msg = f"Concat of following not supported: {[a.format for a in arrays]}"
341350
raise NotImplementedError(msg)
342351
else:
343352
write_concat_dense(
344-
arrays, output_group, output_path, axis, reindexers, fill_value
353+
arrays,
354+
output_group,
355+
output_path,
356+
virtual_concat=virtual_concat,
357+
axis=axis,
358+
reindexers=reindexers,
359+
fill_value=fill_value,
345360
)
346361

347362

@@ -354,6 +369,7 @@ def _write_concat_sequence( # noqa: PLR0913, PLR0917
354369
index: pd.Index | None = None,
355370
reindexers: list[Reindexer] | None = None,
356371
fill_value: Any = None,
372+
virtual_concat: bool = False,
357373
join: Join_T = "inner",
358374
):
359375
"""
@@ -388,11 +404,12 @@ def _write_concat_sequence( # noqa: PLR0913, PLR0917
388404
arrays,
389405
output_group,
390406
output_path,
391-
max_loaded_elems,
392-
axis,
393-
reindexers,
394-
fill_value,
395-
join,
407+
max_loaded_elems=max_loaded_elems,
408+
virtual_concat=virtual_concat,
409+
axis=axis,
410+
reindexers=reindexers,
411+
fill_value=fill_value,
412+
join=join,
396413
)
397414
else:
398415
msg = f"Concatenation of these types is not yet implemented: {[type(a) for a in arrays]} with axis={axis}."
@@ -470,6 +487,7 @@ def concat_on_disk( # noqa: PLR0913
470487
out_file: PathLike[str] | str | H5Group | ZarrGroup,
471488
*,
472489
max_loaded_elems: int = 100_000_000,
490+
virtual_concat: bool = False,
473491
axis: Literal["obs", 0, "var", 1] = 0,
474492
join: Join_T = "inner",
475493
merge: StrategiesLiteral | Callable[[Collection[Mapping]], Mapping] | None = None,
@@ -499,10 +517,9 @@ def concat_on_disk( # noqa: PLR0913
499517
see the Dask documentation, as the Dask concatenation function is used
500518
to concatenate dense arrays in this function.
501519
502-
For sparse arrays, if the backend is hdf5 and there is no reindexing,
520+
For sparse arrays, if the backend is hdf5 and there is no reindexing and
521+
`virtual_concat` is True,
503522
the virtual concatenation is used using the `h5py` virtual dataset support.
504-
This will create soft links to the source files instead of copying the whole content.
505-
Be aware that this will make the output file dependent on the source files.
506523
507524
Params
508525
------
@@ -517,6 +534,11 @@ def concat_on_disk( # noqa: PLR0913
517534
sparse arrays. Note that this number also includes the empty entries.
518535
Set to 100m by default meaning roughly 400mb will be loaded
519536
to memory simultaneously.
537+
virtual_concat
538+
Whether to use virtual concatenation for sparse arrays.
539+
This will create soft links to the source files instead of copying the whole content.
540+
Be aware that this will make the output file dependent on the source files.
541+
This is False by default.
520542
axis
521543
Which axis to concatenate along.
522544
join
@@ -664,6 +686,7 @@ def concat_on_disk( # noqa: PLR0913
664686
label=label,
665687
index_unique=index_unique,
666688
fill_value=fill_value,
689+
virtual_concat=virtual_concat,
667690
merge=merge,
668691
)
669692

@@ -681,6 +704,7 @@ def _concat_on_disk_inner( # noqa: PLR0913
681704
label: str | None,
682705
index_unique: str | None,
683706
fill_value: Any | None,
707+
virtual_concat: bool = False,
684708
merge: Callable[[Collection[Mapping]], Mapping],
685709
) -> None:
686710
"""Internal helper to minimize the amount of indented code within the context manager"""
@@ -754,6 +778,7 @@ def _concat_on_disk_inner( # noqa: PLR0913
754778
reindexers=reindexers,
755779
fill_value=fill_value,
756780
max_loaded_elems=max_loaded_elems,
781+
virtual_concat=virtual_concat,
757782
)
758783

759784
# Write Layers and {axis_name}m
@@ -774,6 +799,7 @@ def _concat_on_disk_inner( # noqa: PLR0913
774799
intersect_keys(maps),
775800
m,
776801
max_loaded_elems=max_loaded_elems,
802+
virtual_concat=virtual_concat,
777803
axis=m_axis,
778804
index=m_index,
779805
reindexers=m_reindexers,

tests/test_concatenate_disk.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def assert_eq_concat_on_disk(
9898
adatas,
9999
tmp_path: Path,
100100
file_format: Literal["zarr", "h5ad"],
101+
*,
101102
max_loaded_elems: int | None = None,
103+
virtual_concat: bool = False,
102104
*args,
103105
merge_strategy: merge.StrategiesLiteral | None = None,
104106
**kwargs,
@@ -110,7 +112,7 @@ def assert_eq_concat_on_disk(
110112
out_name = tmp_path / f"out.{file_format}"
111113
if max_loaded_elems is not None:
112114
kwargs["max_loaded_elems"] = max_loaded_elems
113-
concat_on_disk(paths, out_name, *args, merge=merge_strategy, **kwargs)
115+
concat_on_disk(paths, out_name, *args, virtual_concat=virtual_concat, merge=merge_strategy, **kwargs)
114116
with as_group(out_name, mode="r") as rg:
115117
res2 = read_elem(rg)
116118
assert_equal(res1, res2, exact=False)
@@ -181,6 +183,48 @@ def test_anndatas(
181183
)
182184

183185

186+
def test_anndatas_virtual_concat(
187+
*,
188+
tmp_path: Path,
189+
):
190+
axis = 0
191+
max_loaded_elems = 1_000_000
192+
file_format = "h5ad"
193+
array_type = "sparse"
194+
join_type = "inner"
195+
_, off_axis_name = _resolve_axis(1 - axis)
196+
random_axes = {0, 1}
197+
sparse_fmt = "csr" if axis == 0 else "csc"
198+
kw = GEN_ADATA_OOC_CONCAT_ARGS
199+
200+
adatas = []
201+
for i in range(3):
202+
M, N = (np.random.randint(5, 10) if a in random_axes else 50 for a in (0, 1))
203+
a = gen_adata(
204+
(M, N),
205+
X_type=get_array_type(array_type, axis),
206+
sparse_fmt=sparse_fmt,
207+
obs_dtypes=[pd.CategoricalDtype(ordered=False)],
208+
var_dtypes=[pd.CategoricalDtype(ordered=False)],
209+
**kw,
210+
)
211+
# ensure some names overlap, others do not, for the off-axis so that inner/outer is properly tested
212+
off_names = getattr(a, f"{off_axis_name}_names").array
213+
off_names[1::2] = f"{i}-" + off_names[1::2]
214+
setattr(a, f"{off_axis_name}_names", off_names)
215+
adatas.append(a)
216+
217+
assert_eq_concat_on_disk(
218+
adatas,
219+
tmp_path,
220+
file_format,
221+
max_loaded_elems=max_loaded_elems,
222+
virtual_concat=True,
223+
axis=axis,
224+
join=join_type,
225+
)
226+
227+
184228
def test_concat_ordered_categoricals_retained(tmp_path, file_format):
185229
a = AnnData(
186230
X=np.ones((5, 1)),

0 commit comments

Comments
 (0)