Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/2416.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `aligned_axis_key_join` to {func}`anndata.concat` for controlling on-axis annotation and aligned-mapping key joins (obs/var columns, obsm/varm and obsp/varp keys, layers keys) independently of the off-axis index alignment controlled by `join` {user}`Ekin-Kahraman`
224 changes: 206 additions & 18 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,30 +948,134 @@ def concat_arrays( # noqa: PLR0911, PLR0912
)


def inner_concat_aligned_mapping(
def inner_concat_aligned_mapping( # noqa: PLR0913
mappings,
*,
reindexers=None,
index=None,
axis=0,
concat_axis=None,
force_lazy: bool = False,
keys=None,
fill_value=None,
):
"""Inner-content concat of aligned mappings.

By default iterates ``intersect_keys(mappings)``. Pass ``keys`` to override
the iterated key set (e.g. ``union_keys(mappings)`` for an outer key join
paired with inner content alignment); missing entries are then filled
with ``fill_value``.
"""
if concat_axis is None:
concat_axis = axis
if keys is None:
keys = intersect_keys(mappings)

result = {}
ns = [m.parent.shape[axis] for m in mappings]

for k in intersect_keys(mappings):
els = [m[k] for m in mappings]
if reindexers is None:
cur_reindexers = gen_inner_reindexers(
els, new_index=index, axis=concat_axis
for k in keys:
els = [m.get(k, MissingVal) for m in mappings]
any_missing = any(is_missing(el) for el in els)

if not any_missing:
# All keys present in all mappings — the default
# ``intersect_keys`` path. No fill_value handling needed.
if reindexers is None:
cur_reindexers = gen_inner_reindexers(
els, new_index=index, axis=concat_axis
)
else:
cur_reindexers = reindexers
result[k] = concat_arrays(
els,
cur_reindexers,
index=index,
axis=concat_axis,
force_lazy=force_lazy,
)
continue

# Missing-key path: only reachable when caller passes an explicit
# ``keys`` set wider than ``intersect_keys`` (the outer-key + inner-
# content combination from ``concat(aligned_axis_key_join=...)``).
present_els = [el for el in els if not_missing(el)]

if any(isinstance(el, AwkArray) for el in present_els):
msg = (
"Combining `aligned_axis_key_join` with `join='inner'` is "
"not yet implemented for awkward arrays in `obsm`/`varm` "
"when the key is missing from at least one input. Use the "
"same value for `join` and `aligned_axis_key_join`, or "
"drop the affected awkward entries before concatenating."
)
raise NotImplementedError(msg)

if reindexers is not None:
# Caller-provided reindexers already encode the alt-axis alignment
# across *all* mappings (e.g. the gene-axis intersection for
# ``layers`` when ``join="inner"``). The present-only reindexers
# below would intersect over the present subset only, which is
# wrong for layers because their alt-axis must match X's. Honour
# the caller's reindexers and drop in an identity reindexer for
# missing entries; the filler created by ``missing_element``
# below uses the matching alt-axis size.
target_idx = reindexers[0].new_idx
cur_reindexers = [
reindexers[i] if not_missing(el) else Reindexer(target_idx, target_idx)
for i, el in enumerate(els)
]
off_axis_size = len(target_idx)
elif all(isinstance(el, pd.DataFrame) for el in present_els):
common_cols = reduce(
lambda x, y: x.intersection(y),
(el.columns for el in present_els),
)
cur_reindexers = [
Reindexer(el.columns, common_cols)
if not_missing(el)
else (
lambda _, n=n, cols=common_cols, fv=fill_value: pd.DataFrame(
np.nan if fv is None else fv,
index=range(n),
columns=cols,
)
)
for el, n in zip(els, ns, strict=True)
]
off_axis_size = 0
else:
cur_reindexers = reindexers
inner_present = gen_inner_reindexers(
present_els, new_index=index, axis=concat_axis
)
target_idx = inner_present[0].new_idx
present_iter = iter(inner_present)
cur_reindexers = [
next(present_iter)
if not_missing(el)
else Reindexer(target_idx, target_idx)
for el in els
]
off_axis_size = len(target_idx)

result[k] = concat_arrays(
els, cur_reindexers, index=index, axis=concat_axis, force_lazy=force_lazy
[
el
if not_missing(el)
else missing_element(
n,
axis=concat_axis,
els=els,
fill_value=fill_value,
off_axis_size=off_axis_size,
)
for el, n in zip(els, ns, strict=True)
],
cur_reindexers,
axis=concat_axis,
index=index,
fill_value=fill_value,
force_lazy=force_lazy,
)
return result

Expand Down Expand Up @@ -1081,7 +1185,7 @@ def missing_element(
return xp.zeros(shape, dtype=bool)


def outer_concat_aligned_mapping(
def outer_concat_aligned_mapping( # noqa: PLR0913
mappings,
*,
reindexers=None,
Expand All @@ -1090,13 +1194,23 @@ def outer_concat_aligned_mapping(
concat_axis=None,
fill_value=None,
force_lazy: bool = False,
keys=None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding another kwarg (which can have unintended side effects if the caller doesn't know the behavior), let's make this argument required by the caller so things are explicit and all the "default" logic is lifted up into one place

):
"""Outer-content concat of aligned mappings.

By default iterates ``union_keys(mappings)``. Pass ``keys`` to override
the iterated key set (e.g. ``intersect_keys(mappings)`` for an inner key
join paired with outer content alignment).
"""
if concat_axis is None:
concat_axis = axis
if keys is None:
keys = union_keys(mappings)

result = {}
ns = [m.parent.shape[axis] for m in mappings]

for k in union_keys(mappings):
for k in keys:
els = [m.get(k, MissingVal) for m in mappings]
if reindexers is None:
cur_reindexers = gen_outer_reindexers(
Expand Down Expand Up @@ -1423,6 +1537,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
*,
axis: Literal["obs", 0, "var", 1] = "obs",
join: Join_T | Default = Default("inner"), # noqa: B008
aligned_axis_key_join: Join_T | None = None,
merge: StrategiesLiteral | Callable | None = None,
uns_merge: StrategiesLiteral | Callable | None = None,
label: str | None = None,
Expand All @@ -1447,6 +1562,15 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
How to align values when concatenating. If "outer", the union of the other axis
is taken. If "inner", the intersection. See :doc:`concatenation <../concatenation>`
for more.
aligned_axis_key_join
How to join keys on the *concatenation axis* itself: columns of `obs`/`var`,
keys of `obsm`/`obsp` (or `varm`/`varp` when concatenating along `axis="var"`),
and keys of `layers`. Use "outer" to take the union of these keys, "inner"
to take the intersection. The off-axis content of each value (e.g. the var
index of an obsm DataFrame, or the gene axis of a layer) still follows
`join`. Defaults to `None`, in which case `join` is used for both the
off-axis index alignment and the on-axis key join (the historical
behaviour).
merge
How elements not aligned to the axis being concatenated along are selected.
Currently implemented strategies include:
Expand Down Expand Up @@ -1633,6 +1757,18 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
{'c': {'c.c': 5}}
>>> dict(ad.concat([a, b, c], uns_merge="first").uns)
{'a': 1, 'b': 2, 'c': {'c.a': 3, 'c.b': 4, 'c.c': 5}}

`aligned_axis_key_join` controls on-axis key joining (obs columns,
obsm/obsp keys) independently of the off-axis index `join`. The default
of `None` falls back to `join`, preserving existing behaviour. To keep
the union of `var` indices but the intersection of obs columns:

>>> ad.concat(
... [a, b], join="outer", aligned_axis_key_join="inner"
... ).obs.columns.tolist()
['group']
>>> ad.concat([a, b], join="outer").obs.columns.tolist()
['group', 'measure']
"""

from anndata._core.xarray import Dataset2D
Expand All @@ -1642,6 +1778,16 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
merge = resolve_merge_strategy(merge)
uns_merge = resolve_merge_strategy(uns_merge)

if aligned_axis_key_join is not None and aligned_axis_key_join not in (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a JoinT: Literal you can use to get the values using typing.get_args

"inner",
"outer",
):
msg = (
f"`aligned_axis_key_join` must be one of 'inner', 'outer', or None, "
f"got {aligned_axis_key_join!r}"
)
raise ValueError(msg)

if isinstance(adatas, Mapping):
if keys is not None:
msg = (
Expand All @@ -1665,6 +1811,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
)
warn(msg, UserWarning)

# Resolve the on-axis key join. When aligned_axis_key_join is None, the
# historical behaviour applies and `join` controls both off-axis index
# alignment and on-axis key joining. When set, it overrides the on-axis
# key joining for obs/obsm/obsp (or var/varm/varp when axis="var").
# Resolved after the `Default` unwrap above so `aligned_join` is always
# a plain string.
aligned_join = aligned_axis_key_join if aligned_axis_key_join is not None else join

if keys is None:
keys = np.arange(len(adatas)).astype(str)

Expand Down Expand Up @@ -1695,7 +1849,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
]

# Annotation for concatenation axis
check_combinable_cols([getattr(a, axis_name).columns for a in adatas], join=join)
check_combinable_cols(
[getattr(a, axis_name).columns for a in adatas], join=aligned_join
)
annotations = [getattr(a, axis_name) for a in adatas]
are_any_annotations_dataframes = any(
isinstance(a, pd.DataFrame) for a in annotations
Expand All @@ -1706,14 +1862,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
)
concat_annot = pd.concat(
unify_dtypes(annotations_in_memory),
join=join,
join=aligned_join,
ignore_index=True,
)
concat_annot.index = concat_indices
else:
concat_annot = concat_dataset2d_on_annot_axis(
annotations,
join,
aligned_join,
force_lazy=force_lazy,
concat_indices=concat_indices,
)
Expand Down Expand Up @@ -1750,33 +1906,64 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
xr.merge(annotations_with_only_dask, join=join, compat="override")
)

# `join` controls the *off-axis* content alignment shared by layers
# and obsm-style values within each shared key. `aligned_join` controls
# which keys appear on the on-axis side (obs/var columns, obsm/obsp
# keys, layers keys). The two settings are independent; when
# `aligned_axis_key_join=None`, `aligned_join == join` and behaviour
# reduces to the historical single-knob path. Post #1707, X is a
# layer entry rather than a separate field, so the layers mapping
# implicitly carries X through the same `aligned_join` path.
if join == "inner":
concat_aligned_mapping = inner_concat_aligned_mapping
join_keys = intersect_keys
elif join == "outer":
concat_aligned_mapping = partial(
outer_concat_aligned_mapping, fill_value=fill_value
)
join_keys = union_keys
else:
msg = f"{join=} should have been validated above by pd.concat"
raise AssertionError(msg)

# Pairwise key joining (obsp/varp) is purely about which keys appear in
# the result; the per-key block-diagonal does not have a content-alignment
# axis, so a single setting is sufficient here.
if aligned_join == "inner":
aligned_join_keys = intersect_keys
elif aligned_join == "outer":
aligned_join_keys = union_keys
else:
msg = f"{aligned_join=} should have been validated"
raise AssertionError(msg)

# Layers: `aligned_join` selects which layer keys appear; `reindexers`
# carries the off-axis alignment from X so each kept layer is aligned to
# the same alt-axis as X.
layer_mappings = [a.layers for a in adatas]
layers = concat_aligned_mapping(
[a.layers for a in adatas], axis=axis, reindexers=reindexers
layer_mappings,
axis=axis,
Comment on lines +1910 to +1913
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
layer_mappings = [a.layers for a in adatas]
layers = concat_aligned_mapping(
[a.layers for a in adatas], axis=axis, reindexers=reindexers
layer_mappings,
axis=axis,
layers = concat_aligned_mapping(
[a.layers for a in adatas],
axis=axis,

reindexers=reindexers,
keys=aligned_join_keys(layer_mappings),
)

# obsm/varm: aligned_axis_key_join controls which keys appear; the content
# alignment within a shared key follows `join`. The pre-computed
# `aligned_join_keys` selects the on-axis key set; the inner/outer helper
# selected above performs the off-axis content alignment.
obsm_mappings = [getattr(a, f"{axis_name}m") for a in adatas]
concat_mapping = concat_aligned_mapping(
[getattr(a, f"{axis_name}m") for a in adatas],
obsm_mappings,
axis=axis,
concat_axis=0,
index=concat_indices,
force_lazy=force_lazy,
keys=aligned_join_keys(obsm_mappings),
)
if pairwise:
concat_pairwise = concat_pairwise_mapping(
mappings=[getattr(a, f"{axis_name}p") for a in adatas],
shapes=[a.shape[axis] for a in adatas],
join_keys=join_keys,
join_keys=aligned_join_keys,
)
else:
concat_pairwise = {}
Expand Down Expand Up @@ -1808,6 +1995,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
for a in adatas
],
join=join,
aligned_axis_key_join=aligned_axis_key_join,
label=label,
keys=keys,
index_unique=index_unique,
Expand Down
Loading
Loading