From 85fb1f2f26a08480e8afaad1d2db3429b9a110b3 Mon Sep 17 00:00:00 2001 From: Philipp A Date: Thu, 31 Jul 2025 14:56:59 +0200 Subject: [PATCH] Backport PR #2067: Fix index types --- src/anndata/_core/anndata.py | 27 ++++++++++--- src/anndata/_core/index.py | 19 +++------ src/anndata/_core/raw.py | 8 ++-- src/anndata/_core/views.py | 34 ++++++++++------- src/anndata/_core/xarray.py | 25 ++++++------ src/anndata/compat/__init__.py | 25 +++++++++++- src/anndata/tests/helpers.py | 70 +++++++++++++++++++++------------- tests/test_backed_sparse.py | 9 +++-- tests/test_base.py | 2 +- tests/test_views.py | 48 +++++++++++++++-------- 10 files changed, 172 insertions(+), 95 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index ba39880db..2fbfa8325 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -56,7 +56,7 @@ from zarr.storage import StoreLike - from ..compat import Index1D, XDataset + from ..compat import Index1D, Index1DNorm, XDataset from ..typing import XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView from .index import Index @@ -197,6 +197,11 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641 _accessors: ClassVar[set[str]] = set() + # view attributes + _adata_ref: AnnData | None + _oidx: Index1DNorm | None + _vidx: Index1DNorm | None + @old_positionals( "obsm", "varm", @@ -226,8 +231,8 @@ def __init__( # noqa: PLR0913 asview: bool = False, obsp: np.ndarray | Mapping[str, Sequence[Any]] | None = None, varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None, - oidx: Index1D | None = None, - vidx: Index1D | None = None, + oidx: Index1DNorm | int | np.integer | None = None, + vidx: Index1DNorm | int | np.integer | None = None, ): # check for any multi-indices that aren’t later checked in coerce_array for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]: @@ -237,6 +242,8 @@ def __init__( # noqa: PLR0913 if not isinstance(X, AnnData): msg = "`X` has to be an AnnData object." raise ValueError(msg) + assert oidx is not None + assert vidx is not None self._init_as_view(X, oidx, vidx) else: self._init_as_actual( @@ -256,7 +263,12 @@ def __init__( # noqa: PLR0913 filemode=filemode, ) - def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index): + def _init_as_view( + self, + adata_ref: AnnData, + oidx: Index1DNorm | int | np.integer, + vidx: Index1DNorm | int | np.integer, + ): if adata_ref.isbacked and adata_ref.is_view: msg = ( "Currently, you cannot index repeatedly into a backed AnnData, " @@ -277,6 +289,9 @@ def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index): vidx += adata_ref.n_vars * (vidx < 0) vidx = slice(vidx, vidx + 1, 1) if adata_ref.is_view: + assert adata_ref._adata_ref is not None + assert adata_ref._oidx is not None + assert adata_ref._vidx is not None prev_oidx, prev_vidx = adata_ref._oidx, adata_ref._vidx adata_ref = adata_ref._adata_ref oidx, vidx = _resolve_idxs((prev_oidx, prev_vidx), (oidx, vidx), adata_ref) @@ -1004,7 +1019,9 @@ def _set_backed(self, attr, value): write_attribute(self.file._file, attr, value) - def _normalize_indices(self, index: Index | None) -> tuple[slice, slice]: + def _normalize_indices( + self, index: Index | None + ) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]: return _normalize_indices(index, self.obs_names, self.var_names) # TODO: this is not quite complete... diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index 5ed271add..ec5d857c1 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -14,18 +14,18 @@ from .xarray import Dataset2D if TYPE_CHECKING: - from ..compat import Index, Index1D + from ..compat import Index, Index1D, Index1DNorm def _normalize_indices( index: Index | None, names0: pd.Index, names1: pd.Index -) -> tuple[slice, slice]: +) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]: # deal with tuples of length 1 if isinstance(index, tuple) and len(index) == 1: index = index[0] # deal with pd.Series if isinstance(index, pd.Series): - index: Index = index.values + index = index.values if isinstance(index, tuple): # TODO: The series should probably be aligned first index = tuple(i.values if isinstance(i, pd.Series) else i for i in index) @@ -36,15 +36,8 @@ def _normalize_indices( def _normalize_index( # noqa: PLR0911, PLR0912 - indexer: slice - | np.integer - | int - | str - | Sequence[bool | int | np.integer] - | np.ndarray - | pd.Index, - index: pd.Index, -) -> slice | int | np.ndarray: # ndarray of int or bool + indexer: Index1D, index: pd.Index +) -> Index1DNorm | int | np.integer: # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64): msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" @@ -212,7 +205,7 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index): # Registration for SparseDataset occurs in sparse_dataset.py @_subset.register(h5py.Dataset) -def _subset_dataset(d, subset_idx): +def _subset_dataset(d: h5py.Dataset, subset_idx: Index): if not isinstance(subset_idx, tuple): subset_idx = (subset_idx,) ordered = list(subset_idx) diff --git a/src/anndata/_core/raw.py b/src/anndata/_core/raw.py index 96a6fe06e..90325948f 100644 --- a/src/anndata/_core/raw.py +++ b/src/anndata/_core/raw.py @@ -17,7 +17,7 @@ from collections.abc import Mapping, Sequence from typing import ClassVar - from ..compat import CSMatrix + from ..compat import CSMatrix, Index, Index1DNorm from .aligned_mapping import AxisArraysView from .anndata import AnnData from .sparse_dataset import BaseCompressedSparseDataset @@ -121,7 +121,7 @@ def var_names(self) -> pd.Index[str]: def obs_names(self) -> pd.Index[str]: return self._adata.obs_names - def __getitem__(self, index): + def __getitem__(self, index: Index) -> Raw: oidx, vidx = self._normalize_indices(index) # To preserve two dimensional shape @@ -169,7 +169,9 @@ def to_adata(self) -> AnnData: uns=self._adata.uns.copy(), ) - def _normalize_indices(self, packed_index): + def _normalize_indices( + self, packed_index: Index + ) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]: # deal with slicing with pd.Series if isinstance(packed_index, pd.Series): packed_index = packed_index.values diff --git a/src/anndata/_core/views.py b/src/anndata/_core/views.py index ac9a0dd0f..9801d28af 100644 --- a/src/anndata/_core/views.py +++ b/src/anndata/_core/views.py @@ -29,8 +29,12 @@ from collections.abc import Callable, Iterable, KeysView, Sequence from typing import Any, ClassVar + from numpy.typing import NDArray + from anndata import AnnData + from ..compat import Index1DNorm + @contextmanager def view_update(adata_view: AnnData, attr_name: str, keys: tuple[str, ...]): @@ -433,18 +437,24 @@ class AwkwardArrayView: pass -def _resolve_idxs(old, new, adata): - t = tuple(_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1)) - return t +def _resolve_idxs( + old: tuple[Index1DNorm, Index1DNorm], + new: tuple[Index1DNorm, Index1DNorm], + adata: AnnData, +) -> tuple[Index1DNorm, Index1DNorm]: + o, v = (_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1)) + return o, v @singledispatch -def _resolve_idx(old, new, l): - return old[new] +def _resolve_idx(old: Index1DNorm, new: Index1DNorm, l: Literal[0, 1]) -> Index1DNorm: + raise NotImplementedError @_resolve_idx.register(np.ndarray) -def _resolve_idx_ndarray(old, new, l): +def _resolve_idx_ndarray( + old: NDArray[np.bool_] | NDArray[np.integer], new: Index1DNorm, l: Literal[0, 1] +) -> NDArray[np.bool_] | NDArray[np.integer]: if is_bool_dtype(old) and is_bool_dtype(new): mask_new = np.zeros_like(old) mask_new[np.flatnonzero(old)[new]] = True @@ -454,21 +464,17 @@ def _resolve_idx_ndarray(old, new, l): return old[new] -@_resolve_idx.register(np.integer) -@_resolve_idx.register(int) -def _resolve_idx_scalar(old, new, l): - return np.array([old])[new] - - @_resolve_idx.register(slice) -def _resolve_idx_slice(old, new, l): +def _resolve_idx_slice( + old: slice, new: Index1DNorm, l: Literal[0, 1] +) -> slice | NDArray[np.integer]: if isinstance(new, slice): return _resolve_idx_slice_slice(old, new, l) else: return np.arange(*old.indices(l))[new] -def _resolve_idx_slice_slice(old, new, l): +def _resolve_idx_slice_slice(old: slice, new: slice, l: Literal[0, 1]) -> slice: r = range(*old.indices(l))[new] # Convert back to slice start, stop, step = r.start, r.stop, r.step diff --git a/src/anndata/_core/xarray.py b/src/anndata/_core/xarray.py index 9b93258d5..28754b628 100644 --- a/src/anndata/_core/xarray.py +++ b/src/anndata/_core/xarray.py @@ -184,18 +184,6 @@ def iloc(self) -> Dataset2DIlocIndexer: Handler class for doing the iloc-style indexing using :meth:`~xarray.Dataset.isel`. """ - @dataclass(frozen=True) - class IlocGetter: - _ds: XDataset - _coord: str - - def __getitem__(self, idx) -> Dataset2D: - # xarray seems to have some code looking for a second entry in tuples, - # so we unpack the tuple - if isinstance(idx, tuple) and len(idx) == 1: - idx = idx[0] - return Dataset2D(self._ds.isel(**{self._coord: idx})) - return IlocGetter(self.ds, self.index_dim) # See https://github.com/pydata/xarray/blob/568f3c1638d2d34373408ce2869028faa3949446/xarray/core/dataset.py#L1239-L1248 @@ -402,3 +390,16 @@ def reindex( def _items(self): for col in self: yield col, self[col] + + +@dataclass(frozen=True) +class IlocGetter: + _ds: XDataset + _coord: str + + def __getitem__(self, idx) -> Dataset2D: + # xarray seems to have some code looking for a second entry in tuples, + # so we unpack the tuple + if isinstance(idx, tuple) and len(idx) == 1: + idx = idx[0] + return Dataset2D(self._ds.isel(**{self._coord: idx})) diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index 6eb4da48b..8835804e2 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from codecs import decode -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from functools import cache, partial, singledispatch from importlib.util import find_spec from types import EllipsisType @@ -12,6 +12,7 @@ import numpy as np import pandas as pd import scipy +from numpy.typing import NDArray from packaging.version import Version from zarr import Array as ZarrArray # noqa: F401 from zarr import Group as ZarrGroup @@ -19,6 +20,7 @@ if TYPE_CHECKING: from typing import Any + ############################# # scipy sparse array comapt # ############################# @@ -32,7 +34,26 @@ class Empty: pass -Index1D = slice | int | str | np.int64 | np.ndarray | pd.Series +Index1DNorm = slice | NDArray[np.bool_] | NDArray[np.integer] +# TODO: pd.Index[???] +Index1D = ( + # 0D index + int + | str + | np.int64 + # normalized 1D idex + | Index1DNorm + # different containers for mask, obs/varnames, or numerical index + | Sequence[int] + | Sequence[str] + | Sequence[bool] + | pd.Series # bool, int, str + | pd.Index + | NDArray[np.str_] + | np.matrix # bool + | CSMatrix # bool + | CSArray # bool +) IndexRest = Index1D | EllipsisType Index = ( IndexRest diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 6752b0708..44b66b927 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -42,12 +42,15 @@ from collections.abc import Callable, Collection, Iterable from typing import Literal, TypeGuard, TypeVar + from numpy.typing import NDArray from zarr.abc.store import ByteRequest from zarr.core.buffer import BufferPrototype from .._types import ArrayStorageType + from ..compat import Index1D DT = TypeVar("DT") + _SubsetFunc = Callable[[pd.Index[str], int], Index1D] try: @@ -428,7 +431,7 @@ def gen_adata( # noqa: PLR0913 return adata -def array_bool_subset(index, min_size=2): +def array_bool_subset(index: pd.Index[str], min_size: int = 2) -> NDArray[np.bool_]: b = np.zeros(len(index), dtype=bool) selected = np.random.choice( range(len(index)), @@ -439,11 +442,11 @@ def array_bool_subset(index, min_size=2): return b -def list_bool_subset(index, min_size=2): +def list_bool_subset(index: pd.Index[str], min_size: int = 2) -> list[bool]: return array_bool_subset(index, min_size=min_size).tolist() -def matrix_bool_subset(index, min_size=2): +def matrix_bool_subset(index: pd.Index[str], min_size: int = 2) -> np.matrix: with warnings.catch_warnings(): warnings.simplefilter("ignore", PendingDeprecationWarning) indexer = np.matrix( @@ -452,19 +455,26 @@ def matrix_bool_subset(index, min_size=2): return indexer -def spmatrix_bool_subset(index, min_size=2): +def spmatrix_bool_subset(index: pd.Index[str], min_size: int = 2) -> sparse.csr_matrix: return sparse.csr_matrix( array_bool_subset(index, min_size=min_size).reshape(len(index), 1) ) -def sparray_bool_subset(index, min_size=2): +def sparray_bool_subset(index: pd.Index[str], min_size: int = 2) -> sparse.csr_array: return sparse.csr_array( array_bool_subset(index, min_size=min_size).reshape(len(index), 1) ) -def array_subset(index, min_size=2): +def single_subset(index: pd.Index[str], min_size: int = 1) -> str: + if min_size > 1: + msg = "max_size must be ≤1" + raise AssertionError(msg) + return index[np.random.randint(0, len(index))] + + +def array_subset(index: pd.Index[str], min_size: int = 2) -> NDArray[np.str_]: if len(index) < min_size: msg = f"min_size (={min_size}) must be smaller than len(index) (={len(index)}" raise ValueError(msg) @@ -473,7 +483,7 @@ def array_subset(index, min_size=2): ) -def array_int_subset(index, min_size=2): +def array_int_subset(index: pd.Index[str], min_size: int = 2) -> NDArray[np.int64]: if len(index) < min_size: msg = f"min_size (={min_size}) must be smaller than len(index) (={len(index)}" raise ValueError(msg) @@ -484,11 +494,11 @@ def array_int_subset(index, min_size=2): ) -def list_int_subset(index, min_size=2): +def list_int_subset(index: pd.Index[str], min_size: int = 2) -> list[int]: return array_int_subset(index, min_size=min_size).tolist() -def slice_subset(index, min_size=2): +def slice_int_subset(index: pd.Index[str], min_size: int = 2) -> slice: while True: points = np.random.choice(np.arange(len(index) + 1), size=2, replace=False) s = slice(*sorted(points)) @@ -497,25 +507,33 @@ def slice_subset(index, min_size=2): return s -def single_subset(index): - return index[np.random.randint(0, len(index))] +def single_int_subset(index: pd.Index[str], min_size: int = 1) -> int: + if min_size > 1: + msg = "max_size must be ≤1" + raise AssertionError(msg) + return np.random.randint(0, len(index)) + + +_SUBSET_FUNCS: list[_SubsetFunc] = [ + # str (obs/var name) + single_subset, + array_subset, + # int (numeric index) + single_int_subset, + slice_int_subset, + array_int_subset, + list_int_subset, + # bool (mask) + array_bool_subset, + list_bool_subset, + matrix_bool_subset, + spmatrix_bool_subset, + sparray_bool_subset, +] -@pytest.fixture( - params=[ - array_subset, - slice_subset, - single_subset, - array_int_subset, - list_int_subset, - array_bool_subset, - list_bool_subset, - matrix_bool_subset, - spmatrix_bool_subset, - sparray_bool_subset, - ] -) -def subset_func(request): +@pytest.fixture(params=_SUBSET_FUNCS) +def subset_func(request: pytest.FixtureRequest) -> _SubsetFunc: return request.param diff --git a/tests/test_backed_sparse.py b/tests/test_backed_sparse.py index 2f8381279..79d64ee4a 100644 --- a/tests/test_backed_sparse.py +++ b/tests/test_backed_sparse.py @@ -17,6 +17,7 @@ from anndata._io.zarr import open_write_group from anndata.compat import CSArray, CSMatrix, DaskArray, ZarrGroup, is_zarr_v2 from anndata.experimental import read_dispatched +from anndata.tests import helpers as test_helpers from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func if TYPE_CHECKING: @@ -312,10 +313,10 @@ def test_append_array_cache_bust(tmp_path: Path, diskfmt: Literal["h5ad", "zarr" ("subset_func", "subset_func2"), product( [ - ad.tests.helpers.array_subset, - ad.tests.helpers.slice_subset, - ad.tests.helpers.array_int_subset, - ad.tests.helpers.array_bool_subset, + test_helpers.array_subset, + test_helpers.slice_int_subset, + test_helpers.array_int_subset, + test_helpers.array_bool_subset, ], repeat=2, ), diff --git a/tests/test_base.py b/tests/test_base.py index 98d75a45e..f5d18448c 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -767,4 +767,4 @@ def test_create_adata_from_single_axis_elem( assert in_memory.shape == (10, 0) if axis == "obs" else (0, 10) in_memory.write_h5ad(tmp_path / "adata.h5ad") from_disk = ad.read_h5ad(tmp_path / "adata.h5ad") - ad.tests.helpers.assert_equal(from_disk, in_memory) + assert_equal(from_disk, in_memory) diff --git a/tests/test_views.py b/tests/test_views.py index 29a02b503..d52f9adfc 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -30,8 +30,9 @@ GEN_ADATA_DASK_ARGS, assert_equal, gen_adata, + single_int_subset, single_subset, - slice_subset, + slice_int_subset, subset_func, ) from anndata.utils import asarray @@ -70,31 +71,44 @@ def view(self, dtype=None, typ=None): @pytest.fixture -def adata(): +def adata() -> ad.AnnData: adata = ad.AnnData(np.zeros((100, 100))) adata.obsm["o"] = np.zeros((100, 50)) adata.varm["o"] = np.zeros((100, 50)) return adata +@pytest.fixture(scope="session") +def adata_gen_session(matrix_type) -> ad.AnnData: + adata = gen_adata((30, 15), X_type=matrix_type) + adata.raw = adata.copy() + return adata + + +@pytest.fixture +def adata_gen(adata_gen_session: ad.AnnData) -> ad.AnnData: + return adata_gen_session.copy() + + @pytest.fixture( params=BASE_MATRIX_PARAMS + DASK_MATRIX_PARAMS + CUPY_MATRIX_PARAMS, + scope="session", ) def matrix_type(request): return request.param -@pytest.fixture(params=BASE_MATRIX_PARAMS + DASK_MATRIX_PARAMS) +@pytest.fixture(params=BASE_MATRIX_PARAMS + DASK_MATRIX_PARAMS, scope="session") def matrix_type_no_gpu(request): return request.param -@pytest.fixture(params=BASE_MATRIX_PARAMS) +@pytest.fixture(params=BASE_MATRIX_PARAMS, scope="session") def matrix_type_base(request): return request.param -@pytest.fixture(params=["layers", "obsm", "varm"]) +@pytest.fixture(params=["layers", "obsm", "varm"], scope="session") def mapping_name(request): return request.param @@ -316,7 +330,7 @@ def test_not_set_subset_X(matrix_type_base, subset_func): init_hash = joblib.hash(adata) orig_X_val = adata.X.copy() while True: - subset_idx = slice_subset(adata.obs_names) + subset_idx = slice_int_subset(adata.obs_names) if len(adata[subset_idx, :]) > 2: break subset = adata[subset_idx, :] @@ -344,7 +358,7 @@ def test_not_set_subset_X_dask(matrix_type_no_gpu, subset_func): init_hash = tokenize(adata) orig_X_val = adata.X.copy() while True: - subset_idx = slice_subset(adata.obs_names) + subset_idx = slice_int_subset(adata.obs_names) if len(adata[subset_idx, :]) > 2: break subset = adata[subset_idx, :] @@ -392,7 +406,7 @@ def test_set_subset_obsm(adata, subset_func): orig_obsm_val = adata.obsm["o"].copy() while True: - subset_idx = slice_subset(adata.obs_names) + subset_idx = slice_int_subset(adata.obs_names) if len(adata[subset_idx, :]) > 2: break subset = adata[subset_idx, :] @@ -415,7 +429,7 @@ def test_set_subset_varm(adata, subset_func): orig_varm_val = adata.varm["o"].copy() while True: - subset_idx = slice_subset(adata.var_names) + subset_idx = slice_int_subset(adata.var_names) if (adata[:, subset_idx]).shape[1] > 2: break subset = adata[:, subset_idx] @@ -526,27 +540,31 @@ def test_layers_view(): assert view_hash != joblib.hash(view_adata) -# TODO: This can be flaky. Make that stop -def test_view_of_view(matrix_type, subset_func, subset_func2): - adata = gen_adata((30, 15), X_type=matrix_type) - adata.raw = adata.copy() - if subset_func is single_subset: +# TODO: less combinatoric; split up into 2 tests: +# 1. each subset func produces the right `oidx`/`vidx` kind (slice, array[int], array[bool]) +# 2. each `oidx`/`vidx` kind can be sliced with each subset func +# going from #subset_func² to #subset_func × 3 {ov}idx kinds × 2 tests +def test_view_of_view(adata_gen: ad.AnnData, subset_func, subset_func2) -> None: + adata = adata_gen + if subset_func in {single_subset, single_int_subset}: pytest.xfail("Other subset generating functions have trouble with this") var_s1 = subset_func(adata.var_names, min_size=4) var_view1 = adata[:, var_s1] adata[:, var_s1].X # noqa: B018 var_s2 = subset_func2(var_view1.var_names) var_view2 = var_view1[:, var_s2] + assert var_view2._adata_ref is adata assert isinstance(var_view2.X, type(adata.X)) + obs_s1 = subset_func(adata.obs_names, min_size=4) obs_view1 = adata[obs_s1, :] obs_s2 = subset_func2(obs_view1.obs_names) + assert adata[obs_s1, :][:, var_s1][obs_s2, :]._adata_ref is adata assert isinstance(obs_view1.X, type(adata.X)) view_of_actual_copy = adata[:, var_s1].copy()[obs_s1, :].copy()[:, var_s2].copy() - view_of_view_copy = adata[:, var_s1][obs_s1, :][:, var_s2].copy() assert_equal(view_of_actual_copy, view_of_view_copy, exact=True)