Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7dc2af2
fixes #2064
sjfleming Jul 30, 2025
4c2d47d
undo accidental message changes
sjfleming Jul 30, 2025
d451c74
remove comment
sjfleming Jul 30, 2025
080050e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2025
77956e9
remove redundant code
sjfleming Jul 30, 2025
dae91ab
test the new function explicitly
sjfleming Jul 30, 2025
c99ec17
additional test case
sjfleming Jul 30, 2025
989cca5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2025
887641e
simplify code
sjfleming Jul 30, 2025
eefb639
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2025
eb55618
ruff check
sjfleming Jul 30, 2025
07c4ab9
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
flying-sheep Jul 31, 2025
be9cec0
address PR comments
sjfleming Jul 31, 2025
4368f0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2025
58b990b
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
sjfleming Aug 8, 2025
5ee3658
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
sjfleming Aug 11, 2025
7c23963
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
sjfleming Aug 13, 2025
28b874c
fix subset type
flying-sheep Aug 29, 2025
b669f71
use correct type
flying-sheep Aug 29, 2025
9cc31e2
compact
flying-sheep Aug 29, 2025
78c6141
some more compacting
flying-sheep Aug 29, 2025
b2b68c3
early return
flying-sheep Aug 29, 2025
8c4ce4e
fix types
flying-sheep Aug 29, 2025
b4d70d7
fix type hint and provide comment
sjfleming Sep 5, 2025
8fde064
simplify
flying-sheep Sep 8, 2025
df7a9ea
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
flying-sheep Sep 8, 2025
ee30fa6
slice with most selective index first
sjfleming Sep 19, 2025
e3c16b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2025
31f5bc5
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
sjfleming Sep 19, 2025
08f2cef
move out
flying-sheep Sep 22, 2025
64af43e
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
flying-sheep Sep 22, 2025
e2b9bdc
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
flying-sheep Sep 25, 2025
bf8f2a2
implement suggestion to remove extra unique call
sjfleming Oct 8, 2025
bdab900
Merge branch 'main' into sf-backed-hdf5-fancy-indexing
sjfleming Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 136 additions & 23 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable, Sequence
from functools import singledispatch
from itertools import repeat
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast, overload

import h5py
import numpy as np
Expand All @@ -14,6 +14,8 @@
from .xarray import Dataset2D

if TYPE_CHECKING:
from numpy.typing import NDArray

from ..compat import Index, Index1D, Index1DNorm


Expand Down Expand Up @@ -161,7 +163,10 @@ def unpack_index(index: Index) -> tuple[Index1D, Index1D]:


@singledispatch
def _subset(a: np.ndarray | pd.DataFrame, subset_idx: Index):
def _subset(
a: np.ndarray | pd.DataFrame,
subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm],
):
# Select as combination of indexes, not coordinates
# Correcting for indexing behaviour of np.ndarray
if all(isinstance(x, Iterable) for x in subset_idx):
Expand All @@ -170,7 +175,9 @@ def _subset(a: np.ndarray | pd.DataFrame, subset_idx: Index):


@_subset.register(DaskArray)
def _subset_dask(a: DaskArray, subset_idx: Index):
def _subset_dask(
a: DaskArray, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm]
):
if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx):
if issparse(a._meta) and a._meta.format == "csc":
return a[:, subset_idx[1]][subset_idx[0], :]
Expand All @@ -180,48 +187,154 @@ def _subset_dask(a: DaskArray, subset_idx: Index):

@_subset.register(CSMatrix)
@_subset.register(CSArray)
def _subset_sparse(a: CSMatrix | CSArray, subset_idx: Index):
def _subset_sparse(
a: CSMatrix | CSArray,
subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm],
):
# Correcting for indexing behaviour of sparse.spmatrix
if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx):
first_idx = subset_idx[0]
if issubclass(first_idx.dtype.type, np.bool_):
first_idx = np.where(first_idx)[0]
first_idx = np.flatnonzero(first_idx)
subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:])
return a[subset_idx]


@_subset.register(pd.DataFrame)
@_subset.register(Dataset2D)
def _subset_df(df: pd.DataFrame | Dataset2D, subset_idx: Index):
def _subset_df(
df: pd.DataFrame | Dataset2D,
subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm],
):
return df.iloc[subset_idx]


@_subset.register(AwkArray)
def _subset_awkarray(a: AwkArray, subset_idx: Index):
def _subset_awkarray(
a: AwkArray, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm]
):
if all(isinstance(x, Iterable) for x in subset_idx):
subset_idx = np.ix_(*subset_idx)
return a[subset_idx]


