1010from functools import partial , reduce , singledispatch
1111from itertools import repeat
1212from operator import and_ , or_ , sub
13- from typing import TYPE_CHECKING , Literal , cast
13+ from typing import TYPE_CHECKING , Literal , cast , get_args
1414
1515import numpy as np
1616import pandas as pd
1717from natsort import natsorted
1818from scipy import sparse
1919
2020from anndata ._core .file_backing import to_memory
21+ from anndata ._types import Join_T
2122from anndata ._warnings import ExperimentalFeatureWarning
2223
2324from ..compat import (
3536from .index import _subset , make_slice
3637from .xarray import Dataset2D
3738
39+ JOIN_OPTIONS = get_args (Join_T .__value__ )
40+
3841if TYPE_CHECKING :
3942 from collections .abc import Collection , Generator , Iterable , Sequence
4043 from typing import Any
4144
4245 from numpy .typing import NDArray
4346 from pandas .api .extensions import ExtensionDtype
4447
45- from anndata ._types import Join_T
46-
4748 from ..compat import XDataArray
4849 from ..types import SupportsArrayApi
4950
@@ -951,25 +952,17 @@ def concat_arrays( # noqa: PLR0911, PLR0912
951952def inner_concat_aligned_mapping ( # noqa: PLR0913
952953 mappings ,
953954 * ,
955+ keys ,
954956 reindexers = None ,
955957 index = None ,
956958 axis = 0 ,
957959 concat_axis = None ,
958960 force_lazy : bool = False ,
959- keys = None ,
960961 fill_value = None ,
961962):
962- """Inner-content concat of aligned mappings.
963-
964- By default iterates ``intersect_keys(mappings)``. Pass ``keys`` to override
965- the iterated key set (e.g. ``union_keys(mappings)`` for an outer key join
966- paired with inner content alignment); missing entries are then filled
967- with ``fill_value``.
968- """
963+ """Inner-content concat of aligned mappings over an explicit key set."""
969964 if concat_axis is None :
970965 concat_axis = axis
971- if keys is None :
972- keys = intersect_keys (mappings )
973966
974967 result = {}
975968 ns = [m .parent .shape [axis ] for m in mappings ]
@@ -1188,24 +1181,17 @@ def missing_element(
11881181def outer_concat_aligned_mapping ( # noqa: PLR0913
11891182 mappings ,
11901183 * ,
1184+ keys ,
11911185 reindexers = None ,
11921186 index = None ,
11931187 axis = 0 ,
11941188 concat_axis = None ,
11951189 fill_value = None ,
11961190 force_lazy : bool = False ,
1197- keys = None ,
11981191):
1199- """Outer-content concat of aligned mappings.
1200-
1201- By default iterates ``union_keys(mappings)``. Pass ``keys`` to override
1202- the iterated key set (e.g. ``intersect_keys(mappings)`` for an inner key
1203- join paired with outer content alignment).
1204- """
1192+ """Outer-content concat of aligned mappings over an explicit key set."""
12051193 if concat_axis is None :
12061194 concat_axis = axis
1207- if keys is None :
1208- keys = union_keys (mappings )
12091195
12101196 result = {}
12111197 ns = [m .parent .shape [axis ] for m in mappings ]
@@ -1778,10 +1764,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
17781764 merge = resolve_merge_strategy (merge )
17791765 uns_merge = resolve_merge_strategy (uns_merge )
17801766
1781- if aligned_axis_key_join is not None and aligned_axis_key_join not in (
1782- "inner" ,
1783- "outer" ,
1784- ):
1767+ if aligned_axis_key_join is not None and aligned_axis_key_join not in JOIN_OPTIONS :
17851768 msg = (
17861769 f"`aligned_axis_key_join` must be one of 'inner', 'outer', or None, "
17871770 f"got { aligned_axis_key_join !r} "
@@ -1906,14 +1889,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
19061889 xr .merge (annotations_with_only_dask , join = join , compat = "override" )
19071890 )
19081891
1909- # `join` controls the *off-axis* content alignment shared by layers
1910- # and obsm-style values within each shared key. `aligned_join` controls
1911- # which keys appear on the on-axis side (obs/var columns, obsm/obsp
1912- # keys, layers keys). The two settings are independent; when
1913- # `aligned_axis_key_join=None`, `aligned_join == join` and behaviour
1914- # reduces to the historical single-knob path. Post #1707, X is a
1915- # layer entry rather than a separate field, so the layers mapping
1916- # implicitly carries X through the same `aligned_join` path.
19171892 if join == "inner" :
19181893 concat_aligned_mapping = inner_concat_aligned_mapping
19191894 elif join == "outer" :
@@ -1924,9 +1899,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
19241899 msg = f"{ join = } should have been validated above by pd.concat"
19251900 raise AssertionError (msg )
19261901
1927- # Pairwise key joining (obsp/varp) is purely about which keys appear in
1928- # the result; the per-key block-diagonal does not have a content-alignment
1929- # axis, so a single setting is sufficient here.
19301902 if aligned_join == "inner" :
19311903 aligned_join_keys = intersect_keys
19321904 elif aligned_join == "outer" :
@@ -1935,9 +1907,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
19351907 msg = f"{ aligned_join = } should have been validated"
19361908 raise AssertionError (msg )
19371909
1938- # Layers: `aligned_join` selects which layer keys appear; `reindexers`
1939- # carries the off-axis alignment from X so each kept layer is aligned to
1940- # the same alt-axis as X.
19411910 layer_mappings = [a .layers for a in adatas ]
19421911 layers = concat_aligned_mapping (
19431912 layer_mappings ,
@@ -1946,10 +1915,6 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
19461915 keys = aligned_join_keys (layer_mappings ),
19471916 )
19481917
1949- # obsm/varm: aligned_axis_key_join controls which keys appear; the content
1950- # alignment within a shared key follows `join`. The pre-computed
1951- # `aligned_join_keys` selects the on-axis key set; the inner/outer helper
1952- # selected above performs the off-axis content alignment.
19531918 obsm_mappings = [getattr (a , f"{ axis_name } m" ) for a in adatas ]
19541919 concat_mapping = concat_aligned_mapping (
19551920 obsm_mappings ,
0 commit comments