Skip to content

Commit c47a97f

Browse files
committed
refactor(concat): make aligned mapping key joins explicit
1 parent dac79dd commit c47a97f

2 files changed

Lines changed: 12 additions & 46 deletions

File tree

src/anndata/_core/merge.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
from functools import partial, reduce, singledispatch
1111
from itertools import repeat
1212
from operator import and_, or_, sub
13-
from typing import TYPE_CHECKING, Literal, cast
13+
from typing import TYPE_CHECKING, Literal, cast, get_args
1414

1515
import numpy as np
1616
import pandas as pd
1717
from natsort import natsorted
1818
from scipy import sparse
1919

2020
from anndata._core.file_backing import to_memory
21+
from anndata._types import Join_T
2122
from anndata._warnings import ExperimentalFeatureWarning
2223

2324
from ..compat import (
@@ -35,15 +36,15 @@
3536
from .index import _subset, make_slice
3637
from .xarray import Dataset2D
3738

39+
JOIN_OPTIONS = get_args(Join_T.__value__)
40+
3841
if TYPE_CHECKING:
3942
from collections.abc import Collection, Generator, Iterable, Sequence
4043
from typing import Any
4144

4245
from numpy.typing import NDArray
4346
from pandas.api.extensions import ExtensionDtype
4447

45-
from anndata._types import Join_T
46-
4748
from ..compat import XDataArray
4849
from ..types import SupportsArrayApi
4950

@@ -951,25 +952,17 @@ def concat_arrays( # noqa: PLR0911, PLR0912
951952
def inner_concat_aligned_mapping( # noqa: PLR0913
952953
mappings,
953954
*,
955+
keys,
954956
reindexers=None,
955957
index=None,
956958
axis=0,
957959
concat_axis=None,
958960
force_lazy: bool = False,
959-
keys=None,
960961
fill_value=None,
961962
):
962-
"""Inner-content concat of aligned mappings.
963-
964-
By default iterates ``intersect_keys(mappings)``. Pass ``keys`` to override
965-
the iterated key set (e.g. ``union_keys(mappings)`` for an outer key join
966-
paired with inner content alignment); missing entries are then filled
967-
with ``fill_value``.
968-
"""
963+
"""Inner-content concat of aligned mappings over an explicit key set."""
969964
if concat_axis is None:
970965
concat_axis = axis
971-
if keys is None:
972-
keys = intersect_keys(mappings)
973966

974967
result = {}
975968
ns = [m.parent.shape[axis] for m in mappings]
@@ -1188,24 +1181,17 @@ def missing_element(
11881181
def outer_concat_aligned_mapping( # noqa: PLR0913
11891182
mappings,
11901183
*,
1184+
keys,
11911185
reindexers=None,
11921186
index=None,
11931187
axis=0,
11941188
concat_axis=None,
11951189
fill_value=None,
11961190
force_lazy: bool = False,
1197-
keys=None,
11981191
):
1199-
"""Outer-content concat of aligned mappings.
1200-
1201-
By default iterates ``union_keys(mappings)``. Pass ``keys`` to override
1202-
the iterated key set (e.g. ``intersect_keys(mappings)`` for an inner key
1203-
join paired with outer content alignment).
1204-
"""
1192+
"""Outer-content concat of aligned mappings over an explicit key set."""
12051193
if concat_axis is None:
12061194
concat_axis = axis
1207-
if keys is None:
1208-
keys = union_keys(mappings)
12091195

12101196
result = {}
12111197
ns = [m.parent.shape[axis] for m in mappings]
@@ -1778,10 +1764,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
17781764
merge = resolve_merge_strategy(merge)
17791765
uns_merge = resolve_merge_strategy(uns_merge)
17801766

1781-
if aligned_axis_key_join is not None and aligned_axis_key_join not in (
1782-
"inner",
1783-
"outer",
1784-
):
1767+
if aligned_axis_key_join is not None and aligned_axis_key_join not in JOIN_OPTIONS:
17851768
msg = (
17861769
f"`aligned_axis_key_join` must be one of 'inner', 'outer', or None, "
17871770
f"got {aligned_axis_key_join!r}"
@@ -1906,14 +1889,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
19061889
xr.merge(annotations_with_only_dask, join=join, compat="override")
19071890
)
19081891

1909-
# `join` controls the *off-axis* content alignment shared by layers
1910-
# and obsm-style values within each shared key. `aligned_join` controls
1911-
# which keys appear on the on-axis side (obs/var columns, obsm/obsp
1912-
# keys, layers keys). The two settings are independent; when
1913-
# `aligned_axis_key_join=None`, `aligned_join == join` and behaviour
1914-
# reduces to the historical single-knob path. Post #1707, X is a
1915-
# layer entry rather than a separate field, so the layers mapping
1916-
# implicitly carries X through the same `aligned_join` path.
19171892
if join == "inner":
19181893
concat_aligned_mapping = inner_concat_aligned_mapping
19191894
elif join == "outer":
@@ -1924,9 +1899,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
19241899
msg = f"{join=} should have been validated above by pd.concat"
19251900
raise AssertionError(msg)
19261901

1927-
# Pairwise key joining (obsp/varp) is purely about which keys appear in
1928-
# the result; the per-key block-diagonal does not have a content-alignment
1929-
# axis, so a single setting is sufficient here.
19301902
if aligned_join == "inner":
19311903
aligned_join_keys = intersect_keys
19321904
elif aligned_join == "outer":
@@ -1935,9 +1907,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
19351907
msg = f"{aligned_join=} should have been validated"
19361908
raise AssertionError(msg)
19371909

1938-
# Layers: `aligned_join` selects which layer keys appear; `reindexers`
1939-
# carries the off-axis alignment from X so each kept layer is aligned to
1940-
# the same alt-axis as X.
19411910
layer_mappings = [a.layers for a in adatas]
19421911
layers = concat_aligned_mapping(
19431912
layer_mappings,
@@ -1946,10 +1915,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
19461915
keys=aligned_join_keys(layer_mappings),
19471916
)
19481917

1949-
# obsm/varm: aligned_axis_key_join controls which keys appear; the content
1950-
# alignment within a shared key follows `join`. The pre-computed
1951-
# `aligned_join_keys` selects the on-axis key set; the inner/outer helper
1952-
# selected above performs the off-axis content alignment.
19531918
obsm_mappings = [getattr(a, f"{axis_name}m") for a in adatas]
19541919
concat_mapping = concat_aligned_mapping(
19551920
obsm_mappings,

src/anndata/experimental/multi_files/_anncollection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..._core.aligned_mapping import AxisArrays
1515
from ..._core.anndata import AnnData
1616
from ..._core.index import _normalize_index, _normalize_indices
17-
from ..._core.merge import concat_arrays, inner_concat_aligned_mapping
17+
from ..._core.merge import concat_arrays, inner_concat_aligned_mapping, intersect_keys
1818
from ..._core.sparse_dataset import BaseCompressedSparseDataset
1919
from ..._core.views import _resolve_idx
2020
from ...compat import old_positionals
@@ -768,8 +768,9 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915
768768
if join_obsm == "inner":
769769
view_attrs.remove("obsm")
770770
self._attrs.append("obsm")
771+
obsm_mappings = [a.obsm for a in adatas]
771772
self._obsm = inner_concat_aligned_mapping(
772-
[a.obsm for a in adatas], index=self.obs_names
773+
obsm_mappings, keys=intersect_keys(obsm_mappings), index=self.obs_names
773774
)
774775
self._obsm = (
775776
AxisArrays(self, axis=0, store={}) if self._obsm == {} else self._obsm

0 commit comments

Comments
 (0)