Skip to content

Commit 89fbfdc

Browse files
authored
Fix index types (#2067)
1 parent 7ff7e3b commit 89fbfdc

10 files changed

Lines changed: 172 additions & 95 deletions

File tree

src/anndata/_core/anndata.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656

5757
from zarr.storage import StoreLike
5858

59-
from ..compat import Index1D, XDataset
59+
from ..compat import Index1D, Index1DNorm, XDataset
6060
from ..typing import XDataType
6161
from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
6262
from .index import Index
@@ -197,6 +197,11 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
197197

198198
_accessors: ClassVar[set[str]] = set()
199199

200+
# view attributes
201+
_adata_ref: AnnData | None
202+
_oidx: Index1DNorm | None
203+
_vidx: Index1DNorm | None
204+
200205
@old_positionals(
201206
"obsm",
202207
"varm",
@@ -226,8 +231,8 @@ def __init__( # noqa: PLR0913
226231
asview: bool = False,
227232
obsp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
228233
varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
229-
oidx: Index1D | None = None,
230-
vidx: Index1D | None = None,
234+
oidx: Index1DNorm | int | np.integer | None = None,
235+
vidx: Index1DNorm | int | np.integer | None = None,
231236
):
232237
# check for any multi-indices that aren’t later checked in coerce_array
233238
for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]:
@@ -237,6 +242,8 @@ def __init__( # noqa: PLR0913
237242
if not isinstance(X, AnnData):
238243
msg = "`X` has to be an AnnData object."
239244
raise ValueError(msg)
245+
assert oidx is not None
246+
assert vidx is not None
240247
self._init_as_view(X, oidx, vidx)
241248
else:
242249
self._init_as_actual(
@@ -256,7 +263,12 @@ def __init__( # noqa: PLR0913
256263
filemode=filemode,
257264
)
258265

259-
def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index):
266+
def _init_as_view(
267+
self,
268+
adata_ref: AnnData,
269+
oidx: Index1DNorm | int | np.integer,
270+
vidx: Index1DNorm | int | np.integer,
271+
):
260272
if adata_ref.isbacked and adata_ref.is_view:
261273
msg = (
262274
"Currently, you cannot index repeatedly into a backed AnnData, "
@@ -277,6 +289,9 @@ def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index):
277289
vidx += adata_ref.n_vars * (vidx < 0)
278290
vidx = slice(vidx, vidx + 1, 1)
279291
if adata_ref.is_view:
292+
assert adata_ref._adata_ref is not None
293+
assert adata_ref._oidx is not None
294+
assert adata_ref._vidx is not None
280295
prev_oidx, prev_vidx = adata_ref._oidx, adata_ref._vidx
281296
adata_ref = adata_ref._adata_ref
282297
oidx, vidx = _resolve_idxs((prev_oidx, prev_vidx), (oidx, vidx), adata_ref)
@@ -1004,7 +1019,9 @@ def _set_backed(self, attr, value):
10041019

10051020
write_attribute(self.file._file, attr, value)
10061021

1007-
def _normalize_indices(self, index: Index | None) -> tuple[slice, slice]:
1022+
def _normalize_indices(
1023+
self, index: Index | None
1024+
) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
10081025
return _normalize_indices(index, self.obs_names, self.var_names)
10091026

10101027
# TODO: this is not quite complete...

src/anndata/_core/index.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414
from .xarray import Dataset2D
1515

1616
if TYPE_CHECKING:
17-
from ..compat import Index, Index1D
17+
from ..compat import Index, Index1D, Index1DNorm
1818

1919

2020
def _normalize_indices(
2121
index: Index | None, names0: pd.Index, names1: pd.Index
22-
) -> tuple[slice, slice]:
22+
) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
2323
# deal with tuples of length 1
2424
if isinstance(index, tuple) and len(index) == 1:
2525
index = index[0]
2626
# deal with pd.Series
2727
if isinstance(index, pd.Series):
28-
index: Index = index.values
28+
index = index.values
2929
if isinstance(index, tuple):
3030
# TODO: The series should probably be aligned first
3131
index = tuple(i.values if isinstance(i, pd.Series) else i for i in index)
@@ -36,15 +36,8 @@ def _normalize_indices(
3636

3737

3838
def _normalize_index( # noqa: PLR0911, PLR0912
39-
indexer: slice
40-
| np.integer
41-
| int
42-
| str
43-
| Sequence[bool | int | np.integer]
44-
| np.ndarray
45-
| pd.Index,
46-
index: pd.Index,
47-
) -> slice | int | np.ndarray: # ndarray of int or bool
39+
indexer: Index1D, index: pd.Index
40+
) -> Index1DNorm | int | np.integer:
4841
# TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough.
4942
if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64):
5043
msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}"
@@ -212,7 +205,7 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index):
212205