# Registration for SparseDataset occurs in sparse_dataset.py
@_subset.register(h5py.Dataset)
def _subset_dataset(d: h5py.Dataset, subset_idx: Index):
if not isinstance(subset_idx, tuple):
subset_idx = (subset_idx,)
ordered = list(subset_idx)
rev_order = [slice(None) for _ in range(len(subset_idx))]
for axis, axis_idx in enumerate(ordered.copy()):
if isinstance(axis_idx, np.ndarray):
if axis_idx.dtype == bool:
axis_idx = np.where(axis_idx)[0]
order = np.argsort(axis_idx)
ordered[axis] = axis_idx[order]
rev_order[axis] = np.argsort(order)
def _subset_dataset(
d: h5py.Dataset, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm]
):
order: tuple[NDArray[np.integer] | slice, ...]
inv_order: tuple[NDArray[np.integer] | slice, ...]
order, inv_order = zip(*map(_index_order_and_inverse, subset_idx), strict=True)
# check for duplicates or multi-dimensional fancy indexing
array_dims = [i for i in order if isinstance(i, np.ndarray)]
has_duplicates = any(len(np.unique(i)) != len(i) for i in array_dims)
# Use safe indexing if there are duplicates OR multiple array dimensions
# (h5py doesn't support multi-dimensional fancy indexing natively)
if has_duplicates or len(array_dims) > 1:
# For multi-dimensional indexing, bypass the sorting logic and use original indices
return _safe_fancy_index_h5py(d, subset_idx)
# from hdf5, then to real order
return d[tuple(ordered)][tuple(rev_order)]


def make_slice(idx, dimidx, n=2):
return d[order][inv_order]


@overload
def _index_order_and_inverse(
axis_idx: NDArray[np.integer] | NDArray[np.bool_],
) -> tuple[NDArray[np.integer], NDArray[np.integer]]: ...
@overload
def _index_order_and_inverse(axis_idx: slice) -> tuple[slice, slice]: ...
def _index_order_and_inverse(
axis_idx: Index1DNorm,
) -> tuple[Index1DNorm, NDArray[np.integer] | slice]:
"""Order and get inverse index array."""
if not isinstance(axis_idx, np.ndarray):
return axis_idx, slice(None)
if axis_idx.dtype == bool:
axis_idx = np.flatnonzero(axis_idx)
order = np.argsort(axis_idx)
return axis_idx[order], np.argsort(order)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return axis_idx[order], np.argsort(order)
return axis_idx[order], np.arange(len(order))

isn't order already sorted so argsort would just do an arange?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order is not already sorted here, it's just the index order that sorts axis_idx. I did try the above suggested change, but the result was that tests would no longer pass.



@overload
def _process_index_for_h5py(
idx: NDArray[np.integer] | NDArray[np.bool_],
) -> tuple[NDArray[np.integer], NDArray[np.integer]]: ...
@overload
def _process_index_for_h5py(idx: slice) -> tuple[slice, None]: ...
def _process_index_for_h5py(
idx: Index1DNorm,
) -> tuple[Index1DNorm, NDArray[np.integer] | None]:
"""Process a single index for h5py compatibility, handling sorting and duplicates."""
if not isinstance(idx, np.ndarray):
# Not an array (slice, integer, list) - no special processing needed
return idx, None

if idx.dtype == bool:
idx = np.flatnonzero(idx)

# For h5py fancy indexing, we need sorted indices
# But we also need to track how to reverse the sorting
unique, inverse = np.unique(idx, return_inverse=True)
return (
# Has duplicates - use unique + inverse mapping approach
(unique, inverse)
if len(unique) != len(idx)
# No duplicates - just sort and track reverse mapping
else _index_order_and_inverse(idx)
)


def _safe_fancy_index_h5py(
dataset: h5py.Dataset,
subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm],
) -> h5py.Dataset:
# Handle multi-dimensional indexing of h5py dataset
# This avoids h5py's limitation with multi-dimensional fancy indexing
# without loading the entire dataset into memory

# Convert boolean arrays to integer arrays and handle sorting for h5py
processed_indices: tuple[NDArray[np.integer] | slice, ...]
reverse_indices: tuple[NDArray[np.integer] | None, ...]
processed_indices, reverse_indices = zip(
*map(_process_index_for_h5py, subset_idx), strict=True
)

# First find the index that reduces the size of the dataset the most
i_min = np.argmin([
_get_index_size(inds, dataset.shape[i]) / dataset.shape[i]
for i, inds in enumerate(processed_indices)
])

