diff --git a/docs/release-notes/1911.bugfix.md b/docs/release-notes/1911.bugfix.md new file mode 100755 index 000000000..b0df5cd97 --- /dev/null +++ b/docs/release-notes/1911.bugfix.md @@ -0,0 +1 @@ +Fix concatenation of {class}`anndata.AnnData` objects along `var` using `join="outer"` when `varm` is not empty. {user}`ilia-kats` diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index c3daa9bdb..37308b31f 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -853,17 +853,23 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): ) -def inner_concat_aligned_mapping(mappings, *, reindexers=None, index=None, axis=0): +def inner_concat_aligned_mapping( + mappings, *, reindexers=None, index=None, axis=0, concat_axis=None +): + if concat_axis is None: + concat_axis = axis result = {} for k in intersect_keys(mappings): els = [m[k] for m in mappings] if reindexers is None: - cur_reindexers = gen_inner_reindexers(els, new_index=index, axis=axis) + cur_reindexers = gen_inner_reindexers( + els, new_index=index, axis=concat_axis + ) else: cur_reindexers = reindexers - result[k] = concat_arrays(els, cur_reindexers, index=index, axis=axis) + result[k] = concat_arrays(els, cur_reindexers, index=index, axis=concat_axis) return result @@ -959,15 +965,19 @@ def missing_element( def outer_concat_aligned_mapping( - mappings, *, reindexers=None, index=None, axis=0, fill_value=None + mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None ): + if concat_axis is None: + concat_axis = axis result = {} ns = [m.parent.shape[axis] for m in mappings] for k in union_keys(mappings): els = [m.get(k, MissingVal) for m in mappings] if reindexers is None: - cur_reindexers = gen_outer_reindexers(els, ns, new_index=index, axis=axis) + cur_reindexers = gen_outer_reindexers( + els, ns, new_index=index, axis=concat_axis + ) else: cur_reindexers = reindexers @@ -986,7 +996,7 @@ def outer_concat_aligned_mapping( if not_missing(el) else missing_element( n, - axis=axis, + axis=concat_axis, els=els, fill_value=fill_value, off_axis_size=off_axis_size, @@ -994,7 +1004,7 @@ def outer_concat_aligned_mapping( for el, n in zip(els, ns) ], cur_reindexers, - axis=axis, + axis=concat_axis, index=index, fill_value=fill_value, ) @@ -1368,7 +1378,10 @@ def concat( [a.layers for a in adatas], axis=axis, reindexers=reindexers ) concat_mapping = concat_aligned_mapping( - [getattr(a, f"{axis_name}m") for a in adatas], index=concat_indices + [getattr(a, f"{axis_name}m") for a in adatas], + axis=axis, + concat_axis=0, + index=concat_indices, ) if pairwise: concat_pairwise = concat_pairwise_mapping( diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 622397d88..425217755 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1442,15 +1442,21 @@ def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape): @pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"]) -def test_concat_outer_aligned_mapping(elem): +@pytest.mark.parametrize("axis", ["obs", "var"]) +def test_concat_outer_aligned_mapping(elem, axis): a = gen_adata((5, 5), **GEN_ADATA_DASK_ARGS) b = gen_adata((3, 5), **GEN_ADATA_DASK_ARGS) - del b.obsm[elem] + del getattr(b, f"{axis}m")[elem] - concated = concat({"a": a, "b": b}, join="outer", label="group") - result = concated[concated.obs["group"] == "b"].obsm[elem] + concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis) - check_filled_like(result, elem_name=f"obsm/{elem}") + mask = getattr(concated, axis)["group"] == "b" + result = getattr( + concated[(mask, slice(None)) if axis == "obs" else (slice(None), mask)], + f"{axis}m", + )[elem] + + check_filled_like(result, elem_name=f"{axis}m/{elem}") @mark_legacy_concatenate