Skip to content

Commit 7d24f88

Browse files
committed
(chore): ensure views of anndata produce view classes
1 parent 0bc2b39 commit 7d24f88

4 files changed

Lines changed: 56 additions & 15 deletions

File tree

docs/tutorials/notebooks

src/anndata/_core/storage.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from ..compat import (
1313
AwkArray,
1414
CupyArray,
15+
CupyCSCMatrix,
16+
CupyCSRMatrix,
1517
CupySparseMatrix,
1618
DaskArray,
1719
H5Array,
@@ -43,6 +45,8 @@
4345
CSCDataset,
4446
DaskArray,
4547
CupyArray,
48+
CupyCSCMatrix,
49+
CupyCSRMatrix,
4650
CupySparseMatrix,
4751
]
4852

src/anndata/_core/views.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -294,64 +294,64 @@ def as_view(obj, view_args):
294294

295295

296296
@as_view.register(np.ndarray)
297-
def as_view_array(array, view_args):
297+
def as_view_array(array, view_args) -> ArrayView:
298298
return ArrayView(array, view_args=view_args)
299299

300300

301301
@as_view.register(DaskArray)
302-
def as_view_dask_array(array, view_args):
302+
def as_view_dask_array(array, view_args) -> DaskArrayView:
303303
return DaskArrayView(array, view_args=view_args)
304304

305305

306306
@as_view.register(pd.DataFrame)
307-
def as_view_df(df, view_args):
307+
def as_view_df(df, view_args) -> DataFrameView:
308308
return DataFrameView(df, view_args=view_args)
309309

310310

311311
@as_view.register(sparse.csr_matrix)
312-
def as_view_csr_matrix(mtx, view_args):
312+
def as_view_csr_matrix(mtx, view_args) -> SparseCSRMatrixView:
313313
return SparseCSRMatrixView(mtx, view_args=view_args)
314314

315315

316316
@as_view.register(sparse.csc_matrix)
317-
def as_view_csc_matrix(mtx, view_args):
317+
def as_view_csc_matrix(mtx, view_args) -> SparseCSCMatrixView:
318318
return SparseCSCMatrixView(mtx, view_args=view_args)
319319

320320

321321
@as_view.register(sparse.csr_array)
322-
def as_view_csr_array(mtx, view_args):
322+
def as_view_csr_array(mtx, view_args) -> SparseCSRArrayView:
323323
return SparseCSRArrayView(mtx, view_args=view_args)
324324

325325

326326
@as_view.register(sparse.csc_array)
327-
def as_view_csc_array(mtx, view_args):
327+
def as_view_csc_array(mtx, view_args) -> SparseCSCArrayView:
328328
return SparseCSCArrayView(mtx, view_args=view_args)
329329

330330

331331
@as_view.register(dict)
332-
def as_view_dict(d, view_args):
332+
def as_view_dict(d, view_args) -> DictView:
333333
return DictView(d, view_args=view_args)
334334

335335

336336
@as_view.register(ZappyArray)
337-
def as_view_zappy(z, view_args):
337+
def as_view_zappy(z, view_args) -> ZappyArray:
338338
# Previous code says ZappyArray works as view,
339339
# but as far as I can tell they’re immutable.
340340
return z
341341

342342

343343
@as_view.register(CupyArray)
344-
def as_view_cupy(array, view_args):
344+
def as_view_cupy(array, view_args) -> CupyArrayView:
345345
return CupyArrayView(array, view_args=view_args)
346346

347347

348348
@as_view.register(CupyCSRMatrix)
349-
def as_view_cupy_csr(mtx, view_args):
349+
def as_view_cupy_csr(mtx, view_args) -> CupySparseCSRView:
350350
return CupySparseCSRView(mtx, view_args=view_args)
351351

352352

353353
@as_view.register(CupyCSCMatrix)
354-
def as_view_cupy_csc(mtx, view_args):
354+
def as_view_cupy_csc(mtx, view_args) -> CupySparseCSCView:
355355
return CupySparseCSCView(mtx, view_args=view_args)
356356

357357

@@ -373,7 +373,7 @@ def _view_args(self):
373373
to be attached as "behavior". These "behaviors" cannot take any additional parameters (as we do
374374
for other data types to store `_view_args`). Therefore, we need to store `_view_args` using awkward's
375375
parameter mechanism. These parameters need to be json-serializable, which is why we can't store
376-
ElementRef directly, but need to replace the reference to the parent AnnDataView container with a weak
376+
ElementRef directly, but need to replace the reference to the parent AnnData container with a weak
377377
reference.
378378
"""
379379
parent_key, attrname, keys = self.layout.parameter(_PARAM_NAME)
@@ -394,7 +394,7 @@ def __copy__(self) -> AwkArray:
394394
return array
395395

396396
@as_view.register(AwkArray)
397-
def as_view_awkarray(array, view_args):
397+
def as_view_awkarray(array, view_args) -> AwkwardArrayView:
398398
parent, attrname, keys = view_args
399399
parent_key = f"target-{id(parent)}"
400400
_registry[parent_key] = parent

tests/test_views.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import ExitStack
44
from copy import deepcopy
55
from operator import mul
6+
from typing import get_type_hints
67

78
import joblib
89
import numpy as np
@@ -786,6 +787,42 @@ def test_dataframe_view_index_setting():
786787
assert a2.obs.index.values.tolist() == ["a", "b"]
787788

788789

790+
def test_elem_view_class():
791+
"""
792+
Ensure that:
793+
794+
(a) AnnData views actually produce view classes
795+
(b) Produced view classes are subtypes of their original type
796+
which then allows distinguishing views from non-views.
797+
798+
This test tries to then guarantee that `my_adata.is_view and isinstance(my_adata.obsm['my_array'], BaseArrayClass)`
799+
tells a user that they are working with a view class of `obsm['my_array']` that inherits from the base class (and has its methods).
800+
"""
801+
orig = gen_adata((10, 10))
802+
subset = orig[:8, :8]
803+
assert subset.is_view
804+
registry = ad._core.views.as_view.registry
805+
as_view_funcs = registry.values()
806+
base_classes = registry.keys()
807+
# Use set membership to ensure the *actual* view class is used
808+
view_types = set(
809+
get_type_hints(func)["return"]
810+
for func in as_view_funcs
811+
if "return" in func.__annotations__
812+
)
813+
base_types = tuple(base_classes)
814+
assert type(subset.obs) in view_types
815+
assert type(subset.var) in view_types
816+
for view_data in (
817+
*subset.obsm.values(),
818+
*subset.layers.values(),
819+
*subset.obsp.values(),
820+
):
821+
view_data_type = type(view_data)
822+
assert view_data_type in view_types
823+
assert isinstance(view_data, base_types)
824+
825+
789826
# @pytest.mark.parametrize("dim", ["obs", "var"])
790827
# @pytest.mark.parametrize(
791828
# ("idx", "pat"),

0 commit comments

Comments
 (0)