Skip to content

Commit a5deb2a

Browse files
fix outer concatenation along var when varm is not empty (#1911)
* fix outer concatenation along var when varm is not empty * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add release notes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent eea9002 commit a5deb2a

3 files changed

Lines changed: 33 additions & 13 deletions

File tree

docs/release-notes/1911.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix concatenation of {class}`anndata.AnnData` objects along `var` using `join="outer"` when `varm` is not empty. {user}`ilia-kats`

src/anndata/_core/merge.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -878,17 +878,23 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
878878
)
879879

880880

881-
def inner_concat_aligned_mapping(mappings, *, reindexers=None, index=None, axis=0):
881+
def inner_concat_aligned_mapping(
882+
mappings, *, reindexers=None, index=None, axis=0, concat_axis=None
883+
):
884+
if concat_axis is None:
885+
concat_axis = axis
882886
result = {}
883887

884888
for k in intersect_keys(mappings):
885889
els = [m[k] for m in mappings]
886890
if reindexers is None:
887-
cur_reindexers = gen_inner_reindexers(els, new_index=index, axis=axis)
891+
cur_reindexers = gen_inner_reindexers(
892+
els, new_index=index, axis=concat_axis
893+
)
888894
else:
889895
cur_reindexers = reindexers
890896

891-
result[k] = concat_arrays(els, cur_reindexers, index=index, axis=axis)
897+
result[k] = concat_arrays(els, cur_reindexers, index=index, axis=concat_axis)
892898
return result
893899

894900

@@ -984,15 +990,19 @@ def missing_element(
984990

985991

986992
def outer_concat_aligned_mapping(
987-
mappings, *, reindexers=None, index=None, axis=0, fill_value=None
993+
mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None
988994
):
995+
if concat_axis is None:
996+
concat_axis = axis
989997
result = {}
990998
ns = [m.parent.shape[axis] for m in mappings]
991999

9921000
for k in union_keys(mappings):
9931001
els = [m.get(k, MissingVal) for m in mappings]
9941002
if reindexers is None:
995-
cur_reindexers = gen_outer_reindexers(els, ns, new_index=index, axis=axis)
1003+
cur_reindexers = gen_outer_reindexers(
1004+
els, ns, new_index=index, axis=concat_axis
1005+
)
9961006
else:
9971007
cur_reindexers = reindexers
9981008

@@ -1011,15 +1021,15 @@ def outer_concat_aligned_mapping(
10111021
if not_missing(el)
10121022
else missing_element(
10131023
n,
1014-
axis=axis,
1024+
axis=concat_axis,
10151025
els=els,
10161026
fill_value=fill_value,
10171027
off_axis_size=off_axis_size,
10181028
)
10191029
for el, n in zip(els, ns)
10201030
],
10211031
cur_reindexers,
1022-
axis=axis,
1032+
axis=concat_axis,
10231033
index=index,
10241034
fill_value=fill_value,
10251035
)
@@ -1606,7 +1616,10 @@ def concat(
16061616
[a.layers for a in adatas], axis=axis, reindexers=reindexers
16071617
)
16081618
concat_mapping = concat_aligned_mapping(
1609-
[getattr(a, f"{axis_name}m") for a in adatas], index=concat_indices
1619+
[getattr(a, f"{axis_name}m") for a in adatas],
1620+
axis=axis,
1621+
concat_axis=0,
1622+
index=concat_indices,
16101623
)
16111624
if pairwise:
16121625
concat_pairwise = concat_pairwise_mapping(

tests/test_concatenate.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,15 +1440,21 @@ def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape):
14401440

14411441

14421442
@pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"])
1443-
def test_concat_outer_aligned_mapping(elem):
1443+
@pytest.mark.parametrize("axis", ["obs", "var"])
1444+
def test_concat_outer_aligned_mapping(elem, axis):
14441445
a = gen_adata((5, 5), **GEN_ADATA_DASK_ARGS)
14451446
b = gen_adata((3, 5), **GEN_ADATA_DASK_ARGS)
1446-
del b.obsm[elem]
1447+
del getattr(b, f"{axis}m")[elem]
14471448

1448-
concated = concat({"a": a, "b": b}, join="outer", label="group")
1449-
result = concated[concated.obs["group"] == "b"].obsm[elem]
1449+
concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis)
14501450

1451-
check_filled_like(result, elem_name=f"obsm/{elem}")
1451+
mask = getattr(concated, axis)["group"] == "b"
1452+
result = getattr(
1453+
concated[(mask, slice(None)) if axis == "obs" else (slice(None), mask)],
1454+
f"{axis}m",
1455+
)[elem]
1456+
1457+
check_filled_like(result, elem_name=f"{axis}m/{elem}")
14521458

14531459

14541460
@mark_legacy_concatenate

0 commit comments

Comments
 (0)