213206
# Registration for SparseDataset occurs in sparse_dataset.py
214207
@_subset.register(h5py.Dataset)
215-
def _subset_dataset(d, subset_idx):
208+
def _subset_dataset(d: h5py.Dataset, subset_idx: Index):
216209
if not isinstance(subset_idx, tuple):
217210
subset_idx = (subset_idx,)
218211
ordered = list(subset_idx)

src/anndata/_core/raw.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections.abc import Mapping, Sequence
1818
from typing import ClassVar
1919

20-
from ..compat import CSMatrix
20+
from ..compat import CSMatrix, Index, Index1DNorm
2121
from .aligned_mapping import AxisArraysView
2222
from .anndata import AnnData
2323
from .sparse_dataset import BaseCompressedSparseDataset
@@ -121,7 +121,7 @@ def var_names(self) -> pd.Index[str]:
121121
def obs_names(self) -> pd.Index[str]:
122122
return self._adata.obs_names
123123

124-
def __getitem__(self, index):
124+
def __getitem__(self, index: Index) -> Raw:
125125
oidx, vidx = self._normalize_indices(index)
126126

127127
# To preserve two dimensional shape
@@ -169,7 +169,9 @@ def to_adata(self) -> AnnData:
169169
uns=self._adata.uns.copy(),
170170
)
171171

172-
def _normalize_indices(self, packed_index):
172+
def _normalize_indices(
173+
self, packed_index: Index
174+
) -> tuple[Index1DNorm | int | np.integer, Index1DNorm | int | np.integer]:
173175
# deal with slicing with pd.Series
174176
if isinstance(packed_index, pd.Series):
175177
packed_index = packed_index.values

src/anndata/_core/views.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@
2929
from collections.abc import Callable, Iterable, KeysView, Sequence
3030
from typing import Any, ClassVar
3131

32+
from numpy.typing import NDArray
33+
3234
from anndata import AnnData
3335

36+
from ..compat import Index1DNorm
37+
3438

3539
@contextmanager
3640
def view_update(adata_view: AnnData, attr_name: str, keys: tuple[str, ...]):
@@ -433,18 +437,24 @@ class AwkwardArrayView:
433437
pass
434438

435439

