Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

from zarr.storage import StoreLike

from ..compat import Index1D, XDataset
from ..compat import Index1D, Index1DNorm, XDataset
from ..typing import XDataType
from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
from .index import Index
Expand Down Expand Up @@ -197,6 +197,11 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641

_accessors: ClassVar[set[str]] = set()

# view attributes
_adata_ref: AnnData | None
_oidx: Index1DNorm | None
_vidx: Index1DNorm | None

@old_positionals(
"obsm",
"varm",
Expand Down Expand Up @@ -226,8 +231,8 @@ def __init__( # noqa: PLR0913
asview: bool = False,
obsp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
oidx: Index1D | None = None,
vidx: Index1D | None = None,
oidx: Index1DNorm | int | np.integer | None = None,
vidx: Index1DNorm | int | np.integer | None = None,
):
# check for any multi-indices that aren’t later checked in coerce_array
for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]:
Expand All @@ -237,6 +242,8 @@ def __init__( # noqa: PLR0913
if not isinstance(X, AnnData):
msg = "`X` has to be an AnnData object."
raise ValueError(msg)
assert oidx is not None
assert vidx is not None
self._init_as_view(X, oidx, vidx)
else:
self._init_as_actual(
Expand All @@ -256,7 +263,12 @@ def __init__( # noqa: PLR0913
filemode=filemode,
)

def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index):
def _init_as_view(
self,
adata_ref: AnnData,
oidx: Index1DNorm | int | np.integer,
vidx: Index1DNorm | int | np.integer,
):
if adata_ref.isbacked and adata_ref.is_view:
msg = (
"Currently, you cannot index repeatedly into a backed AnnData, "
Expand All @@ -277,6 +289,9 @@ def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index):
vidx += adata_ref.n_vars * (vidx < 0)
vidx = slice(vidx, vidx + 1, 1)
if adata_ref.is_view:
assert adata_ref._adata_ref is not None
assert adata_ref._oidx is not None
assert adata_ref._vidx is not None
prev_oidx, prev_vidx = adata_ref._oidx, adata_ref._vidx
adata_ref = adata_ref._adata_ref
oidx, vidx = _resolve_idxs((prev_oidx, prev_vidx), (oidx, vidx), adata_ref)
Expand Down Expand Up @@ -1004,7 +1019,9 @@ def _set_backed(self, attr, value):

write_attribute(self.file._file, attr, value)

def _normalize_indices(self, index: Index | None) -> tuple[slice, slice]:
def _normalize_indices(
self, index: Index | None
) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
return _normalize_indices(index, self.obs_names, self.var_names)

# TODO: this is not quite complete...
Expand Down
19 changes: 6 additions & 13 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
from .xarray import Dataset2D

if TYPE_CHECKING:
from ..compat import Index, Index1D
from ..compat import Index, Index1D, Index1DNorm


def _normalize_indices(
index: Index | None, names0: pd.Index, names1: pd.Index
) -> tuple[slice, slice]:
) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
# deal with tuples of length 1
if isinstance(index, tuple) and len(index) == 1:
index = index[0]
# deal with pd.Series
if isinstance(index, pd.Series):
index: Index = index.values
index = index.values
if isinstance(index, tuple):
# TODO: The series should probably be aligned first
index = tuple(i.values if isinstance(i, pd.Series) else i for i in index)
Expand All @@ -36,15 +36,8 @@ def _normalize_indices(


def _normalize_index( # noqa: PLR0911, PLR0912
indexer: slice
| np.integer
| int
| str
| Sequence[bool | int | np.integer]
| np.ndarray
| pd.Index,
index: pd.Index,
) -> slice | int | np.ndarray: # ndarray of int or bool
indexer: Index1D, index: pd.Index
) -> Index1DNorm | int | np.integer:
# TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough.
if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64):
msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}"
Expand Down Expand Up @@ -212,7 +205,7 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index):

# Registration for SparseDataset occurs in sparse_dataset.py
@_subset.register(h5py.Dataset)
def _subset_dataset(d, subset_idx):
def _subset_dataset(d: h5py.Dataset, subset_idx: Index):
if not isinstance(subset_idx, tuple):
subset_idx = (subset_idx,)
ordered = list(subset_idx)
Expand Down
8 changes: 5 additions & 3 deletions src/anndata/_core/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Mapping, Sequence
from typing import ClassVar

from ..compat import CSMatrix
from ..compat import CSMatrix, Index, Index1DNorm
from .aligned_mapping import AxisArraysView
from .anndata import AnnData
from .sparse_dataset import BaseCompressedSparseDataset
Expand Down Expand Up @@ -121,7 +121,7 @@ def var_names(self) -> pd.Index[str]:
def obs_names(self) -> pd.Index[str]:
return self._adata.obs_names

def __getitem__(self, index):
def __getitem__(self, index: Index) -> Raw:
oidx, vidx = self._normalize_indices(index)

# To preserve two dimensional shape
Expand Down Expand Up @@ -169,7 +169,9 @@ def to_adata(self) -> AnnData:
uns=self._adata.uns.copy(),
)

