Skip to content

Commit 2a9afed

Browse files
meeseeksmachineilia-katsilan-gold
authored
Backport PR #1911: fix outer concatenation along var when varm is not empty (#1936)
Co-authored-by: ilia-kats <ilia-kats@gmx.net> Co-authored-by: Ilan Gold <ilanbassgold@gmail.com>
1 parent 632dd95 commit 2a9afed

3 files changed

Lines changed: 33 additions & 13 deletions

File tree

docs/release-notes/1911.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix concatenation of {class}`anndata.AnnData` objects along `var` using `join="outer"` when `varm` is not empty. {user}`ilia-kats`

src/anndata/_core/merge.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -853,17 +853,23 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
853853
)
854854

855855

856-
def inner_concat_aligned_mapping(mappings, *, reindexers=None, index=None, axis=0):
856+
def inner_concat_aligned_mapping(
857+
mappings, *, reindexers=None, index=None, axis=0, concat_axis=None
858+
):
859+
if concat_axis is None:
860+
concat_axis = axis
857861
result = {}
858862

859863
for k in intersect_keys(mappings):
860864
els = [m[k] for m in mappings]
861865
if reindexers is None:
862-
cur_reindexers = gen_inner_reindexers(els, new_index=index, axis=axis)
866+
cur_reindexers = gen_inner_reindexers(
867+
els, new_index=index, axis=concat_axis
868+
)
863869
else:
864870
cur_reindexers = reindexers
865871

866-
result[k] = concat_arrays(els, cur_reindexers, index=index, axis=axis)
872+
result[k] = concat_arrays(els, cur_reindexers, index=index, axis=concat_axis)
867873
return result
868874

869875

@@ -959,15 +965,19 @@ def missing_element(
959965

960966

961967
def outer_concat_aligned_mapping(
962-
mappings, *, reindexers=None, index=None, axis=0, fill_value=None
968+
mappings, *, reindexers=None, index=None, axis=0, concat_axis=None, fill_value=None
963969
):
970+
if concat_axis is None:
971+
concat_axis = axis
964972
result = {}
965973
ns = [m.parent.shape[axis] for m in mappings]
966974

967975
for k in union_keys(mappings):
968976
els = [m.get(k, MissingVal) for m in mappings]
969977
if reindexers is None:
970-
cur_reindexers = gen_outer_reindexers(els, ns, new_index=index, axis=axis)
978+
cur_reindexers = gen_outer_reindexers(
979+
els, ns, new_index=index, axis=concat_axis
980+
)
971981
else:
972982
cur_reindexers = reindexers
973983

@@ -986,15 +996,15 @@ def outer_concat_aligned_mapping(
986996
if not_missing(el)
987997
else missing_element(
988998
n,
989-
axis=axis,
999+
axis=concat_axis,
9901000
els=els,
9911001
fill_value=fill_value,
9921002
off_axis_size=off_axis_size,
9931003
)
9941004
for el, n in zip(els, ns)
9951005
],
9961006
cur_reindexers,
997-
axis=axis,
1007+
axis=concat_axis,
9981008
index=index,
9991009
fill_value=fill_value,
10001010
)
@@ -1368,7 +1378,10 @@ def concat(
13681378
[a.layers for a in adatas], axis=axis, reindexers=reindexers
13691379
)
13701380
concat_mapping = concat_aligned_mapping(
1371-
[getattr(a, f"{axis_name}m") for a in adatas], index=concat_indices
1381+
[getattr(a, f"{axis_name}m") for a in adatas],
1382+
axis=axis,
1383+
concat_axis=0,
1384+
index=concat_indices,
13721385
)
13731386
if pairwise:
13741387
concat_pairwise = concat_pairwise_mapping(

tests/test_concatenate.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,15 +1442,21 @@ def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape):
14421442

14431443

14441444
@pytest.mark.parametrize("elem", ["sparse", "array", "df", "da"])
1445-
def test_concat_outer_aligned_mapping(elem):
1445+
@pytest.mark.parametrize("axis", ["obs", "var"])
1446+
def test_concat_outer_aligned_mapping(elem, axis):
14461447
a = gen_adata((5, 5), **GEN_ADATA_DASK_ARGS)
14471448
b = gen_adata((3, 5), **GEN_ADATA_DASK_ARGS)
1448-
del b.obsm[elem]
1449+
del getattr(b, f"{axis}m")[elem]
14491450

1450-
concated = concat({"a": a, "b": b}, join="outer", label="group")
1451-
result = concated[concated.obs["group"] == "b"].obsm[elem]
1451+
concated = concat({"a": a, "b": b}, join="outer", label="group", axis=axis)
14521452

1453-
check_filled_like(result, elem_name=f"obsm/{elem}")
1453+
mask = getattr(concated, axis)["group"] == "b"
1454+
result = getattr(
1455+
concated[(mask, slice(None)) if axis == "obs" else (slice(None), mask)],
1456+
f"{axis}m",
1457+
)[elem]
1458+
1459+
check_filled_like(result, elem_name=f"{axis}m/{elem}")
14541460

14551461

14561462
@mark_legacy_concatenate

0 commit comments

Comments
 (0)