Skip to content

Commit 49b4bb4

Browse files
authored
Backport CS* tightening (#1889)
1 parent f996bb1 commit 49b4bb4

21 files changed

Lines changed: 103 additions & 96 deletions

src/anndata/_core/aligned_mapping.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99

1010
import numpy as np
1111
import pandas as pd
12-
from scipy.sparse import spmatrix
1312

1413
from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning
15-
from ..compat import AwkArray
14+
from ..compat import AwkArray, CSArray, CSMatrix
1615
from ..utils import (
1716
axis_len,
1817
convert_to_dict,
@@ -36,7 +35,7 @@
3635
OneDIdx = Sequence[int] | Sequence[bool] | slice
3736
TwoDIdx = tuple[OneDIdx, OneDIdx]
3837
# TODO: pd.DataFrame only allowed in AxisArrays?
39-
Value = pd.DataFrame | spmatrix | np.ndarray
38+
Value = pd.DataFrame | CSArray | CSMatrix | np.ndarray
4039

4140
P = TypeVar("P", bound="AlignedMappingBase")
4241
"""Parent mapping an AlignedView is based on."""

src/anndata/_core/anndata.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,13 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
195195

196196
def __init__(
197197
self,
198-
X: np.ndarray | sparse.spmatrix | pd.DataFrame | None = None,
198+
X: ArrayDataStructureType | pd.DataFrame | None = None,
199199
obs: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
200200
var: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
201201
uns: Mapping[str, Any] | None = None,
202202
obsm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
203203
varm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
204-
layers: Mapping[str, np.ndarray | sparse.spmatrix] | None = None,
204+
layers: Mapping[str, ArrayDataStructureType] | None = None,
205205
raw: Mapping[str, Any] | None = None,
206206
dtype: np.dtype | type | str | None = None,
207207
shape: tuple[int, int] | None = None,
@@ -557,7 +557,7 @@ def X(self) -> ArrayDataStructureType | None:
557557
# return X
558558

559559
@X.setter
560-
def X(self, value: np.ndarray | sparse.spmatrix | SpArray | None):
560+
def X(self, value: ArrayDataStructureType | None):
561561
if value is None:
562562
if self.isbacked:
563563
msg = "Cannot currently remove data matrix from backed object."
@@ -1159,7 +1159,7 @@ def _inplace_subset_obs(self, index: Index1D):
11591159
self._init_as_actual(adata_subset)
11601160

11611161
# TODO: Update, possibly remove
1162-
def __setitem__(self, index: Index, val: float | np.ndarray | sparse.spmatrix):
1162+
def __setitem__(self, index: Index, val: float | ArrayDataStructureType):
11631163
if self.is_view:
11641164
msg = "Object is view and cannot be accessed with `[]`."
11651165
raise ValueError(msg)

src/anndata/_core/index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..compat import AwkArray, DaskArray, SpArray
1414

1515
if TYPE_CHECKING:
16-
from ..compat import Index, Index1D
16+
from ..compat import CSArray, CSMatrix, Index, Index1D
1717

1818

1919
def _normalize_indices(
@@ -182,7 +182,7 @@ def _subset_dask(a: DaskArray, subset_idx: Index):
182182

183183
@_subset.register(spmatrix)
184184
@_subset.register(SpArray)
185-
def _subset_sparse(a: spmatrix | SpArray, subset_idx: Index):
185+
def _subset_sparse(a: CSMatrix | CSArray, subset_idx: Index):
186186
# Correcting for indexing behaviour of sparse.spmatrix
187187
if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx):
188188
first_idx = subset_idx[0]

src/anndata/_core/merge.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141

4242
from pandas.api.extensions import ExtensionDtype
4343

44+
from ..compat import CSArray, CSMatrix
45+
4446
T = TypeVar("T")
4547

4648
###################
@@ -174,7 +176,7 @@ def equal_sparse(a, b) -> bool:
174176

175177
xp = array_api_compat.array_namespace(a.data)
176178

177-
if isinstance(b, CupySparseMatrix | sparse.spmatrix | SpArray):
179+
if isinstance(b, CupySparseMatrix | spmatrix | SpArray):
178180
if isinstance(a, CupySparseMatrix):
179181
# Comparison broken for CSC matrices
180182
# https://github.com/cupy/cupy/issues/7757
@@ -205,8 +207,8 @@ def equal_awkward(a, b) -> bool:
205207
return ak.almost_equal(a, b)
206208

207209

208-
def as_sparse(x, use_sparse_array=False):
209-
if not isinstance(x, sparse.spmatrix | SpArray):
210+
def as_sparse(x, *, use_sparse_array: bool = False) -> CSMatrix | CSArray:
211+
if not isinstance(x, spmatrix | SpArray):
210212
if CAN_USE_SPARSE_ARRAY and use_sparse_array:
211213
return sparse.csr_array(x)
212214
return sparse.csr_matrix(x)
@@ -537,7 +539,7 @@ def apply(self, el, *, axis, fill_value=None):
537539
return el
538540
if isinstance(el, pd.DataFrame):
539541
return self._apply_to_df(el, axis=axis, fill_value=fill_value)
540-
elif isinstance(el, sparse.spmatrix | SpArray | CupySparseMatrix):
542+
elif isinstance(el, spmatrix | SpArray | CupySparseMatrix):
541543
return self._apply_to_sparse(el, axis=axis, fill_value=fill_value)
542544
elif isinstance(el, AwkArray):
543545
return self._apply_to_awkward(el, axis=axis, fill_value=fill_value)
@@ -615,7 +617,7 @@ def _apply_to_array(self, el, *, axis, fill_value=None):
615617
)
616618

617619
def _apply_to_sparse(
618-
self, el: sparse.spmatrix | SpArray, *, axis, fill_value=None
620+
self, el: CSMatrix | CSArray, *, axis, fill_value=None
619621
) -> spmatrix:
620622
if isinstance(el, CupySparseMatrix):
621623
from cupyx.scipy import sparse
@@ -730,11 +732,8 @@ def default_fill_value(els):
730732
This is largely due to backwards compat, and might not be the ideal solution.
731733
"""
732734
if any(
733-
isinstance(el, sparse.spmatrix | SpArray)
734-
or (
735-
isinstance(el, DaskArray)
736-
and isinstance(el._meta, sparse.spmatrix | SpArray)
737-
)
735+
isinstance(el, spmatrix | SpArray)
736+
or (isinstance(el, DaskArray) and isinstance(el._meta, spmatrix | SpArray))
738737
for el in els
739738
):
740739
return 0
@@ -830,7 +829,7 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
830829
],
831830
axis=axis,
832831
)
833-
elif any(isinstance(a, sparse.spmatrix | SpArray) for a in arrays):
832+
elif any(isinstance(a, spmatrix | SpArray) for a in arrays):
834833
sparse_stack = (sparse.vstack, sparse.hstack)[axis]
835834
use_sparse_array = any(issubclass(type(a), SpArray) for a in arrays)
836835
return sparse_stack(
@@ -941,7 +940,7 @@ def gen_outer_reindexers(els, shapes, new_index: pd.Index, *, axis=0):
941940

942941
def missing_element(
943942
n: int,
944-
els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray],
943+
els: list[CSArray | CSMatrix | np.ndarray | DaskArray],
945944
axis: Literal[0, 1] = 0,
946945
fill_value: Any | None = None,
947946
off_axis_size: int = 0,

src/anndata/_core/raw.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from collections.abc import Mapping, Sequence
1818
from typing import ClassVar
1919

20-
from scipy import sparse
21-
20+
from ..compat import CSMatrix
2221
from .aligned_mapping import AxisArraysView
2322
from .anndata import AnnData
2423
from .sparse_dataset import BaseCompressedSparseDataset
@@ -31,7 +30,7 @@ class Raw:
3130
def __init__(
3231
self,
3332
adata: AnnData,
34-
X: np.ndarray | sparse.spmatrix | None = None,
33+
X: np.ndarray | CSMatrix | None = None,
3534
var: pd.DataFrame | Mapping[str, Sequence] | None = None,
3635
varm: AxisArrays | Mapping[str, np.ndarray] | None = None,
3736
):
@@ -67,7 +66,7 @@ def _get_X(self, layer=None):
6766
return self.X
6867

6968
@property
70-
def X(self) -> BaseCompressedSparseDataset | np.ndarray | sparse.spmatrix:
69+
def X(self) -> BaseCompressedSparseDataset | np.ndarray | CSMatrix:
7170
# TODO: Handle unsorted array of integer indices for h5py.Datasets
7271
if not self._adata.isbacked:
7372
return self._X

src/anndata/_core/sparse_dataset.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from scipy.sparse._compressed import _cs_matrix
4141

4242
from .._types import GroupStorageType
43-
from ..compat import H5Array
43+
from ..compat import CSArray, CSMatrix, H5Array
4444
from .index import Index, Index1D
4545
else:
4646
from scipy.sparse import spmatrix as _cs_matrix
@@ -67,7 +67,7 @@ class BackedSparseMatrix(_cs_matrix):
6767
indices: GroupStorageType
6868
indptr: np.ndarray
6969

70-
def copy(self) -> ss.csr_matrix | ss.csc_matrix:
70+
def copy(self) -> CSMatrix:
7171
if isinstance(self.data, h5py.Dataset):
7272
return sparse_dataset(self.data.parent).to_memory()
7373
if isinstance(self.data, ZarrArray):
@@ -433,9 +433,7 @@ def __repr__(self) -> str:
433433
name = type(self).__name__.removeprefix("_")
434434
return f"{name}: backend {self.backend}, shape {self.shape}, data_dtype {self.dtype}"
435435

436-
def __getitem__(
437-
self, index: Index | tuple[()]
438-
) -> float | ss.csr_matrix | ss.csc_matrix | SpArray:
436+
def __getitem__(self, index: Index | tuple[()]) -> float | CSMatrix | CSArray:
439437
indices = self._normalize_index(index)
440438
row, col = indices
441439
mtx = self._to_backed()
@@ -494,7 +492,7 @@ def __setitem__(self, index: Index | tuple[()], value) -> None:
494492
mock_matrix[row, col] = value
495493

496494
# TODO: split to other classes?
497-
def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None:
495+
def append(self, sparse_matrix: CSMatrix | CSArray) -> None:
498496
"""Append an in-memory or on-disk sparse matrix to the current object's store.
499497
500498
Parameters
@@ -620,7 +618,7 @@ def _to_backed(self) -> BackedSparseMatrix:
620618
mtx.indptr = self._indptr
621619
return mtx
622620

623-
def to_memory(self) -> ss.csr_matrix | ss.csc_matrix | SpArray:
621+
def to_memory(self) -> CSMatrix | CSArray:
624622
format_class = get_memory_class(
625623
self.format, use_sparray_in_io=settings.use_sparse_array_on_read
626624
)

src/anndata/_io/h5ad.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from typing import Any, Literal
3939

4040
from .._core.file_backing import AnnDataFileManager
41+
from ..compat import CSMatrix
4142

4243
T = TypeVar("T")
4344

@@ -115,7 +116,7 @@ def write_h5ad(
115116
def write_sparse_as_dense(
116117
f: h5py.Group,
117118
key: str,
118-
value: sparse.spmatrix | BaseCompressedSparseDataset,
119+
value: CSMatrix | BaseCompressedSparseDataset,
119120
*,
120121
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
121122
):
@@ -172,7 +173,7 @@ def read_h5ad(
172173
backed: Literal["r", "r+"] | bool | None = None,
173174
*,
174175
as_sparse: Sequence[str] = (),
175-
as_sparse_fmt: type[sparse.spmatrix] = sparse.csr_matrix,
176+
as_sparse_fmt: type[CSMatrix] = sparse.csr_matrix,
176177
chunk_size: int = 6000, # TODO, probably make this 2d chunks
177178
) -> AnnData:
178179
"""\
@@ -273,7 +274,7 @@ def callback(func, elem_name: str, elem, iospec):
273274
def _read_raw(
274275
f: h5py.File | AnnDataFileManager,
275276
as_sparse: Collection[str] = (),
276-
rdasp: Callable[[h5py.Dataset], sparse.spmatrix] | None = None,
277+
rdasp: Callable[[h5py.Dataset], CSMatrix] | None = None,
277278
*,
278279
attrs: Collection[str] = ("X", "var", "varm"),
279280
) -> dict:
@@ -346,7 +347,7 @@ def read_dataset(dataset: h5py.Dataset):
346347

347348
@report_read_key_on_error
348349
def read_dense_as_sparse(
349-
dataset: h5py.Dataset, sparse_format: sparse.spmatrix, axis_chunk: int
350+
dataset: h5py.Dataset, sparse_format: CSMatrix, axis_chunk: int
350351
):
351352
if sparse_format == sparse.csr_matrix:
352353
return read_dense_as_csr(dataset, axis_chunk)

src/anndata/_io/specs/lazy_methods.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections.abc import Generator, Mapping, Sequence
2121
from typing import Literal, ParamSpec, TypeVar
2222

23-
from ...compat import DaskArray, H5File, SpArray
23+
from ...compat import CSArray, CSMatrix, DaskArray, H5File
2424
from .registry import DaskReader
2525

2626
BlockInfo = Mapping[
@@ -72,7 +72,7 @@ def make_dask_chunk(
7272
path_or_sparse_dataset: Path | D,
7373
elem_name: str,
7474
block_info: BlockInfo | None = None,
75-
) -> sparse.csr_matrix | sparse.csc_matrix | SpArray:
75+
) -> CSMatrix | CSArray:
7676
if block_info is None:
7777
msg = "Block info is required"
7878
raise ValueError(msg)

src/anndata/_io/specs/methods.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,9 @@
5151
from numpy import typing as npt
5252
from numpy.typing import NDArray
5353

54-
from anndata._types import ArrayStorageType, GroupStorageType
55-
from anndata.compat import SpArray
56-
from anndata.typing import AxisStorable, InMemoryArrayOrScalarType
57-
54+
from ..._types import ArrayStorageType, GroupStorageType
55+
from ...compat import CSArray, CSMatrix
56+
from ...typing import AxisStorable, InMemoryArrayOrScalarType
5857
from .registry import Reader, Writer
5958

6059
####################
@@ -127,7 +126,7 @@ def wrapper(
127126
@_REGISTRY.register_read(H5Array, IOSpec("", ""))
128127
def read_basic(
129128
elem: H5File | H5Group | H5Array, *, _reader: Reader
130-
) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | sparse.spmatrix | SpArray:
129+
) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | CSMatrix | CSArray:
131130
from anndata._io import h5ad
132131

133132
warn(
@@ -149,7 +148,7 @@ def read_basic(
149148
@_REGISTRY.register_read(ZarrArray, IOSpec("", ""))
150149
def read_basic_zarr(
151150
elem: ZarrGroup | ZarrArray, *, _reader: Reader
152-
) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | sparse.spmatrix | SpArray:
151+
) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | CSMatrix | CSArray:
153152
from anndata._io import zarr
154153

155154
warn(
@@ -590,7 +589,7 @@ def write_recarray_zarr(
590589
def write_sparse_compressed(
591590
f: GroupStorageType,
592591
key: str,
593-
value: sparse.spmatrix | SpArray,
592+
value: CSMatrix | CSArray,
594593
*,
595594
_writer: Writer,
596595
fmt: Literal["csr", "csc"],
@@ -756,9 +755,7 @@ def chunk_slice(start: int, stop: int) -> tuple[slice | None, slice | None]:
756755
@_REGISTRY.register_read(H5Group, IOSpec("csr_matrix", "0.1.0"))
757756
@_REGISTRY.register_read(ZarrGroup, IOSpec("csc_matrix", "0.1.0"))
758757
@_REGISTRY.register_read(ZarrGroup, IOSpec("csr_matrix", "0.1.0"))
759-
def read_sparse(
760-
elem: GroupStorageType, *, _reader: Reader
761-
) -> sparse.spmatrix | SpArray:
758+
def read_sparse(elem: GroupStorageType, *, _reader: Reader) -> CSMatrix | CSArray:
762759
return sparse_dataset(elem).to_memory()
763760

764761

src/anndata/abc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from typing import ClassVar, Literal
88

99
import numpy as np
10-
from scipy.sparse import csc_matrix, csr_matrix
1110

12-
from .compat import Index, SpArray
11+
from .compat import CSArray, CSMatrix, Index
1312

1413

1514
__all__ = ["CSRDataset", "CSCDataset"]
@@ -31,7 +30,7 @@ class _AbstractCSDataset(ABC):
3130
"""Which file type is used on-disk."""
3231

3332
@abstractmethod
34-
def __getitem__(self, index: Index) -> float | csr_matrix | csc_matrix | SpArray:
33+
def __getitem__(self, index: Index) -> float | CSMatrix | CSArray:
3534
"""Load a slice or an element from the sparse dataset into memory.
3635
3736
Parameters
@@ -45,7 +44,7 @@ def __getitem__(self, index: Index) -> float | csr_matrix | csc_matrix | SpArray
4544
"""
4645

4746
@abstractmethod
48-
def to_memory(self) -> csr_matrix | csc_matrix | SpArray:
47+
def to_memory(self) -> CSMatrix | CSArray:
4948
"""Load the sparse dataset into memory.
5049
5150
Returns

0 commit comments

Comments
 (0)