Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 187 additions & 13 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,115 @@ def outer_concat_aligned_mapping(
return result


def _concat_aligned_mapping_split_join( # noqa: PLR0913
mappings,
*,
key_join: Join_T,
content_join: Join_T,
fill_value=None,
axis=0,
concat_axis=None,
index=None,
force_lazy: bool = False,
):
"""Concatenate aligned mappings (obsm/varm style) with separate key and
content joins. ``key_join`` selects which keys appear in the result;
``content_join`` selects how shared keys' values are aligned along the
off-axis dimension (e.g. inner intersects DataFrame columns, outer unions
them). Used for ``concat(aligned_axis_key_join=...)`` when the on-axis key
join differs from the off-axis ``join``.
"""
if concat_axis is None:
concat_axis = axis
keys = union_keys(mappings) if key_join == "outer" else intersect_keys(mappings)
ns = [m.parent.shape[axis] for m in mappings]

result = {}
for k in keys:
els = [m.get(k, MissingVal) for m in mappings]
any_missing = any(is_missing(el) for el in els)
present_els = [el for el in els if not_missing(el)]

if content_join == "inner":
# Inner content alignment: intersect the off-axis dimension among
# values that are actually present, then reindex everything to that
# intersection. Missing entries get a filler matching the shape so
# the downstream concat can stack them.
if any_missing and any(isinstance(el, AwkArray) for el in present_els):
msg = (
"Combining `aligned_axis_key_join` with `join='inner'` is "
"not yet implemented for awkward arrays in `obsm`/`varm` "
"when the key is missing from at least one input. Use the "
"same value for `join` and `aligned_axis_key_join`, or "
"drop the affected awkward entries before concatenating."
)
raise NotImplementedError(msg)
if all(isinstance(el, pd.DataFrame) for el in present_els):
common_cols = reduce(
lambda x, y: x.intersection(y),
(el.columns for el in present_els),
)
cur_reindexers = [
Reindexer(el.columns, common_cols)
if not_missing(el)
else (
lambda _, n=n, cols=common_cols, fv=fill_value: pd.DataFrame(
np.nan if fv is None else fv,
index=range(n),
columns=cols,
)
)
for el, n in zip(els, ns, strict=True)
]
# Use an empty filler so concat_arrays' DataFrame check passes;
# the lambda reindexers above replace these with proper DataFrames.
off_axis_size = 0
else:
inner_present = gen_inner_reindexers(
present_els, new_index=index, axis=concat_axis
)
target_idx = inner_present[0].new_idx
present_iter = iter(inner_present)
cur_reindexers = [
next(present_iter)
if not_missing(el)
else Reindexer(target_idx, target_idx)
for el in els
]
off_axis_size = len(target_idx)
else:
cur_reindexers = gen_outer_reindexers(
els, ns, new_index=index, axis=concat_axis
)
off_axis_size = 0
if any(isinstance(e, DaskArray) for e in els if not_missing(e)):
if not isinstance(cur_reindexers[0], Reindexer): # pragma: no cover
msg = "Cannot re-index a dask array without a Reindexer"
raise ValueError(msg)
off_axis_size = cur_reindexers[0].idx.shape[0]

result[k] = concat_arrays(
[
el
if not_missing(el)
else missing_element(
n,
axis=concat_axis,
els=els,
fill_value=fill_value,
off_axis_size=off_axis_size,
)
for el, n in zip(els, ns, strict=True)
],
cur_reindexers,
axis=concat_axis,
index=index,
fill_value=fill_value if any_missing else None,
force_lazy=force_lazy,
)
return result


def concat_pairwise_mapping(
mappings: Collection[Mapping], shapes: Collection[int], join_keys=intersect_keys
):
Expand Down Expand Up @@ -1447,6 +1556,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
*,
axis: Literal["obs", 0, "var", 1] = "obs",
join: Join_T = "inner",
aligned_axis_key_join: Join_T | None = None,
merge: StrategiesLiteral | Callable | None = None,
uns_merge: StrategiesLiteral | Callable | None = None,
label: str | None = None,
Expand All @@ -1471,6 +1581,12 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
How to align values when concatenating. If "outer", the union of the other axis
is taken. If "inner", the intersection. See :doc:`concatenation <../concatenation>`
for more.
aligned_axis_key_join
How to join keys on the *concatenation axis* itself: columns of `obs`/`var`,
and keys of `obsm`/`obsp` (or `varm`/`varp` when concatenating along `axis="var"`).
Use "outer" to take the union of these keys, "inner" to take the intersection.
Defaults to `None`, in which case `join` is used for both the off-axis index
alignment and the on-axis key join (the historical behaviour).
merge
How elements not aligned to the axis being concatenated along are selected.
Currently implemented strategies include:
Expand Down Expand Up @@ -1651,6 +1767,18 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
{'c': {'c.c': 5}}
>>> dict(ad.concat([a, b, c], uns_merge="first").uns)
{'a': 1, 'b': 2, 'c': {'c.a': 3, 'c.b': 4, 'c.c': 5}}

`aligned_axis_key_join` controls on-axis key joining (obs columns,
obsm/obsp keys) independently of the off-axis index `join`. The default
of `None` falls back to `join`, preserving existing behaviour. To keep
the union of `var` indices but the intersection of obs columns:

>>> ad.concat(
... [a, b], join="outer", aligned_axis_key_join="inner"
... ).obs.columns.tolist()
['group']
>>> ad.concat([a, b], join="outer").obs.columns.tolist()
['group', 'measure']
"""

from anndata._core.xarray import Dataset2D
Expand All @@ -1660,6 +1788,21 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
merge = resolve_merge_strategy(merge)
uns_merge = resolve_merge_strategy(uns_merge)

# Resolve the on-axis key join. When aligned_axis_key_join is None, the
# historical behaviour applies and `join` controls both off-axis index
# alignment and on-axis key joining. When set, it overrides the on-axis
# key joining for obs/obsm/obsp (or var/varm/varp when axis="var").
if aligned_axis_key_join is not None and aligned_axis_key_join not in (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a JoinT: Literal you can use to get the values using typing.get_args

"inner",
"outer",
):
msg = (
f"`aligned_axis_key_join` must be one of 'inner', 'outer', or None, "
f"got {aligned_axis_key_join!r}"
)
raise ValueError(msg)
aligned_join = aligned_axis_key_join if aligned_axis_key_join is not None else join

if isinstance(adatas, Mapping):
if keys is not None:
msg = (
Expand Down Expand Up @@ -1701,7 +1844,9 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
]

# Annotation for concatenation axis
check_combinable_cols([getattr(a, axis_name).columns for a in adatas], join=join)
check_combinable_cols(
[getattr(a, axis_name).columns for a in adatas], join=aligned_join
)
annotations = [getattr(a, axis_name) for a in adatas]
are_any_annotations_dataframes = any(
isinstance(a, pd.DataFrame) for a in annotations
Expand All @@ -1712,14 +1857,14 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
)
concat_annot = pd.concat(
unify_dtypes(annotations_in_memory),
join=join,
join=aligned_join,
ignore_index=True,
)
concat_annot.index = concat_indices
else:
concat_annot = concat_dataset2d_on_annot_axis(
annotations,
join,
aligned_join,
force_lazy=force_lazy,
concat_indices=concat_indices,
)
Expand Down Expand Up @@ -1758,33 +1903,61 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915

X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value)

# Helper bindings for off-axis-shaped objects (layers): keep using `join`.
if join == "inner":
concat_aligned_mapping = inner_concat_aligned_mapping
join_keys = intersect_keys
elif join == "outer":
concat_aligned_mapping = partial(
outer_concat_aligned_mapping, fill_value=fill_value
)
join_keys = union_keys
else:
msg = f"{join=} should have been validated above by pd.concat"
raise AssertionError(msg)

# Pairwise key joining (obsp/varp) is purely about which keys appear in
# the result; the per-key block-diagonal does not have a content-alignment
# axis, so a single setting is sufficient here.
if aligned_join == "inner":
aligned_join_keys = intersect_keys
elif aligned_join == "outer":
aligned_join_keys = union_keys
else:
msg = f"{aligned_join=} should have been validated"
raise AssertionError(msg)

layers = concat_aligned_mapping(
[a.layers for a in adatas], axis=axis, reindexers=reindexers
)
concat_mapping = concat_aligned_mapping(
[getattr(a, f"{axis_name}m") for a in adatas],
axis=axis,
concat_axis=0,
index=concat_indices,
force_lazy=force_lazy,
)

# obsm/varm: aligned_axis_key_join controls which keys appear, while the
# content alignment (e.g. DataFrame columns within a shared key) follows
# `join`. When the two settings agree we can use the existing helpers;
# when they diverge we use a split-join helper.
obsm_mappings = [getattr(a, f"{axis_name}m") for a in adatas]
if aligned_join == join:
concat_mapping = concat_aligned_mapping(
obsm_mappings,
axis=axis,
concat_axis=0,
index=concat_indices,
force_lazy=force_lazy,
)
else:
concat_mapping = _concat_aligned_mapping_split_join(
obsm_mappings,
key_join=aligned_join,
content_join=join,
fill_value=fill_value,
axis=axis,
concat_axis=0,
index=concat_indices,
force_lazy=force_lazy,
)
if pairwise:
concat_pairwise = concat_pairwise_mapping(
mappings=[getattr(a, f"{axis_name}p") for a in adatas],
shapes=[a.shape[axis] for a in adatas],
join_keys=join_keys,
join_keys=aligned_join_keys,
)
else:
concat_pairwise = {}
Expand Down Expand Up @@ -1816,6 +1989,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
for a in adatas
],
join=join,
aligned_axis_key_join=aligned_axis_key_join,
label=label,
keys=keys,
index_unique=index_unique,
Expand Down
Loading
Loading