diff --git a/docs/release-notes/2416.feat.md b/docs/release-notes/2416.feat.md new file mode 100644 index 000000000..68cbf55d7 --- /dev/null +++ b/docs/release-notes/2416.feat.md @@ -0,0 +1 @@ +Add `aligned_axis_key_join` to {func}`anndata.concat` for controlling on-axis annotation and aligned-mapping key joins (obs/var columns, obsm/varm and obsp/varp keys, layers keys) independently of the off-axis index alignment controlled by `join` {user}`Ekin-Kahraman` diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index b63b7bf84..a882348f1 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -10,7 +10,7 @@ from functools import partial, reduce, singledispatch from itertools import repeat from operator import and_, or_, sub -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, Literal, cast, get_args import numpy as np import pandas as pd @@ -18,6 +18,7 @@ from scipy import sparse from anndata._core.file_backing import to_memory +from anndata._types import Join_T from anndata._warnings import ExperimentalFeatureWarning from ..compat import ( @@ -35,6 +36,8 @@ from .index import _subset, make_slice from .xarray import Dataset2D +JOIN_OPTIONS = get_args(Join_T.__value__) + if TYPE_CHECKING: from collections.abc import Collection, Generator, Iterable, Sequence from typing import Any @@ -42,8 +45,6 @@ from numpy.typing import NDArray from pandas.api.extensions import ExtensionDtype - from anndata._types import Join_T - from ..compat import XDataArray from ..types import SupportsArrayApi @@ -948,30 +949,126 @@ def concat_arrays( # noqa: PLR0911, PLR0912 ) -def inner_concat_aligned_mapping( +def inner_concat_aligned_mapping( # noqa: PLR0913 mappings, *, + keys, reindexers=None, index=None, axis=0, concat_axis=None, force_lazy: bool = False, + fill_value=None, ): + """Inner-content concat of aligned mappings over an explicit key set.""" if concat_axis is None: concat_axis = axis + result = {} + ns = [m.parent.shape[axis] for m in mappings] - for k in intersect_keys(mappings): - els = [m[k] for m in mappings] - if reindexers is None: - cur_reindexers = gen_inner_reindexers( - els, new_index=index, axis=concat_axis + for k in keys: + els = [m.get(k, MissingVal) for m in mappings] + any_missing = any(is_missing(el) for el in els) + + if not any_missing: + # All keys present in all mappings — the default + # ``intersect_keys`` path. No fill_value handling needed. + if reindexers is None: + cur_reindexers = gen_inner_reindexers( + els, new_index=index, axis=concat_axis + ) + else: + cur_reindexers = reindexers + result[k] = concat_arrays( + els, + cur_reindexers, + index=index, + axis=concat_axis, + force_lazy=force_lazy, ) + continue + + # Missing-key path: only reachable when caller passes an explicit + # ``keys`` set wider than ``intersect_keys`` (the outer-key + inner- + # content combination from ``concat(aligned_axis_key_join=...)``). + present_els = [el for el in els if not_missing(el)] + + if any(isinstance(el, AwkArray) for el in present_els): + msg = ( + "Combining `aligned_axis_key_join` with `join='inner'` is " + "not yet implemented for awkward arrays in `obsm`/`varm` " + "when the key is missing from at least one input. Use the " + "same value for `join` and `aligned_axis_key_join`, or " + "drop the affected awkward entries before concatenating." + ) + raise NotImplementedError(msg) + + if reindexers is not None: + # Caller-provided reindexers already encode the alt-axis alignment + # across *all* mappings (e.g. the gene-axis intersection for + # ``layers`` when ``join="inner"``). The present-only reindexers + # below would intersect over the present subset only, which is + # wrong for layers because their alt-axis must match X's. Honour + # the caller's reindexers and drop in an identity reindexer for + # missing entries; the filler created by ``missing_element`` + # below uses the matching alt-axis size. + target_idx = reindexers[0].new_idx + cur_reindexers = [ + reindexers[i] if not_missing(el) else Reindexer(target_idx, target_idx) + for i, el in enumerate(els) + ] + off_axis_size = len(target_idx) + elif all(isinstance(el, pd.DataFrame) for el in present_els): + common_cols = reduce( + lambda x, y: x.intersection(y), + (el.columns for el in present_els), + ) + cur_reindexers = [ + Reindexer(el.columns, common_cols) + if not_missing(el) + else ( + lambda _, n=n, cols=common_cols, fv=fill_value: pd.DataFrame( + np.nan if fv is None else fv, + index=range(n), + columns=cols, + ) + ) + for el, n in zip(els, ns, strict=True) + ] + off_axis_size = 0 else: - cur_reindexers = reindexers + inner_present = gen_inner_reindexers( + present_els, new_index=index, axis=concat_axis + ) + target_idx = inner_present[0].new_idx + present_iter = iter(inner_present) + cur_reindexers = [ + next(present_iter) + if not_missing(el) + else Reindexer(target_idx, target_idx) + for el in els + ] + off_axis_size = len(target_idx) result[k] = concat_arrays( - els, cur_reindexers, index=index, axis=concat_axis, force_lazy=force_lazy + [ + el + if not_missing(el) + else missing_element( + n, + axis=concat_axis, + els=els, + fill_value=fill_value, + off_axis_size=off_axis_size, + ) + for el, n in zip(els, ns, strict=True) + ], + cur_reindexers, + axis=concat_axis, + index=index, + fill_value=fill_value, + force_lazy=force_lazy, ) return result @@ -1081,9 +1178,10 @@ def missing_element( return xp.zeros(shape, dtype=bool) -def outer_concat_aligned_mapping( +def outer_concat_aligned_mapping( # noqa: PLR0913 mappings, *, + keys, reindexers=None, index=None, axis=0, @@ -1091,12 +1189,14 @@ def outer_concat_aligned_mapping( fill_value=None, force_lazy: bool = False, ): + """Outer-content concat of aligned mappings over an explicit key set.""" if concat_axis is None: concat_axis = axis + result = {} ns = [m.parent.shape[axis] for m in mappings] - for k in union_keys(mappings): + for k in keys: els = [m.get(k, MissingVal) for m in mappings] if reindexers is None: cur_reindexers = gen_outer_reindexers( @@ -1423,6 +1523,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 *, axis: Literal["obs", 0, "var", 1] = "obs", join: Join_T | Default = Default("inner"), # noqa: B008 + aligned_axis_key_join: Join_T | None = None, merge: StrategiesLiteral | Callable | None = None, uns_merge: StrategiesLiteral | Callable | None = None, label: str | None = None, @@ -1447,6 +1548,15 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 How to align values when concatenating. If "outer", the union of the other axis is taken. If "inner", the intersection. See :doc:`concatenation <../concatenation>` for more. + aligned_axis_key_join + How to join keys on the *concatenation axis* itself: columns of `obs`/`var`, + keys of `obsm`/`obsp` (or `varm`/`varp` when concatenating along `axis="var"`), + and keys of `layers`. Use "outer" to take the union of these keys, "inner" + to take the intersection. The off-axis content of each value (e.g. the var + index of an obsm DataFrame, or the gene axis of a layer) still follows + `join`. Defaults to `None`, in which case `join` is used for both the + off-axis index alignment and the on-axis key join (the historical + behaviour). merge How elements not aligned to the axis being concatenated along are selected. Currently implemented strategies include: @@ -1633,6 +1743,18 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 {'c': {'c.c': 5}} >>> dict(ad.concat([a, b, c], uns_merge="first").uns) {'a': 1, 'b': 2, 'c': {'c.a': 3, 'c.b': 4, 'c.c': 5}} + + `aligned_axis_key_join` controls on-axis key joining (obs columns, + obsm/obsp keys) independently of the off-axis index `join`. The default + of `None` falls back to `join`, preserving existing behaviour. To keep + the union of `var` indices but the intersection of obs columns: + + >>> ad.concat( + ... [a, b], join="outer", aligned_axis_key_join="inner" + ... ).obs.columns.tolist() + ['group'] + >>> ad.concat([a, b], join="outer").obs.columns.tolist() + ['group', 'measure'] """ from anndata._core.xarray import Dataset2D @@ -1642,6 +1764,13 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 merge = resolve_merge_strategy(merge) uns_merge = resolve_merge_strategy(uns_merge) + if aligned_axis_key_join is not None and aligned_axis_key_join not in JOIN_OPTIONS: + msg = ( + f"`aligned_axis_key_join` must be one of 'inner', 'outer', or None, " + f"got {aligned_axis_key_join!r}" + ) + raise ValueError(msg) + if isinstance(adatas, Mapping): if keys is not None: msg = ( @@ -1665,6 +1794,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ) warn(msg, UserWarning) + # Resolve the on-axis key join. When aligned_axis_key_join is None, the + # historical behaviour applies and `join` controls both off-axis index + # alignment and on-axis key joining. When set, it overrides the on-axis + # key joining for obs/obsm/obsp (or var/varm/varp when axis="var"). + # Resolved after the `Default` unwrap above so `aligned_join` is always + # a plain string. + aligned_join = aligned_axis_key_join if aligned_axis_key_join is not None else join + if keys is None: keys = np.arange(len(adatas)).astype(str) @@ -1695,7 +1832,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ] # Annotation for concatenation axis - check_combinable_cols([getattr(a, axis_name).columns for a in adatas], join=join) + check_combinable_cols( + [getattr(a, axis_name).columns for a in adatas], join=aligned_join + ) annotations = [getattr(a, axis_name) for a in adatas] are_any_annotations_dataframes = any( isinstance(a, pd.DataFrame) for a in annotations @@ -1706,14 +1845,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 ) concat_annot = pd.concat( unify_dtypes(annotations_in_memory), - join=join, + join=aligned_join, ignore_index=True, ) concat_annot.index = concat_indices else: concat_annot = concat_dataset2d_on_annot_axis( annotations, - join, + aligned_join, force_lazy=force_lazy, concat_indices=concat_indices, ) @@ -1752,31 +1891,44 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 if join == "inner": concat_aligned_mapping = inner_concat_aligned_mapping - join_keys = intersect_keys elif join == "outer": concat_aligned_mapping = partial( outer_concat_aligned_mapping, fill_value=fill_value ) - join_keys = union_keys else: msg = f"{join=} should have been validated above by pd.concat" raise AssertionError(msg) + if aligned_join == "inner": + aligned_join_keys = intersect_keys + elif aligned_join == "outer": + aligned_join_keys = union_keys + else: + msg = f"{aligned_join=} should have been validated" + raise AssertionError(msg) + + layer_mappings = [a.layers for a in adatas] layers = concat_aligned_mapping( - [a.layers for a in adatas], axis=axis, reindexers=reindexers + layer_mappings, + axis=axis, + reindexers=reindexers, + keys=aligned_join_keys(layer_mappings), ) + + obsm_mappings = [getattr(a, f"{axis_name}m") for a in adatas] concat_mapping = concat_aligned_mapping( - [getattr(a, f"{axis_name}m") for a in adatas], + obsm_mappings, axis=axis, concat_axis=0, index=concat_indices, force_lazy=force_lazy, + keys=aligned_join_keys(obsm_mappings), ) if pairwise: concat_pairwise = concat_pairwise_mapping( mappings=[getattr(a, f"{axis_name}p") for a in adatas], shapes=[a.shape[axis] for a in adatas], - join_keys=join_keys, + join_keys=aligned_join_keys, ) else: concat_pairwise = {} @@ -1808,6 +1960,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 for a in adatas ], join=join, + aligned_axis_key_join=aligned_axis_key_join, label=label, keys=keys, index_unique=index_unique, diff --git a/src/anndata/experimental/multi_files/_anncollection.py b/src/anndata/experimental/multi_files/_anncollection.py index b1ed27b78..b76208b9d 100644 --- a/src/anndata/experimental/multi_files/_anncollection.py +++ b/src/anndata/experimental/multi_files/_anncollection.py @@ -14,7 +14,7 @@ from ..._core.aligned_mapping import AxisArrays from ..._core.anndata import AnnData from ..._core.index import _normalize_index, _normalize_indices -from ..._core.merge import concat_arrays, inner_concat_aligned_mapping +from ..._core.merge import concat_arrays, inner_concat_aligned_mapping, intersect_keys from ..._core.sparse_dataset import BaseCompressedSparseDataset from ..._core.views import _resolve_idx from ...compat import old_positionals @@ -768,8 +768,9 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915 if join_obsm == "inner": view_attrs.remove("obsm") self._attrs.append("obsm") + obsm_mappings = [a.obsm for a in adatas] self._obsm = inner_concat_aligned_mapping( - [a.obsm for a in adatas], index=self.obs_names + obsm_mappings, keys=intersect_keys(obsm_mappings), index=self.obs_names ) self._obsm = ( AxisArrays(self, axis=0, store={}) if self._obsm == {} else self._obsm diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 37a8e6b75..502c95864 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1891,3 +1891,422 @@ def test_1d_concat(): adata = AnnData(np.ones((5, 20)), obsm={"1d-array": np.ones(5)}) concated = concat([adata, adata]) assert concated.obsm["1d-array"].shape == (10, 1) + + +# ------------------------------------------------------------------ +# aligned_axis_key_join (issue #2374) +# ------------------------------------------------------------------ + + +def _adatas_with_partial_overlap_along_axis(axis_name): + """Two AnnData with different on-axis annotation columns and aligned-mapping keys.""" + if axis_name == "var": + obs_a = pd.DataFrame(index=["row1", "row2"]) + obs_b = pd.DataFrame(index=["row1", "row2"]) + var_a = pd.DataFrame( + {"shared": ["a", "b"], "only_a": [1, 2]}, index=["v1", "v2"] + ) + var_b = pd.DataFrame( + {"shared": ["c", "d"], "only_b": [3, 4]}, index=["v3", "v4"] + ) + else: + obs_a = pd.DataFrame( + {"shared": ["a", "b"], "only_a": [1, 2]}, index=["s1", "s2"] + ) + obs_b = pd.DataFrame( + {"shared": ["c", "d"], "only_b": [3, 4]}, index=["s3", "s4"] + ) + var_a = pd.DataFrame(index=["v1", "v2"]) + var_b = pd.DataFrame(index=["v1", "v2"]) + + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=obs_a, + var=var_a, + **{ + f"{axis_name}m": { + "shared_m": np.ones((2, 3)), + "only_a_m": np.zeros((2, 3)), + } + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=obs_b, + var=var_b, + **{ + f"{axis_name}m": { + "shared_m": 2 * np.ones((2, 3)), + "only_b_m": np.ones((2, 3)), + } + }, + ) + return a, b + + +@pytest.mark.parametrize("axis_name", ["obs", "var"]) +def test_aligned_axis_key_join_default_falls_back_to_join(axis_name): + """When `aligned_axis_key_join=None`, the on-axis behaviour matches `join`.""" + a, b = _adatas_with_partial_overlap_along_axis(axis_name) + axis = 0 if axis_name == "obs" else 1 + + inner = concat([a, b], axis=axis) + inner_explicit = concat([a, b], axis=axis, aligned_axis_key_join=None) + assert list(getattr(inner, axis_name).columns) == list( + getattr(inner_explicit, axis_name).columns + ) + assert list(getattr(inner, f"{axis_name}m").keys()) == list( + getattr(inner_explicit, f"{axis_name}m").keys() + ) + # Inner is the historical default, only "shared" should remain. + assert list(getattr(inner, axis_name).columns) == ["shared"] + assert list(getattr(inner, f"{axis_name}m").keys()) == ["shared_m"] + + +@pytest.mark.parametrize("axis_name", ["obs", "var"]) +def test_aligned_axis_key_join_outer_with_inner_join(axis_name): + """`aligned_axis_key_join="outer"` unions on-axis keys while leaving off-axis as inner.""" + a, b = _adatas_with_partial_overlap_along_axis(axis_name) + axis = 0 if axis_name == "obs" else 1 + + res = concat([a, b], axis=axis, join="inner", aligned_axis_key_join="outer") + cols = list(getattr(res, axis_name).columns) + keys = list(getattr(res, f"{axis_name}m").keys()) + assert set(cols) == {"shared", "only_a", "only_b"} + assert set(keys) == {"shared_m", "only_a_m", "only_b_m"} + + +@pytest.mark.parametrize("axis_name", ["obs", "var"]) +def test_aligned_axis_key_join_inner_with_outer_join(axis_name): + """`aligned_axis_key_join="inner"` intersects on-axis keys while leaving off-axis as outer.""" + a, b = _adatas_with_partial_overlap_along_axis(axis_name) + axis = 0 if axis_name == "obs" else 1 + + res = concat([a, b], axis=axis, join="outer", aligned_axis_key_join="inner") + cols = list(getattr(res, axis_name).columns) + keys = list(getattr(res, f"{axis_name}m").keys()) + assert cols == ["shared"] + assert keys == ["shared_m"] + + +def test_aligned_axis_key_join_layer_keys_unioned_with_inner_content(): + """`aligned_axis_key_join="outer"` + `join="inner"` unions layer keys + while aligning each layer's off-axis (var) to the inner intersection. + + Mirrors the obsm contract: which keys appear is on-axis (controlled by + `aligned_axis_key_join`); how each kept value aligns along the alt-axis + is off-axis (controlled by `join`). + """ + a = AnnData( + X=np.ones((2, 3), dtype=np.float64), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2", "v3"]), + layers={ + "shared_layer": np.full((2, 3), 1.0), + "only_a_layer": np.full((2, 3), 7.0), + }, + ) + b = AnnData( + X=np.ones((2, 2), dtype=np.float64), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + layers={ + "shared_layer": np.full((2, 2), 2.0), + "only_b_layer": np.full((2, 2), 9.0), + }, + ) + + res = concat([a, b], join="inner", aligned_axis_key_join="outer") + + # Outer key join: every layer name appears. Post #1707, X is stored + # under the `None` layer key, so it shows up here too. + assert set(res.layers.keys()) == { + None, + "only_a_layer", + "only_b_layer", + "shared_layer", + } + + # Inner content join: alt-axis (var) is the intersection {v1, v2}. + n_total_cells = 4 + n_inner_genes = 2 + assert res.shape == (n_total_cells, n_inner_genes) + for k in ("shared_layer", "only_a_layer", "only_b_layer"): + assert res.layers[k].shape == (n_total_cells, n_inner_genes), ( + f"layer {k!r} should be aligned to inner alt-axis, " + f"got shape {res.layers[k].shape}" + ) + + # Spot-check content. shared_layer is present in both; values 1.0 from a + # and 2.0 from b stack into the inner gene set. + np.testing.assert_array_equal( + np.asarray(res.layers["shared_layer"]), + np.array([[1.0, 1.0], [1.0, 1.0], [2.0, 2.0], [2.0, 2.0]]), + ) + # only_a_layer is filled (with the missing-element default) for b's rows. + only_a = np.asarray(res.layers["only_a_layer"]) + np.testing.assert_array_equal(only_a[:2], np.array([[7.0, 7.0], [7.0, 7.0]])) + # only_b_layer is filled for a's rows; b's rows carry 9.0. + only_b = np.asarray(res.layers["only_b_layer"]) + np.testing.assert_array_equal(only_b[2:], np.array([[9.0, 9.0], [9.0, 9.0]])) + + +def test_aligned_axis_key_join_layer_keys_intersected_with_outer_content(): + """`aligned_axis_key_join="inner"` + `join="outer"` intersects layer + keys while aligning each kept layer's off-axis (var) to the outer + union. + """ + a = AnnData( + X=np.ones((2, 3), dtype=np.float64), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2", "v3"]), + layers={ + "shared_layer": np.full((2, 3), 1.0), + "only_a_layer": np.full((2, 3), 7.0), + }, + ) + b = AnnData( + X=np.ones((2, 2), dtype=np.float64), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + layers={ + "shared_layer": np.full((2, 2), 2.0), + "only_b_layer": np.full((2, 2), 9.0), + }, + ) + + res = concat([a, b], join="outer", aligned_axis_key_join="inner") + + # Inner key join: only the layer keys present in every input survive. + # Post #1707, X is stored under the `None` layer key and is present in + # both inputs, so it survives the intersection alongside "shared_layer". + assert set(res.layers.keys()) == {None, "shared_layer"} + + # Outer content join: alt-axis (var) is the union {v1, v2, v3}. + n_total_cells = 4 + n_outer_genes = 3 + assert res.shape == (n_total_cells, n_outer_genes) + assert res.layers["shared_layer"].shape == (n_total_cells, n_outer_genes) + + +def test_aligned_axis_key_join_does_not_affect_alt_axis_mappings(): + """When concatenating along obs, varm/varp follow `merge`, not `aligned_axis_key_join`.""" + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + varm={ + "shared_varm": np.ones((2, 3)), + "only_a_varm": np.zeros((2, 3)), + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + varm={ + "shared_varm": 2 * np.ones((2, 3)), + "only_b_varm": np.ones((2, 3)), + }, + ) + + # The contract: aligned_axis_key_join only governs the on-axis (obs here). + # varm is alt-axis, controlled entirely by `merge`. Compare results + # produced with and without aligned_axis_key_join for several merge + # strategies — varm contents must be identical. + for merge_strategy in (None, "same", "first", "unique", "only"): + kwargs = {"axis": "obs", "merge": merge_strategy} + baseline = concat([a, b], **kwargs) + with_inner = concat([a, b], aligned_axis_key_join="inner", **kwargs) + with_outer = concat([a, b], aligned_axis_key_join="outer", **kwargs) + assert sorted(baseline.varm.keys()) == sorted(with_inner.varm.keys()), ( + f"varm keys diverged under merge={merge_strategy!r}" + ) + assert sorted(baseline.varm.keys()) == sorted(with_outer.varm.keys()), ( + f"varm keys diverged under merge={merge_strategy!r}" + ) + + +def test_aligned_axis_key_join_awkward_inner_missing_key_raises(): + """Awkward arrays with inner content-join + missing keys raises NotImplementedError.""" + import awkward as ak + + a = AnnData( + X=np.eye(2, dtype=np.float64), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "shared_awk": ak.Array([[1, 2], [3]]), + "only_a_awk": ak.Array([[4], [5, 6]]), + }, + ) + b = AnnData( + X=np.eye(2, dtype=np.float64), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={"shared_awk": ak.Array([[7, 8], [9]])}, + ) + # join="inner" forces inner content-join; aligned_axis_key_join="outer" + # forces outer key set, so `only_a_awk` is missing from `b`. The combo of + # awkward + missing + inner content is the unimplemented branch. + with pytest.raises(NotImplementedError, match="awkward"): + concat([a, b], join="inner", aligned_axis_key_join="outer") + + +def test_aligned_axis_key_join_obsp_pairwise(): + """Pairwise on-axis (obsp) key joining responds to `aligned_axis_key_join`.""" + a = AnnData( + X=sparse.csr_matrix(np.eye(3, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2", "s3"]), + obsp={ + "shared_obsp": sparse.csr_matrix(np.eye(3, dtype=np.float64)), + "only_a_obsp": sparse.csr_matrix(np.eye(3, dtype=np.float64)), + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s4", "s5"]), + obsp={ + "shared_obsp": sparse.csr_matrix(np.eye(2, dtype=np.float64)), + "only_b_obsp": sparse.csr_matrix(np.eye(2, dtype=np.float64)), + }, + ) + outer = concat([a, b], pairwise=True, aligned_axis_key_join="outer") + inner = concat([a, b], pairwise=True, aligned_axis_key_join="inner") + assert sorted(outer.obsp.keys()) == ["only_a_obsp", "only_b_obsp", "shared_obsp"] + assert sorted(inner.obsp.keys()) == ["shared_obsp"] + + +def test_aligned_axis_key_join_invalid_value(): + """Invalid `aligned_axis_key_join` raises a clear error.""" + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + ) + with pytest.raises(ValueError, match="aligned_axis_key_join"): + concat([a, a], aligned_axis_key_join="banana") + + +def test_aligned_axis_key_join_obsm_dataframe_content_follows_join(): + """When `aligned_axis_key_join` differs from `join`, the on-axis key set + follows `aligned_axis_key_join` but the per-key content alignment (e.g. + DataFrame columns inside a shared `obsm[k]`) follows `join`. + """ + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "df": pd.DataFrame({"x": [1, 2], "a_only": [10, 20]}, index=["s1", "s2"]) + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "df": pd.DataFrame({"x": [3, 4], "b_only": [30, 40]}, index=["s3", "s4"]) + }, + ) + + # join="inner" + aligned="outer": obsm key set unioned (still just "df"), + # but the columns inside df should stay intersected. + res = concat([a, b], join="inner", aligned_axis_key_join="outer") + assert list(res.obsm["df"].columns) == ["x"] + + # join="outer" + aligned="inner": obsm key set intersected (still "df"), + # df columns should be unioned per join="outer". + res = concat([a, b], join="outer", aligned_axis_key_join="inner") + assert sorted(res.obsm["df"].columns) == ["a_only", "b_only", "x"] + + +def test_aligned_axis_key_join_inner_content_with_missing_key(): + """When `join="inner"` and `aligned_axis_key_join="outer"` with 3+ + inputs and a key present in only some, content alignment intersects + among the present values rather than unioning them. + """ + a = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "df": pd.DataFrame({"x": [1, 2], "a_only": [10, 20]}, index=["s1", "s2"]) + }, + ) + b = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={ + "df": pd.DataFrame({"x": [3, 4], "b_only": [30, 40]}, index=["s3", "s4"]) + }, + ) + c = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s5", "s6"]), + var=pd.DataFrame(index=["v1", "v2"]), + ) + + # DataFrame obsm: inner intersection among present DataFrames is just "x" + res = concat([a, b, c], aligned_axis_key_join="outer") + assert list(res.obsm["df"].columns) == ["x"] + + # ndarray obsm with mismatched widths: inner = min width + a_arr = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s1", "s2"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={"arr": np.ones((2, 3))}, + ) + b_arr = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s3", "s4"]), + var=pd.DataFrame(index=["v1", "v2"]), + obsm={"arr": np.ones((2, 5))}, + ) + c_empty = AnnData( + X=sparse.csr_matrix(np.eye(2, dtype=np.float64)), + obs=pd.DataFrame(index=["s5", "s6"]), + var=pd.DataFrame(index=["v1", "v2"]), + ) + res_arr = concat([a_arr, b_arr, c_empty], aligned_axis_key_join="outer") + assert res_arr.obsm["arr"].shape == (6, 3) + + +def test_aligned_axis_key_join_reproduces_issue_2374_example(): + """Direct reproduction of the example from issue #2374. Verifies that + `aligned_axis_key_join` controls on-axis annotation columns (obs/var) + independently from `join`, which keeps controlling the off-axis index + (var_names when axis=0, obs_names when axis=1). + """ + adatas = [ + AnnData( + X=np.ones((1, 2)), + obs=pd.DataFrame({"1": ["a"], "2": ["b"]}), + var=pd.DataFrame(index=["1", "2"]), + ), + AnnData( + X=np.ones((1, 2)), + obs=pd.DataFrame({"1": ["a"], "3": ["b"]}), + var=pd.DataFrame(index=["1", "3"]), + ), + ] + + # Existing behaviour (unchanged): join controls both axes. + r = concat(adatas, join="inner") + assert list(r.obs.columns) == ["1"] + assert list(r.var_names) == ["1"] + r = concat(adatas, join="outer") + assert sorted(r.obs.columns) == ["1", "2", "3"] + assert sorted(r.var_names) == ["1", "2", "3"] + + # New behaviour: outer on-axis columns, inner off-axis index. + r = concat(adatas, join="inner", aligned_axis_key_join="outer") + assert sorted(r.obs.columns) == ["1", "2", "3"] + assert list(r.var_names) == ["1"] + + # Converse: inner on-axis columns, outer off-axis index. + r = concat(adatas, join="outer", aligned_axis_key_join="inner") + assert list(r.obs.columns) == ["1"] + assert sorted(r.var_names) == ["1", "2", "3"]