diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index ec5d857c1..3b92a99ac 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Sequence from functools import singledispatch from itertools import repeat -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast, overload import h5py import numpy as np @@ -14,6 +14,8 @@ from .xarray import Dataset2D if TYPE_CHECKING: + from numpy.typing import NDArray + from ..compat import Index, Index1D, Index1DNorm @@ -161,7 +163,10 @@ def unpack_index(index: Index) -> tuple[Index1D, Index1D]: @singledispatch -def _subset(a: np.ndarray | pd.DataFrame, subset_idx: Index): +def _subset( + a: np.ndarray | pd.DataFrame, + subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm], +): # Select as combination of indexes, not coordinates # Correcting for indexing behaviour of np.ndarray if all(isinstance(x, Iterable) for x in subset_idx): @@ -170,7 +175,9 @@ def _subset(a: np.ndarray | pd.DataFrame, subset_idx: Index): @_subset.register(DaskArray) -def _subset_dask(a: DaskArray, subset_idx: Index): +def _subset_dask( + a: DaskArray, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm] +): if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx): if issparse(a._meta) and a._meta.format == "csc": return a[:, subset_idx[1]][subset_idx[0], :] @@ -180,24 +187,32 @@ def _subset_dask(a: DaskArray, subset_idx: Index): @_subset.register(CSMatrix) @_subset.register(CSArray) -def _subset_sparse(a: CSMatrix | CSArray, subset_idx: Index): +def _subset_sparse( + a: CSMatrix | CSArray, + subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm], +): # Correcting for indexing behaviour of sparse.spmatrix if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx): first_idx = subset_idx[0] if issubclass(first_idx.dtype.type, np.bool_): - first_idx = np.where(first_idx)[0] + first_idx = np.flatnonzero(first_idx) subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:]) return a[subset_idx] @_subset.register(pd.DataFrame) @_subset.register(Dataset2D) -def _subset_df(df: pd.DataFrame | Dataset2D, subset_idx: Index): +def _subset_df( + df: pd.DataFrame | Dataset2D, + subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm], +): return df.iloc[subset_idx] @_subset.register(AwkArray) -def _subset_awkarray(a: AwkArray, subset_idx: Index): +def _subset_awkarray( + a: AwkArray, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm] +): if all(isinstance(x, Iterable) for x in subset_idx): subset_idx = np.ix_(*subset_idx) return a[subset_idx] @@ -205,23 +220,121 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index): # Registration for SparseDataset occurs in sparse_dataset.py @_subset.register(h5py.Dataset) -def _subset_dataset(d: h5py.Dataset, subset_idx: Index): - if not isinstance(subset_idx, tuple): - subset_idx = (subset_idx,) - ordered = list(subset_idx) - rev_order = [slice(None) for _ in range(len(subset_idx))] - for axis, axis_idx in enumerate(ordered.copy()): - if isinstance(axis_idx, np.ndarray): - if axis_idx.dtype == bool: - axis_idx = np.where(axis_idx)[0] - order = np.argsort(axis_idx) - ordered[axis] = axis_idx[order] - rev_order[axis] = np.argsort(order) +def _subset_dataset( + d: h5py.Dataset, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm] +): + order: tuple[NDArray[np.integer] | slice, ...] + inv_order: tuple[NDArray[np.integer] | slice, ...] + order, inv_order = zip(*map(_index_order_and_inverse, subset_idx), strict=True) + # check for duplicates or multi-dimensional fancy indexing + array_dims = [i for i in order if isinstance(i, np.ndarray)] + has_duplicates = any(len(np.unique(i)) != len(i) for i in array_dims) + # Use safe indexing if there are duplicates OR multiple array dimensions + # (h5py doesn't support multi-dimensional fancy indexing natively) + if has_duplicates or len(array_dims) > 1: + # For multi-dimensional indexing, bypass the sorting logic and use original indices + return _safe_fancy_index_h5py(d, subset_idx) # from hdf5, then to real order - return d[tuple(ordered)][tuple(rev_order)] - - -def make_slice(idx, dimidx, n=2): + return d[order][inv_order] + + +@overload +def _index_order_and_inverse( + axis_idx: NDArray[np.integer] | NDArray[np.bool_], +) -> tuple[NDArray[np.integer], NDArray[np.integer]]: ... +@overload +def _index_order_and_inverse(axis_idx: slice) -> tuple[slice, slice]: ... +def _index_order_and_inverse( + axis_idx: Index1DNorm, +) -> tuple[Index1DNorm, NDArray[np.integer] | slice]: + """Order and get inverse index array.""" + if not isinstance(axis_idx, np.ndarray): + return axis_idx, slice(None) + if axis_idx.dtype == bool: + axis_idx = np.flatnonzero(axis_idx) + order = np.argsort(axis_idx) + return axis_idx[order], np.argsort(order) + + +@overload +def _process_index_for_h5py( + idx: NDArray[np.integer] | NDArray[np.bool_], +) -> tuple[NDArray[np.integer], NDArray[np.integer]]: ... +@overload +def _process_index_for_h5py(idx: slice) -> tuple[slice, None]: ... +def _process_index_for_h5py( + idx: Index1DNorm, +) -> tuple[Index1DNorm, NDArray[np.integer] | None]: + """Process a single index for h5py compatibility, handling sorting and duplicates.""" + if not isinstance(idx, np.ndarray): + # Not an array (slice, integer, list) - no special processing needed + return idx, None + + if idx.dtype == bool: + idx = np.flatnonzero(idx) + + # For h5py fancy indexing, we need sorted indices + # But we also need to track how to reverse the sorting + unique, inverse = np.unique(idx, return_inverse=True) + return ( + # Has duplicates - use unique + inverse mapping approach + (unique, inverse) + if len(unique) != len(idx) + # No duplicates - just sort and track reverse mapping + else _index_order_and_inverse(idx) + ) + + +def _safe_fancy_index_h5py( + dataset: h5py.Dataset, + subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm], +) -> h5py.Dataset: + # Handle multi-dimensional indexing of h5py dataset + # This avoids h5py's limitation with multi-dimensional fancy indexing + # without loading the entire dataset into memory + + # Convert boolean arrays to integer arrays and handle sorting for h5py + processed_indices: tuple[NDArray[np.integer] | slice, ...] + reverse_indices: tuple[NDArray[np.integer] | None, ...] + processed_indices, reverse_indices = zip( + *map(_process_index_for_h5py, subset_idx), strict=True + ) + + # First find the index that reduces the size of the dataset the most + i_min = np.argmin([ + _get_index_size(inds, dataset.shape[i]) / dataset.shape[i] + for i, inds in enumerate(processed_indices) + ]) + + # Apply the most selective index first to h5py dataset + first_index = [slice(None)] * len(processed_indices) + first_index[i_min] = processed_indices[i_min] + in_memory_array = cast("np.ndarray", dataset[tuple(first_index)]) + + # Apply remaining indices to the numpy array + remaining_indices = list(processed_indices) + remaining_indices[i_min] = slice(None) # Already applied + result = in_memory_array[tuple(remaining_indices)] + + # Now apply reverse mappings to get the original order + for dim, reverse_map in enumerate(reverse_indices): + if reverse_map is not None: + result = result.take(reverse_map, axis=dim) + + return result + + +def _get_index_size(idx: Index1DNorm, dim_size: int) -> int: + """Get size for any index type.""" + if isinstance(idx, slice): + return len(range(*idx.indices(dim_size))) + elif isinstance(idx, int): + return 1 + else: # For other types, try to get length + return len(idx) + + +def make_slice(idx, dimidx: int, n: int = 2) -> tuple[slice, ...]: mut = list(repeat(slice(None), n)) mut[dimidx] = idx return tuple(mut) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index e054c75db..5624b25fd 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -38,6 +38,7 @@ from collections.abc import Collection, Generator, Iterable, Sequence from typing import Any + from numpy.typing import NDArray from pandas.api.extensions import ExtensionDtype from anndata._types import Join_T @@ -553,7 +554,7 @@ class Reindexer: Together with `old_pos` this forms a mapping. """ - def __init__(self, old_idx, new_idx): + def __init__(self, old_idx: pd.Index, new_idx: pd.Index) -> None: self.old_idx = old_idx self.new_idx = new_idx self.no_change = new_idx.equals(old_idx) @@ -753,7 +754,7 @@ def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None): return el[self.idx] @property - def idx(self): + def idx(self) -> NDArray[np.intp]: return self.old_idx.get_indexer(self.new_idx) @@ -782,7 +783,7 @@ def default_fill_value(els): return np.nan -def gen_reindexer(new_var: pd.Index, cur_var: pd.Index): +def gen_reindexer(new_var: pd.Index, cur_var: pd.Index) -> Reindexer: """ Given a new set of var_names, and a current set, generates a function which will reindex a matrix to be aligned with the new set. @@ -939,7 +940,7 @@ def inner_concat_aligned_mapping( return result -def gen_inner_reindexers(els, new_index, axis: Literal[0, 1] = 0): +def gen_inner_reindexers(els, new_index, axis: Literal[0, 1] = 0) -> list[Reindexer]: alt_axis = 1 - axis if axis == 0: df_indices = lambda x: x.columns @@ -1016,7 +1017,7 @@ def missing_element( axis: Literal[0, 1] = 0, fill_value: Any | None = None, off_axis_size: int = 0, -) -> np.ndarray | DaskArray: +) -> NDArray[np.bool_] | DaskArray: """Generates value to use when there is a missing element.""" should_return_dask = any(isinstance(el, DaskArray) for el in els) # 0 sized array for in-memory prevents allocating unnecessary memory while preserving broadcasting. diff --git a/src/anndata/_core/sparse_dataset.py b/src/anndata/_core/sparse_dataset.py index feb5868f8..f6beff9c3 100644 --- a/src/anndata/_core/sparse_dataset.py +++ b/src/anndata/_core/sparse_dataset.py @@ -48,8 +48,7 @@ from scipy.sparse._compressed import _cs_matrix from .._types import GroupStorageType - from ..compat import H5Array - from .index import Index, Index1D + from ..compat import H5Array, Index, Index1D, Index1DNorm else: from scipy.sparse import spmatrix as _cs_matrix @@ -738,5 +737,7 @@ def sparse_dataset( @_subset.register(BaseCompressedSparseDataset) -def subset_sparsedataset(d, subset_idx): +def subset_sparsedataset( + d, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm] +): return d[subset_idx] diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index ca9bde00c..b0f1f4bf7 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -25,9 +25,10 @@ from pathlib import Path from typing import Literal - from anndata._core.index import Index from anndata.compat import ZarrGroup + from ...compat import Index1DNorm + K = TypeVar("K", H5Array, ZarrArray) @@ -199,7 +200,9 @@ def dtype(self): @_subset.register(XDataArray) -def _subset_masked(a: XDataArray, subset_idx: Index): +def _subset_masked( + a: XDataArray, subset_idx: tuple[Index1DNorm] | tuple[Index1DNorm, Index1DNorm] +): return a[subset_idx] diff --git a/tests/test_backed_hdf5.py b/tests/test_backed_hdf5.py index 25e7c05b1..789898431 100644 --- a/tests/test_backed_hdf5.py +++ b/tests/test_backed_hdf5.py @@ -10,7 +10,6 @@ from scipy import sparse import anndata as ad -from anndata.compat import CSArray, CSMatrix from anndata.tests.helpers import ( GEN_ADATA_DASK_ARGS, GEN_ADATA_NO_XARRAY_ARGS, @@ -241,14 +240,6 @@ def test_backed_raw_subset( mem_adata.raw = mem_adata obs_idx = subset_func(mem_adata.obs_names) var_idx = subset_func2(mem_adata.var_names) - if ( - array_type is asarray - and isinstance(obs_idx, list | np.ndarray | CSMatrix | CSArray) - and isinstance(var_idx, list | np.ndarray | CSMatrix | CSArray) - ): - pytest.xfail( - "Fancy indexing does not work with multiple arrays on a h5py.Dataset" - ) mem_adata.write(backed_pth) ### Backed view has same values as in memory view ### @@ -402,3 +393,105 @@ def test_backed_modification_sparse( # backed_view = backed_adata[[1,2], :] # backed_view.X = 0 # assert np.all(backed_adata.X[[1,2], :] == 0) + + +@pytest.mark.parametrize( + ("obs_idx", "var_idx"), + [ + pytest.param(np.array([0, 1, 2]), np.array([1, 2]), id="no_dupes"), + pytest.param(np.array([0, 1, 0, 2]), slice(None), id="1d_dupes"), + pytest.param(np.array([0, 1, 0, 2]), np.array([1, 2, 1]), id="2d_dupes"), + ], +) +def test_backed_duplicate_indices(tmp_path, obs_idx, var_idx): + """Test that backed HDF5 datasets handle duplicate indices correctly.""" + backed_pth = tmp_path / "backed.h5ad" + + # Create test data + mem_adata = gen_adata((6, 4), X_type=asarray, **GEN_ADATA_NO_XARRAY_ARGS) + mem_adata.write(backed_pth) + + # Load backed data + backed_adata = ad.read_h5ad(backed_pth, backed="r") + + # Test the indexing + mem_result_multi = mem_adata[obs_idx, var_idx] + backed_result_multi = backed_adata[obs_idx, var_idx] + assert_equal(mem_result_multi, backed_result_multi) + + +@pytest.fixture +def h5py_test_data(tmp_path): + """Create test HDF5 file with dataset for _safe_fancy_index_h5py tests.""" + import h5py + + h5_path = tmp_path / "test_dataset.h5" + test_data = np.arange(24).reshape(6, 4) # 6x4 matrix + + with h5py.File(h5_path, "w") as f: + f.create_dataset("test", data=test_data) + + return h5_path, test_data + + +@pytest.mark.parametrize( + ("indices", "description"), + [ + pytest.param((np.array([0, 1, 0, 2]),), "single_dimension_with_duplicates"), + pytest.param( + (np.array([0, 1, 2]), np.array([1, 2])), "multi_dimensional_no_duplicates" + ), + pytest.param( + (np.array([0, 1, 0, 2]), np.array([1, 2])), + "multi_dimensional_duplicates_first_dim", + ), + pytest.param( + (np.array([0, 1, 2]), np.array([1, 2, 1])), + "multi_dimensional_duplicates_second_dim", + ), + pytest.param( + (np.array([0, 1, 0]), np.array([1, 2, 1])), + "multi_dimensional_duplicates_both_dims", + ), + pytest.param( + (np.array([True, False, True, False, False, True]),), "boolean_arrays" + ), + pytest.param((np.array([0, 1, 0]), slice(1, 3)), "mixed_indexing_with_slices"), + pytest.param( + (np.array([0, 1, 0]), [1, 2]), "mixed_indexing_with_slices_and_lists" + ), + pytest.param((np.array([3, 1, 3, 0, 1]),), "unsorted_indices_with_duplicates"), + ], +) +def test_safe_fancy_index_h5py_function(h5py_test_data, indices, description): + """Test the _safe_fancy_index_h5py function directly with various indexing patterns.""" + import h5py + + from anndata._core.index import _safe_fancy_index_h5py + + h5_path, test_data = h5py_test_data + + with h5py.File(h5_path, "r") as f: + dataset = f["test"] + + # Get result from the function + result = _safe_fancy_index_h5py(dataset, indices) + + # Calculate expected result using NumPy + if isinstance(indices, tuple) and len(indices) > 1: + # Multi-dimensional case - use np.ix_ for fancy indexing + if isinstance(indices[1], slice): + # Handle mixed case with slice + expected = test_data[ + np.ix_(indices[0], np.arange(indices[1].start, indices[1].stop)) + ] + else: + expected = test_data[np.ix_(*indices)] + else: + # Single dimensional case + expected = test_data[indices] + + # Assert arrays are equal + np.testing.assert_array_equal( + result, expected, err_msg=f"Failed for test case: {description}" + ) diff --git a/tests/test_inplace_subset.py b/tests/test_inplace_subset.py index ce0e75c47..075503276 100644 --- a/tests/test_inplace_subset.py +++ b/tests/test_inplace_subset.py @@ -90,6 +90,7 @@ def test_inplace_subset_no_X(subset_func, dim): subset_idx = subset_func(getattr(orig, f"{dim}_names")) modified = orig.copy() + # TODO: apart from this test, `_subset` is never called with strings, lists, … from_view = subset_dim(orig, **{dim: subset_idx}).copy() getattr(modified, f"_inplace_subset_{dim}")(subset_idx)