Skip to content

Commit a92bda8

Browse files
(fix): correct typing of AnnData.X (#1616)
Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent 8f3299b commit a92bda8

10 files changed

Lines changed: 62 additions & 86 deletions

File tree

docs/concatenation.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ When the variables present in the objects to be concatenated aren't exactly the
5454
This is otherwise called taking the `"inner"` (intersection) or `"outer"` (union) join.
5555
For example, given two anndata objects with differing variables:
5656

57-
>>> a = AnnData(sparse.eye(3), var=pd.DataFrame(index=list("abc")))
58-
>>> b = AnnData(sparse.eye(2), var=pd.DataFrame(index=list("ba")))
57+
>>> a = AnnData(sparse.eye(3, format="csr"), var=pd.DataFrame(index=list("abc")))
58+
>>> b = AnnData(sparse.eye(2, format="csr"), var=pd.DataFrame(index=list("ba")))
5959
>>> ad.concat([a, b], join="inner").X.toarray()
6060
array([[1., 0.],
6161
[0., 1.],
@@ -208,11 +208,11 @@ Note that comparisons are made after indices are aligned.
208208
That is, if the objects only share a subset of indices on the alternative axis, it's only required that values for those indices match when using a strategy like `"same"`.
209209

210210
>>> a = AnnData(
211-
... sparse.eye(3),
211+
... sparse.eye(3, format="csr"),
212212
... var=pd.DataFrame({"nums": [1, 2, 3]}, index=list("abc"))
213213
... )
214214
>>> b = AnnData(
215-
... sparse.eye(2),
215+
... sparse.eye(2, format="csr"),
216216
... var=pd.DataFrame({"nums": [2, 1]}, index=list("ba"))
217217
... )
218218
>>> ad.concat([a, b], merge="same").var

docs/release-notes/1616.doc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Correct {attr}`anndata.AnnData.X` type to include {class}`~anndata.experimental.CSRDataset` and {class}`~anndata.experimental.CSCDataset` as possible types {user}`ilan-gold`

src/anndata/_core/anndata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252
from os import PathLike
5353
from typing import Any, Literal
5454

55+
from .._types import ArrayDataStructureType
5556
from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
5657
from .index import Index, Index1D
57-
from .views import ArrayView
5858

5959

6060
# for backwards compat
@@ -541,7 +541,7 @@ def shape(self) -> tuple[int, int]:
541541
return self.n_obs, self.n_vars
542542

543543
@property
544-
def X(self) -> np.ndarray | sparse.spmatrix | SpArray | ArrayView | None:
544+
def X(self) -> ArrayDataStructureType | None:
545545
"""Data matrix of shape :attr:`n_obs` × :attr:`n_vars`."""
546546
if self.isbacked:
547547
if not self.file.is_open:

src/anndata/_core/merge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,8 +737,8 @@ def gen_reindexer(new_var: pd.Index, cur_var: pd.Index):
737737
Usage
738738
-----
739739
740-
>>> a = AnnData(sparse.eye(3), var=pd.DataFrame(index=list("abc")))
741-
>>> b = AnnData(sparse.eye(2), var=pd.DataFrame(index=list("ba")))
740+
>>> a = AnnData(sparse.eye(3, format="csr"), var=pd.DataFrame(index=list("abc")))
741+
>>> b = AnnData(sparse.eye(2, format="csr"), var=pd.DataFrame(index=list("ba")))
742742
>>> reindexer = gen_reindexer(a.var_names, b.var_names)
743743
>>> sparse.vstack([a.X, reindexer(b.X)]).toarray()
744744
array([[1., 0., 0.],

src/anndata/_core/storage.py

Lines changed: 30 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from enum import Enum
5-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Union, get_args
65

76
import numpy as np
87
import pandas as pd
@@ -25,48 +24,27 @@
2524
join_english,
2625
raise_value_error_if_multiindex_columns,
2726
)
28-
from .sparse_dataset import BaseCompressedSparseDataset
27+
from .sparse_dataset import CSCDataset, CSRDataset
2928

3029
if TYPE_CHECKING:
31-
from collections.abc import Generator
32-
from typing import Any
30+
from typing import Any, TypeAlias
3331

34-
35-
class ArrayDataStructureType(Enum):
36-
# Memory
37-
Array = (np.ndarray, "np.ndarray")
38-
Masked = (ma.MaskedArray, "numpy.ma.core.MaskedArray")
39-
Sparse = (sparse.spmatrix, "scipy.sparse.spmatrix")
40-
SparseArray = (SpArray, "scipy.sparse.sparray")
41-
AwkArray = (AwkArray, "awkward.Array")
42-
# Backed
43-
HDF5Dataset = (H5Array, "h5py.Dataset")
44-
ZarrArray = (ZarrArray, "zarr.Array")
45-
ZappyArray = (ZappyArray, "zappy.base.ZappyArray")
46-
BackedSparseMatrix = (
47-
BaseCompressedSparseDataset,
48-
"anndata.experimental.[CSC,CSR]Dataset",
49-
)
50-
# Distributed
51-
DaskArray = (DaskArray, "dask.array.Array")
52-
CupyArray = (CupyArray, "cupy.ndarray")
53-
CupySparseMatrix = (CupySparseMatrix, "cupyx.scipy.sparse.spmatrix")
54-
55-
@property
56-
def cls(self):
57-
return self.value[0]
58-
59-
@property
60-
def qualname(self):
61-
return self.value[1]
62-
63-
@classmethod
64-
def classes(cls) -> tuple[type, ...]:
65-
return tuple(v.cls for v in cls)
66-
67-
@classmethod
68-
def qualnames(cls) -> Generator[str, None, None]:
69-
yield from (v.qualname for v in cls)
32+
ArrayDataStructureType: TypeAlias = Union[
33+
np.ndarray,
34+
ma.MaskedArray,
35+
sparse.csr_matrix,
36+
sparse.csc_matrix,
37+
SpArray,
38+
AwkArray,
39+
H5Array,
40+
ZarrArray,
41+
ZappyArray,
42+
CSRDataset,
43+
CSCDataset,
44+
DaskArray,
45+
CupyArray,
46+
CupySparseMatrix,
47+
]
7048

7149

7250
def coerce_array(
@@ -81,12 +59,21 @@ def coerce_array(
8159
if allow_array_like and np.isscalar(value):
8260
return value
8361
# If value is one of the allowed types, return it
84-
if isinstance(value, ArrayDataStructureType.classes()):
62+
array_data_structure_types = get_args(ArrayDataStructureType)
63+
if isinstance(value, array_data_structure_types):
8564
if isinstance(value, np.matrix):
8665
msg = f"{name} should not be a np.matrix, use np.ndarray instead."
8766
warnings.warn(msg, ImplicitModificationWarning)
8867
value = value.A
8968
return value
69+
elif isinstance(value, sparse.spmatrix):
70+
msg = (
71+
f"AnnData previously had undefined behavior around matrices of type {type(value)}."
72+
"In 0.12, passing in this type will throw an error. Please convert to a supported type."
73+
"Continue using for this minor version at your own risk."
74+
)
75+
warnings.warn(msg, FutureWarning)
76+
return value
9077
if isinstance(value, pd.DataFrame):
9178
if allow_df:
9279
raise_value_error_if_multiindex_columns(value, name)
@@ -100,7 +87,7 @@ def coerce_array(
10087
except (ValueError, TypeError) as _e:
10188
e = _e
10289
# if value isn’t the right type or convertible, raise an error
103-
msg = f"{name} needs to be of one of {join_english(ArrayDataStructureType.qualnames())}, not {type(value)}."
90+
msg = f"{name} needs to be of one of {join_english(map(str, array_data_structure_types))}, not {type(value)}."
10491
if e is not None:
10592
msg += " (Failed to convert it to an array, see above for details.)"
10693
raise ValueError(msg) from e

src/anndata/_types.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,13 @@
88

99
import numpy as np
1010
import pandas as pd
11-
from numpy.typing import NDArray
12-
from scipy import sparse
1311

1412
from anndata._core.anndata import AnnData
1513

16-
from ._core.sparse_dataset import BaseCompressedSparseDataset
14+
from ._core.storage import ArrayDataStructureType
1715
from .compat import (
18-
AwkArray,
19-
CupyArray,
20-
CupySparseMatrix,
21-
DaskArray,
2216
H5Array,
2317
H5Group,
24-
SpArray,
25-
ZappyArray,
2618
ZarrArray,
2719
ZarrGroup,
2820
)
@@ -34,6 +26,7 @@
3426
from anndata._io.specs.registry import DaskReader
3527

3628
from ._io.specs.registry import IOSpec, Reader, Writer
29+
from .compat import DaskArray
3730

3831
__all__ = [
3932
"ArrayStorageType",
@@ -42,21 +35,7 @@
4235
]
4336

4437
InMemoryArrayOrScalarType: TypeAlias = Union[
45-
NDArray,
46-
np.ma.MaskedArray,
47-
sparse.spmatrix,
48-
SpArray,
49-
H5Array,
50-
ZarrArray,
51-
ZappyArray,
52-
BaseCompressedSparseDataset,
53-
DaskArray,
54-
CupyArray,
55-
CupySparseMatrix,
56-
AwkArray,
57-
pd.DataFrame,
58-
np.number,
59-
str,
38+
pd.DataFrame, np.number, str, ArrayDataStructureType
6039
]
6140
RWAble: TypeAlias = Union[
6241
InMemoryArrayOrScalarType, dict[str, "RWAble"], list["RWAble"]

tests/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_creation():
3030
AnnData(np.array([[1, 2], [3, 4]]))
3131
AnnData(np.array([[1, 2], [3, 4]]), {}, {})
3232
AnnData(ma.array([[1, 2], [3, 4]]), uns=dict(mask=[0, 1, 1, 0]))
33-
AnnData(sp.eye(2))
33+
AnnData(sp.eye(2, format="csr"))
3434
if CAN_USE_SPARSE_ARRAY:
3535
AnnData(sp.eye_array(2))
3636
X = np.array([[1, 2, 3], [4, 5, 6]])
@@ -95,7 +95,7 @@ def test_creation_error(src, src_arg, dim_msg, dim, dim_arg, msg: str | None):
9595
def test_invalid_X():
9696
with pytest.raises(
9797
ValueError,
98-
match=r"X needs to be of one of np\.ndarray.*not <class 'str'>\.",
98+
match=r"X needs to be of one of <class 'numpy.ndarray'>.*not <class 'str'>\.",
9999
):
100100
AnnData("string is not a valid X")
101101

@@ -126,7 +126,7 @@ def test_error_create_from_multiindex_df(attr):
126126

127127

128128
def test_create_from_sparse_df():
129-
s = sp.random(20, 30, density=0.2)
129+
s = sp.random(20, 30, density=0.2, format="csr")
130130
obs_names = [f"obs{i}" for i in range(20)]
131131
var_names = [f"var{i}" for i in range(30)]
132132
df = pd.DataFrame.sparse.from_spmatrix(s, index=obs_names, columns=var_names)

tests/test_obsmvarm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,21 @@ def test_setting_dataframe(adata: AnnData):
8585

8686

8787
def test_setting_sparse(adata: AnnData):
88-
obsm_sparse = sparse.random(M, 100)
88+
obsm_sparse = sparse.random(M, 100, format="csr")
8989
adata.obsm["a"] = obsm_sparse
9090
assert not np.any((adata.obsm["a"] != obsm_sparse).data)
9191

92-
varm_sparse = sparse.random(N, 100)
92+
varm_sparse = sparse.random(N, 100, format="csr")
9393
adata.varm["a"] = varm_sparse
9494
assert not np.any((adata.varm["a"] != varm_sparse).data)
9595

9696
h = joblib.hash(adata)
9797

98-
bad_obsm_sparse = sparse.random(M * 2, M)
98+
bad_obsm_sparse = sparse.random(M * 2, M, format="csr")
9999
with pytest.raises(ValueError, match=r"incorrect shape"):
100100
adata.obsm["b"] = bad_obsm_sparse
101101

102-
bad_varm_sparse = sparse.random(N * 2, N)
102+
bad_varm_sparse = sparse.random(N * 2, N, format="csr")
103103
with pytest.raises(ValueError, match=r"incorrect shape"):
104104
adata.varm["b"] = bad_varm_sparse
105105

tests/test_obspvarp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,21 @@ def test_setting_ndarray(adata: AnnData):
6565

6666

6767
def test_setting_sparse(adata: AnnData):
68-
obsp_sparse = sparse.random(M, M)
68+
obsp_sparse = sparse.random(M, M, format="csr")
6969
adata.obsp["a"] = obsp_sparse
7070
assert not np.any((adata.obsp["a"] != obsp_sparse).data)
7171

72-
varp_sparse = sparse.random(N, N)
72+
varp_sparse = sparse.random(N, N, format="csr")
7373
adata.varp["a"] = varp_sparse
7474
assert not np.any((adata.varp["a"] != varp_sparse).data)
7575

7676
h = joblib.hash(adata)
7777

78-
bad_obsp_sparse = sparse.random(M * 2, M)
78+
bad_obsp_sparse = sparse.random(M * 2, M, format="csr")
7979
with pytest.raises(ValueError, match=r"incorrect shape"):
8080
adata.obsp["b"] = bad_obsp_sparse
8181

82-
bad_varp_sparse = sparse.random(N * 2, N)
82+
bad_varp_sparse = sparse.random(N * 2, N, format="csr")
8383
with pytest.raises(ValueError, match=r"incorrect shape"):
8484
adata.varp["b"] = bad_varp_sparse
8585

tests/test_x.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,12 @@ def test_set_dense_x_view_from_sparse():
182182
assert_equal(view.X, x1[:30])
183183
assert_equal(orig.X[:30], x1[:30]) # change propagates through
184184
assert_equal(orig.X[30:], x[30:]) # change propagates through
185+
186+
187+
def test_warn_on_non_csr_csc_matrix():
188+
X = sparse.eye(100)
189+
with pytest.warns(
190+
FutureWarning,
191+
match=rf"AnnData previously had undefined behavior around matrices of type {type(X)}.*",
192+
):
193+
ad.AnnData(X=X)

0 commit comments

Comments
 (0)