Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/2416.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `aligned_axis_key_join` to {func}`anndata.concat` for controlling on-axis annotation and aligned-mapping key joins (obs/var columns, obsm/varm and obsp/varp keys, layers keys) independently of the off-axis index alignment controlled by `join` {user}`Ekin-Kahraman`
195 changes: 174 additions & 21 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from functools import partial, reduce, singledispatch
from itertools import repeat
from operator import and_, or_, sub
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal, cast, get_args

import numpy as np
import pandas as pd
from natsort import natsorted
from scipy import sparse

from anndata._core.file_backing import to_memory
from anndata._types import Join_T
from anndata._warnings import ExperimentalFeatureWarning

from ..compat import (
Expand All @@ -35,15 +36,15 @@
from .index import _subset, make_slice
from .xarray import Dataset2D

JOIN_OPTIONS = get_args(Join_T.__value__)

if TYPE_CHECKING:
from collections.abc import Collection, Generator, Iterable, Sequence
from typing import Any

from numpy.typing import NDArray
from pandas.api.extensions import ExtensionDtype

from anndata._types import Join_T

from ..compat import XDataArray
from ..types import SupportsArrayApi

Expand Down Expand Up @@ -948,30 +949,126 @@ def concat_arrays( # noqa: PLR0911, PLR0912
)


def inner_concat_aligned_mapping(
def inner_concat_aligned_mapping( # noqa: PLR0913
mappings,
*,
keys,
reindexers=None,
index=None,
axis=0,
concat_axis=None,
force_lazy: bool = False,
fill_value=None,
):
"""Inner-content concat of aligned mappings over an explicit key set."""
if concat_axis is None:
concat_axis = axis

result = {}
ns = [m.parent.shape[axis] for m in mappings]

for k in intersect_keys(mappings):
els = [m[k] for m in mappings]
if reindexers is None:
cur_reindexers = gen_inner_reindexers(
els, new_index=index, axis=concat_axis
for k in keys:
els = [m.get(k, MissingVal) for m in mappings]
any_missing = any(is_missing(el) for el in els)

if not any_missing:
# All keys present in all mappings — the default
# ``intersect_keys`` path. No fill_value handling needed.
if reindexers is None:
cur_reindexers = gen_inner_reindexers(
els, new_index=index, axis=concat_axis
)
else:
cur_reindexers = reindexers
result[k] = concat_arrays(
els,
cur_reindexers,
index=index,
axis=concat_axis,
force_lazy=force_lazy,
)
continue

# Missing-key path: only reachable when caller passes an explicit
# ``keys`` set wider than ``intersect_keys`` (the outer-key + inner-
# content combination from ``concat(aligned_axis_key_join=...)``).
present_els = [el for el in els if not_missing(el)]

if any(isinstance(el, AwkArray) for el in present_els):
msg = (
"Combining `aligned_axis_key_join` with `join='inner'` is "
"not yet implemented for awkward arrays in `obsm`/`varm` "
"when the key is missing from at least one input. Use the "
"same value for `join` and `aligned_axis_key_join`, or "
"drop the affected awkward entries before concatenating."
)
raise NotImplementedError(msg)

if reindexers is not None:
# Caller-provided reindexers already encode the alt-axis alignment
# across *all* mappings (e.g. the gene-axis intersection for
# ``layers`` when ``join="inner"``). The present-only reindexers
# below would intersect over the present subset only, which is
# wrong for layers because their alt-axis must match X's. Honour
# the caller's reindexers and drop in an identity reindexer for
# missing entries; the filler created by ``missing_element``
# below uses the matching alt-axis size.
target_idx = reindexers[0].new_idx
cur_reindexers = [
reindexers[i] if not_missing(el) else Reindexer(target_idx, target_idx)
for i, el in enumerate(els)
]
off_axis_size = len(target_idx)
elif all(isinstance(el, pd.DataFrame) for el in present_els):
common_cols = reduce(
lambda x, y: x.intersection(y),
(el.columns for el in present_els),
)
cur_reindexers = [
Reindexer(el.columns, common_cols)
if not_missing(el)
else (
lambda _, n=n, cols=common_cols, fv=fill_value: pd.DataFrame(
np.nan if fv is None else fv,
index=range(n),
columns=cols,
)
)
for el, n in zip(els, ns, strict=True)
]
off_axis_size = 0
else:
cur_reindexers = reindexers
inner_present = gen_inner_reindexers(
present_els, new_index=index, axis=concat_axis
)
target_idx = inner_present[0].new_idx
present_iter = iter(inner_present)
cur_reindexers = [
next(present_iter)
if not_missing(el)
else Reindexer(target_idx, target_idx)
for el in els
]
off_axis_size = len(target_idx)

result[k] = concat_arrays(
els, cur_reindexers, index=index, axis=concat_axis, force_lazy=force_lazy
[
el
if not_missing(el)
else missing_element(
n,
axis=concat_axis,
els=els,
fill_value=fill_value,
off_axis_size=off_axis_size,
)
for el, n in zip(els, ns, strict=True)
],
cur_reindexers,
axis=concat_axis,
index=index,
fill_value=fill_value,
force_lazy=force_lazy,
)
return result

Expand Down Expand Up @@ -1081,22 +1178,25 @@ def missing_element(
return xp.zeros(shape, dtype=bool)


def outer_concat_aligned_mapping(
def outer_concat_aligned_mapping( # noqa: PLR0913
mappings,
*,
keys,
reindexers=None,
index=None,
axis=0,
concat_axis=None,
fill_value=None,
force_lazy: bool = False,
):
"""Outer-content concat of aligned mappings over an explicit key set."""
if concat_axis is None:
concat_axis = axis

result = {}
ns = [m.parent.shape[axis] for m in mappings]

for k in union_keys(mappings):
for k in keys:
els = [m.get(k, MissingVal) for m in mappings]
if reindexers is None:
cur_reindexers = gen_outer_reindexers(
Expand Down Expand Up @@ -1423,6 +1523,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
*,
axis: Literal["obs", 0, "var", 1] = "obs",
join: Join_T | Default = Default("inner"), # noqa: B008
aligned_axis_key_join: Join_T | None = None,
merge: StrategiesLiteral | Callable | None = None,
uns_merge: StrategiesLiteral | Callable | None = None,
label: str | None = None,
Expand All @@ -1447,6 +1548,15 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
How to align values when concatenating. If "outer", the union of the other axis
is taken. If "inner", the intersection. See :doc:`concatenation <../concatenation>`
for more.
aligned_axis_key_join
How to join keys on the *concatenation axis* itself: columns of `obs`/`var`,
keys of `obsm`/`obsp` (or `varm`/`varp` when concatenating along `axis="var"`),
and keys of `layers`. Use "outer" to take the union of these keys, "inner"
to take the intersection. The off-axis content of each value (e.g. the var
index of an obsm DataFrame, or the gene axis of a layer) still follows
`join`. Defaults to `None`, in which case `join` is used for both the
off-axis index alignment and the on-axis key join (the historical
behaviour).
merge
How elements not aligned to the axis being concatenated along are selected.
Currently implemented strategies include:
Expand Down Expand Up @@ -1633,6 +1743,18 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
{'c': {'c.c': 5}}
>>> dict(ad.concat([a, b, c], uns_merge="first").uns)
{'a': 1, 'b': 2, 'c': {'c.a': 3, 'c.b': 4, 'c.c': 5}}

`aligned_axis_key_join` controls on-axis key joining (obs columns,
obsm/obsp keys) independently of the off-axis index `join`. The default
of `None` falls back to `join`, preserving existing behaviour. To keep
the union of `var` indices but the intersection of obs columns:

>>> ad.concat(
... [a, b], join="outer", aligned_axis_key_join="inner"
... ).obs.columns.tolist()
['group']
>>> ad.concat([a, b], join="outer").obs.columns.tolist()
['group', 'measure']
"""

from anndata._core.xarray import Dataset2D
Expand All @@ -1642,6 +1764,13 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
merge = resolve_merge_strategy(merge)
uns_merge = resolve_merge_strategy(uns_merge)

if aligned_axis_key_join is not None and aligned_axis_key_join not in JOIN_OPTIONS:
msg = (
f"`aligned_axis_key_join` must be one of 'inner', 'outer', or None, "
f"got {aligned_axis_key_join!r}"
)
raise ValueError(msg)

if isinstance(adatas, Mapping):
if keys is not None:
msg = (
Expand All @@ -1665,6 +1794,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
)
warn(msg, UserWarning)

# Resolve the on-axis key join. When aligned_axis_key_join is None, the
# historical behaviour applies and `join` controls both off-axis index
# alignment and on-axis key joining. When set, it overrides the on-axis
# key joining for obs/obsm/obsp (or var/varm/varp when axis="var").
# Resolved after the `Default` unwrap above so `aligned_join` is always
# a plain string.
aligned_join = aligned_axis_key_join if aligned_axis_key_join is not None else join

if keys is None:
keys = np.arange(len(adatas)).astype(str)

Expand Down Expand Up @@ -1695,7 +1832,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
]

# Annotation for concatenation axis
check_combinable_cols([getattr(a, axis_name).columns for a in adatas], join=join)
check_combinable_cols(
[getattr(a, axis_name).columns for a in adatas], join=aligned_join
)
annotations = [getattr(a, axis_name) for a in adatas]
are_any_annotations_dataframes = any(
isinstance(a, pd.DataFrame) for a in annotations
Expand All @@ -1706,14 +1845,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
)
concat_annot = pd.concat(
unify_dtypes(annotations_in_memory),
join=join,
join=aligned_join,
ignore_index=True,
)
concat_annot.index = concat_indices
else:
concat_annot = concat_dataset2d_on_annot_axis(
annotations,
join,
aligned_join,
force_lazy=force_lazy,
concat_indices=concat_indices,
)
Expand Down Expand Up @@ -1752,31 +1891,44 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915

if join == "inner":
concat_aligned_mapping = inner_concat_aligned_mapping
join_keys = intersect_keys
elif join == "outer":
concat_aligned_mapping = partial(
outer_concat_aligned_mapping, fill_value=fill_value
)
join_keys = union_keys
else:
msg = f"{join=} should have been validated above by pd.concat"
raise AssertionError(msg)

if aligned_join == "inner":
aligned_join_keys = intersect_keys
elif aligned_join == "outer":
aligned_join_keys = union_keys
else:
msg = f"{aligned_join=} should have been validated"
raise AssertionError(msg)

layer_mappings = [a.layers for a in adatas]
layers = concat_aligned_mapping(
[a.layers for a in adatas], axis=axis, reindexers=reindexers
layer_mappings,
axis=axis,
Comment on lines +1910 to +1913
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
layer_mappings = [a.layers for a in adatas]
layers = concat_aligned_mapping(
[a.layers for a in adatas], axis=axis, reindexers=reindexers
layer_mappings,
axis=axis,
layers = concat_aligned_mapping(
[a.layers for a in adatas],
axis=axis,

reindexers=reindexers,
keys=aligned_join_keys(layer_mappings),
)

obsm_mappings = [getattr(a, f"{axis_name}m") for a in adatas]
concat_mapping = concat_aligned_mapping(
[getattr(a, f"{axis_name}m") for a in adatas],
obsm_mappings,
axis=axis,
concat_axis=0,
index=concat_indices,
force_lazy=force_lazy,
keys=aligned_join_keys(obsm_mappings),
)
if pairwise:
concat_pairwise = concat_pairwise_mapping(
mappings=[getattr(a, f"{axis_name}p") for a in adatas],
shapes=[a.shape[axis] for a in adatas],
join_keys=join_keys,
join_keys=aligned_join_keys,
)
else:
concat_pairwise = {}
Expand Down Expand Up @@ -1808,6 +1960,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
for a in adatas
],
join=join,
aligned_axis_key_join=aligned_axis_key_join,
label=label,
keys=keys,
index_unique=index_unique,
Expand Down
5 changes: 3 additions & 2 deletions src/anndata/experimental/multi_files/_anncollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..._core.aligned_mapping import AxisArrays
from ..._core.anndata import AnnData
from ..._core.index import _normalize_index, _normalize_indices
from ..._core.merge import concat_arrays, inner_concat_aligned_mapping
from ..._core.merge import concat_arrays, inner_concat_aligned_mapping, intersect_keys
from ..._core.sparse_dataset import BaseCompressedSparseDataset
from ..._core.views import _resolve_idx
from ...compat import old_positionals
Expand Down Expand Up @@ -768,8 +768,9 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915
if join_obsm == "inner":
view_attrs.remove("obsm")
self._attrs.append("obsm")
obsm_mappings = [a.obsm for a in adatas]
self._obsm = inner_concat_aligned_mapping(
[a.obsm for a in adatas], index=self.obs_names
obsm_mappings, keys=intersect_keys(obsm_mappings), index=self.obs_names
)
self._obsm = (
AxisArrays(self, axis=0, store={}) if self._obsm == {} else self._obsm
Expand Down
Loading
Loading