436-
def _resolve_idxs(old, new, adata):
437-
t = tuple(_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
438-
return t
440+
def _resolve_idxs(
441+
old: tuple[Index1DNorm, Index1DNorm],
442+
new: tuple[Index1DNorm, Index1DNorm],
443+
adata: AnnData,
444+
) -> tuple[Index1DNorm, Index1DNorm]:
445+
o, v = (_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
446+
return o, v
439447

440448

441449
@singledispatch
442-
def _resolve_idx(old, new, l):
443-
return old[new]
450+
def _resolve_idx(old: Index1DNorm, new: Index1DNorm, l: Literal[0, 1]) -> Index1DNorm:
451+
raise NotImplementedError
444452

445453

446454
@_resolve_idx.register(np.ndarray)
447-
def _resolve_idx_ndarray(old, new, l):
455+
def _resolve_idx_ndarray(
456+
old: NDArray[np.bool_] | NDArray[np.integer], new: Index1DNorm, l: Literal[0, 1]
457+
) -> NDArray[np.bool_] | NDArray[np.integer]:
448458
if is_bool_dtype(old) and is_bool_dtype(new):
449459
mask_new = np.zeros_like(old)
450460
mask_new[np.flatnonzero(old)[new]] = True
@@ -454,21 +464,17 @@ def _resolve_idx_ndarray(old, new, l):
454464
return old[new]
455465

456466

457-
@_resolve_idx.register(np.integer)
458-
@_resolve_idx.register(int)
459-
def _resolve_idx_scalar(old, new, l):
460-
return np.array([old])[new]
461-
462-
463467
@_resolve_idx.register(slice)
464-
def _resolve_idx_slice(old, new, l):
468+
def _resolve_idx_slice(
469+
old: slice, new: Index1DNorm, l: Literal[0, 1]
470+
) -> slice | NDArray[np.integer]:
465471
if isinstance(new, slice):
466472
return _resolve_idx_slice_slice(old, new, l)
467473
else:
468474
return np.arange(*old.indices(l))[new]
469475

470476

471-
def _resolve_idx_slice_slice(old, new, l):
477+
def _resolve_idx_slice_slice(old: slice, new: slice, l: Literal[0, 1]) -> slice:
472478
r = range(*old.indices(l))[new]
473479
# Convert back to slice
474480
start, stop, step = r.start, r.stop, r.step

src/anndata/_core/xarray.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,6 @@ def iloc(self) -> Dataset2DIlocIndexer:
184184
Handler class for doing the iloc-style indexing using :meth:`~xarray.Dataset.isel`.
185185
"""
186186

187-
@dataclass(frozen=True)
188-
class IlocGetter:
189-
_ds: XDataset
190-
_coord: str
191-
192-
def __getitem__(self, idx) -> Dataset2D:
193-
# xarray seems to have some code looking for a second entry in tuples,
194-
# so we unpack the tuple
195-
if isinstance(idx, tuple) and len(idx) == 1:
196-
idx = idx[0]
197-
return Dataset2D(self._ds.isel(**{self._coord: idx}))
198-
199187
return IlocGetter(self.ds, self.index_dim)
200188

201189
# See https://github.com/pydata/xarray/blob/568f3c1638d2d34373408ce2869028faa3949446/xarray/core/dataset.py#L1239-L1248
@@ -402,3 +390,16 @@ def reindex(
402390
def _items(self):
403391
for col in self:
404392
yield col, self[col]
393+
394+
395+
@dataclass(frozen=True)
396+
class IlocGetter:
397+
_ds: XDataset
398+
_coord: str
399+
400+
def __getitem__(self, idx) -> Dataset2D:
401+
# xarray seems to have some code looking for a second entry in tuples,
402+
# so we unpack the tuple
403+
if isinstance(idx, tuple) and len(idx) == 1:
404+
idx = idx[0]
405+
return Dataset2D(self._ds.isel(**{self._coord: idx}))

src/anndata/compat/__init__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from codecs import decode
4-
from collections.abc import Mapping
4+
from collections.abc import Mapping, Sequence
55
from functools import cache, partial, singledispatch
66
from importlib.util import find_spec
77
from types import EllipsisType
@@ -12,13 +12,15 @@
1212
import numpy as np
1313
import pandas as pd
1414
import scipy
15+
from numpy.typing import NDArray
1516
from packaging.version import Version
1617
from zarr import Array as ZarrArray # noqa: F401
1718
from zarr import Group as ZarrGroup
1819

1920
if TYPE_CHECKING:
2021
from typing import Any
2122

23+
2224
#############################
2325
# scipy sparse array comapt #
2426
#############################
@@ -32,7 +34,26 @@ class Empty:
3234
pass
3335

3436

35-
Index1D = slice | int | str | np.int64 | np.ndarray | pd.Series
37+
Index1DNorm = slice | NDArray[np.bool_] | NDArray[np.integer]
38+
# TODO: pd.Index[???]
39+
Index1D = (
40+
# 0D index
41+
int
42+
| str
43+
| np.int64
44+
# normalized 1D idex
45+
| Index1DNorm
46+
# different containers for mask, obs/varnames, or numerical index
47+
| Sequence[int]
48+
| Sequence[str]
49+
| Sequence[bool]
50+
| pd.Series # bool, int, str
51+
| pd.Index
52+
| NDArray[np.str_]
53+
| np.matrix # bool
54+
| CSMatrix # bool
55+
| CSArray # bool
56+
)
3657
IndexRest = Index1D | EllipsisType
3758
Index = (
3859
IndexRest

0 commit comments

Comments
 (0)