Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/1911.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix concatenation of {class}`anndata.AnnData` objects along `var` using `join="outer"` when `varm` is not empty. {user}`ilia-kats`
29 changes: 21 additions & 8 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,17 +878,23 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
)


def inner_concat_aligned_mapping(mappings, *, reindexers=None, index=None, axis=0):
def inner_concat_aligned_mapping(
mappings, *, reindexers=None, index=None, axis=0, concat_axis=None
):
if concat_axis is None:
concat_axis = axis
result = {}

for k in intersect_keys(mappings):
els = [m[k] for m in mappings]
if reindexers is None:
cur_reindexers = gen_inner_reindexers(els, new_index=index, axis=axis)
cur_reindexers = gen_inner_reindexers(
els, new_index=index, axis=concat_axis
)
else:
cur_reindexers = reindexers

result[k] = concat_arrays(els, cur_reindexers, index=index, axis=axis)
result[k] = concat_arrays(els, cur_reindexers, index=index, axis=concat_axis)
return result


Expand Down Expand Up @@ -984,15 +990,19 @@ def missing_element(


def outer_concat_aligned_mapping(
mappings, *, reindexers=None, index=None, axis=0, fill_value=None
mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None
):
if concat_axis is None:
concat_axis = axis
result = {}
ns = [m.parent.shape[axis] for m in mappings]

for k in union_keys(mappings):
els = [m.get(k, MissingVal) for m in mappings]
if reindexers is None:
cur_reindexers = gen_outer_reindexers(els, ns, new_index=index, axis=axis)
cur_reindexers = gen_outer_reindexers(
els, ns, new_index=index, axis=concat_axis
)
else:
cur_reindexers = reindexers

Expand All @@ -1011,15 +1021,15 @@ def outer_concat_aligned_mapping(
if not_missing(el)
else missing_element(
n,
axis=axis,
axis=concat_axis,
els=els,
fill_value=fill_value,
off_axis_size=off_axis_size,
)
for el, n in zip(els, ns)
],
cur_reindexers,
axis=axis,
axis=concat_axis,
index=index,
fill_value=fill_value,
)
Expand Down Expand Up @@ -1606,7 +1616,10 @@ def concat(
[a.layers for a in adatas], axis=axis, reindexers=reindexers
)
concat_mapping = concat_aligned_mapping(
[getattr(a, f"{axis_name}m") for a in adatas], index=concat_indices
[getattr(a, f"{axis_name}m") for a in adatas],
axis=axis,
concat_axis=0,
index=concat_indices,
)
if pairwise:
concat_pairwise = concat_pairwise_mapping(
Expand Down
16 changes: 11 additions & 5 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,15 +1440,21 @@ def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape):


@pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"])
def test_concat_outer_aligned_mapping(elem):
@pytest.mark.parametrize("axis", ["obs", "var"])
def test_concat_outer_aligned_mapping(elem, axis):
a = gen_adata((5, 5), **GEN_ADATA_DASK_ARGS)
b = gen_adata((3, 5), **GEN_ADATA_DASK_ARGS)
del b.obsm[elem]
del getattr(b, f"{axis}m")[elem]

concated = concat({"a": a, "b": b}, join="outer", label="group")
result = concated[concated.obs["group"] == "b"].obsm[elem]
concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis)

check_filled_like(result, elem_name=f"obsm/{elem}")
mask = getattr(concated, axis)["group"] == "b"
result = getattr(
concated[(mask, slice(None)) if axis == "obs" else (slice(None), mask)],
f"{axis}m",
)[elem]

check_filled_like(result, elem_name=f"{axis}m/{elem}")


@mark_legacy_concatenate
Expand Down