def _normalize_indices(self, packed_index):
def _normalize_indices(
self, packed_index: Index
) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
# deal with slicing with pd.Series
if isinstance(packed_index, pd.Series):
packed_index = packed_index.values
Expand Down
34 changes: 20 additions & 14 deletions src/anndata/_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
from collections.abc import Callable, Iterable, KeysView, Sequence
from typing import Any, ClassVar

from numpy.typing import NDArray

from anndata import AnnData

from ..compat import Index1DNorm


@contextmanager
def view_update(adata_view: AnnData, attr_name: str, keys: tuple[str, ...]):
Expand Down Expand Up @@ -433,18 +437,24 @@ class AwkwardArrayView:
pass


def _resolve_idxs(old, new, adata):
t = tuple(_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
return t
def _resolve_idxs(
old: tuple[Index1DNorm, Index1DNorm],
new: tuple[Index1DNorm, Index1DNorm],
adata: AnnData,
) -> tuple[Index1DNorm, Index1DNorm]:
o, v = (_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
return o, v


@singledispatch
def _resolve_idx(old, new, l):
return old[new]
def _resolve_idx(old: Index1DNorm, new: Index1DNorm, l: Literal[0, 1]) -> Index1DNorm:
raise NotImplementedError


@_resolve_idx.register(np.ndarray)
def _resolve_idx_ndarray(old, new, l):
def _resolve_idx_ndarray(
old: NDArray[np.bool_] | NDArray[np.integer], new: Index1DNorm, l: Literal[0, 1]
) -> NDArray[np.bool_] | NDArray[np.integer]:
if is_bool_dtype(old) and is_bool_dtype(new):
mask_new = np.zeros_like(old)
mask_new[np.flatnonzero(old)[new]] = True
Expand All @@ -454,21 +464,17 @@ def _resolve_idx_ndarray(old, new, l):
return old[new]


@_resolve_idx.register(np.integer)
@_resolve_idx.register(int)
def _resolve_idx_scalar(old, new, l):
return np.array([old])[new]


@_resolve_idx.register(slice)
def _resolve_idx_slice(old, new, l):
def _resolve_idx_slice(
old: slice, new: Index1DNorm, l: Literal[0, 1]
) -> slice | NDArray[np.integer]:
if isinstance(new, slice):
return _resolve_idx_slice_slice(old, new, l)
else:
return np.arange(*old.indices(l))[new]


def _resolve_idx_slice_slice(old, new, l):
def _resolve_idx_slice_slice(old: slice, new: slice, l: Literal[0, 1]) -> slice:
r = range(*old.indices(l))[new]
# Convert back to slice
start, stop, step = r.start, r.stop, r.step
Expand Down
25 changes: 13 additions & 12 deletions src/anndata/_core/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,6 @@ def iloc(self) -> Dataset2DIlocIndexer:
Handler class for doing the iloc-style indexing using :meth:`~xarray.Dataset.isel`.
"""

@dataclass(frozen=True)
class IlocGetter:
_ds: XDataset
_coord: str

def __getitem__(self, idx) -> Dataset2D:
# xarray seems to have some code looking for a second entry in tuples,
# so we unpack the tuple
if isinstance(idx, tuple) and len(idx) == 1:
idx = idx[0]
return Dataset2D(self._ds.isel(**{self._coord: idx}))

return IlocGetter(self.ds, self.index_dim)

# See https://github.com/pydata/xarray/blob/568f3c1638d2d34373408ce2869028faa3949446/xarray/core/dataset.py#L1239-L1248
Expand Down Expand Up @@ -402,3 +390,16 @@ def reindex(
def _items(self):
for col in self:
yield col, self[col]


@dataclass(frozen=True)
class IlocGetter:
_ds: XDataset
_coord: str

def __getitem__(self, idx) -> Dataset2D:
# xarray seems to have some code looking for a second entry in tuples,
# so we unpack the tuple
if isinstance(idx, tuple) and len(idx) == 1:
idx = idx[0]
return Dataset2D(self._ds.isel(**{self._coord: idx}))
22 changes: 20 additions & 2 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from codecs import decode
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from functools import cache, partial, singledispatch
from importlib.util import find_spec
from types import EllipsisType
Expand All @@ -12,13 +12,15 @@
import numpy as np
import pandas as pd
import scipy
from numpy.typing import NDArray
from packaging.version import Version
from zarr import Array as ZarrArray # noqa: F401
from zarr import Group as ZarrGroup

if TYPE_CHECKING:
from typing import Any


#############################
# scipy sparse array comapt #
#############################
Expand All @@ -32,7 +34,23 @@ class Empty:
pass


Index1D = slice | int | str | np.int64 | np.ndarray | pd.Series
Index1DNorm = slice | NDArray[np.bool] | NDArray[np.integer]
# TODO: pd.Index[???]
Index1D = (
# 0D index
int
| str
| np.int64
# normalized 1D idex
| Index1DNorm
# different containers for mask, obs/varnames, or numerical index
| pd.Series # bool, int, str
| pd.Index
Copy link
Copy Markdown
Member Author

@flying-sheep flying-sheep Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think we test these (at least not via subset_func and therefore with test_views_of_views)

| NDArray[np.str_]
| Sequence[int]
| Sequence[str]
| Sequence[bool]
)
IndexRest = Index1D | EllipsisType
Index = (
IndexRest
Expand Down
Loading
Loading