# Apply the most selective index first to h5py dataset
first_index = [slice(None)] * len(processed_indices)
first_index[i_min] = processed_indices[i_min]
in_memory_array = cast("np.ndarray", dataset[tuple(first_index)])

# Apply remaining indices to the numpy array
remaining_indices = list(processed_indices)
remaining_indices[i_min] = slice(None) # Already applied
result = in_memory_array[tuple(remaining_indices)]

# Now apply reverse mappings to get the original order
for dim, reverse_map in enumerate(reverse_indices):
if reverse_map is not None:
result = result.take(reverse_map, axis=dim)

return result


def _get_index_size(idx: Index1DNorm, dim_size: int) -> int:
"""Get size for any index type."""
if isinstance(idx, slice):
return len(range(*idx.indices(dim_size)))
elif isinstance(idx, int):
return 1
else: # For other types, try to get length
return len(idx)


def make_slice(idx, dimidx: int, n: int = 2) -> tuple[slice, ...]:
mut = list(repeat(slice(None), n))
mut[dimidx] = idx
return tuple(mut)
Expand Down
11 changes: 6 additions & 5 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from collections.abc import Collection, Generator, Iterable, Sequence
from typing import Any

from numpy.typing import NDArray
from pandas.api.extensions import ExtensionDtype

from anndata._types import Join_T
Expand Down Expand Up @@ -553,7 +554,7 @@ class Reindexer:
Together with `old_pos` this forms a mapping.
"""

def __init__(self, old_idx, new_idx):
def __init__(self, old_idx: pd.Index, new_idx: pd.Index) -> None:
self.old_idx = old_idx
self.new_idx = new_idx
self.no_change = new_idx.equals(old_idx)
Expand Down Expand Up @@ -753,7 +754,7 @@ def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None):
return el[self.idx]

@property
def idx(self):
def idx(self) -> NDArray[np.intp]:
return self.old_idx.get_indexer(self.new_idx)


Expand Down Expand Up @@ -782,7 +783,7 @@ def default_fill_value(els):
return np.nan


def gen_reindexer(new_var: pd.Index, cur_var: pd.Index):
def gen_reindexer(new_var: pd.Index, cur_var: pd.Index) -> Reindexer:
"""
Given a new set of var_names, and a current set, generates a function which will reindex
a matrix to be aligned with the new set.
Expand Down Expand Up @@ -939,7 +940,7 @@ def inner_concat_aligned_mapping(
return result


def gen_inner_reindexers(els, new_index, axis: Literal[0, 1] = 0):
def gen_inner_reindexers(els, new_index, axis: Literal[0, 1] = 0) -> list[Reindexer]:
alt_axis = 1 - axis
if axis == 0:
df_indices = lambda x: x.columns
Expand Down Expand Up @@ -1016,7 +1017,7 @@ def missing_element(
axis: Literal[0, 1] = 0,
fill_value: Any | None = None,
off_axis_size: int = 0,
) -> np.ndarray | DaskArray:
) -> NDArray[np.bool_] | DaskArray:
"""Generates value to use when there is a missing element."""
should_return_dask = any(isinstance(el, DaskArray) for el in els)
# 0 sized array for in-memory prevents allocating unnecessary memory while preserving broadcasting.
Expand Down
7 changes: 4 additions & 3 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
from scipy.sparse._compressed import _cs_matrix

from .._types import GroupStorageType
from ..compat import H5Array
from .index import Index, Index1D
from ..compat import H5Array, Index, Index1D, Index1DNorm
else:
from scipy.sparse import spmatrix as _cs_matrix

Expand Down Expand Up @@ -738,5 +737,7 @@ def sparse_dataset(


@_subset.register(BaseCompressedSparseDataset)
def subset_sparsedataset(d, subset_idx):
def subset_sparsedataset(
d, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm]
):
return d[subset_idx]
7 changes: 5 additions & 2 deletions src/anndata/experimental/backed/_lazy_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from pathlib import Path
from typing import Literal

from anndata._core.index import Index
from anndata.compat import ZarrGroup

from ...compat import Index1DNorm


K = TypeVar("K", H5Array, ZarrArray)

Expand Down Expand Up @@ -199,7 +200,9 @@ def dtype(self):


@_subset.register(XDataArray)
def _subset_masked(a: XDataArray, subset_idx: Index):
def _subset_masked(
a: XDataArray, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm]
):
return a[subset_idx]


Expand Down
Loading