1818if TYPE_CHECKING :
1919 from collections .abc import Callable , Generator
2020
21- from _pytest .nodes import Node
22- else :
23- Node = object
24-
2521
2622__all__ = ["array_type" , "conversion_context" ]
2723
@@ -32,6 +28,34 @@ def pytest_configure(config: pytest.Config) -> None:
3228 )
3329
3430
31+ def _resolve_sel (select : Flags = ~ Flags (0 ), skip : Flags = Flags (0 )) -> tuple [Flags , Flags ]:
32+ return select , skip
33+
34+
35+ def pytest_collection_modifyitems (
36+ session : pytest .Session , # noqa: ARG001
37+ config : pytest .Config , # noqa: ARG001
38+ items : list [pytest .Item ],
39+ ) -> None :
40+ """Filter tests using `pytest.mark.array_type` based on `testing.fast_array_utils.Flags`."""
41+ # reverse so we can .pop() items from the back without changing others’ index
42+ for i , item in reversed (list (enumerate (items ))):
43+ if not (
44+ isinstance (item , pytest .Function ) and (mark := item .get_closest_marker ("array_type" ))
45+ ):
46+ continue
47+
48+ msg = "Test function marked with `pytest.mark.array_type` must have `array_type` parameter"
49+ if not (at := item .callspec .params .get ("array_type" )):
50+ raise TypeError (msg )
51+ if not isinstance (at , ArrayType ):
52+ msg = f"{ msg } of type { ArrayType .__name__ } , got { type (at ).__name__ } "
53+ raise TypeError (msg )
54+ select , skip = _resolve_sel (* mark .args , ** mark .kwargs )
55+ if not (at .flags & select ) or (at .flags & skip ):
56+ del items [i ]
57+
58+
3559def _skip_if_unimportable (array_type : ArrayType ) -> pytest .MarkDecorator :
3660 dist = None
3761 skip = False
@@ -41,12 +65,6 @@ def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator:
4165 return pytest .mark .skipif (skip , reason = f"{ dist } not installed" )
4266
4367
44- def _resolve_sel (
45- select : Flags = ~ Flags (0 ), skip : Flags = Flags (0 ), * , reason : str | None = None
46- ) -> tuple [Flags , Flags , str | None ]:
47- return select , skip , reason
48-
49-
5068@pytest .fixture (
5169 params = [pytest .param (t , id = str (t ), marks = _skip_if_unimportable (t )) for t in SUPPORTED_TYPES ],
5270)
@@ -59,7 +77,7 @@ def array_type(request: pytest.FixtureRequest) -> ArrayType:
5977
6078 .. code:: python
6179
62- @pytest.mark.array_type(Flags.Sparse, reason="`something` only supports sparse arrays" )
80+ @pytest.mark.array_type(Flags.Sparse)
6381 def test_something(array_type: ArrayType) -> None:
6482 ...
6583
@@ -74,17 +92,9 @@ def test_something(array_type: ArrayType) -> None:
7492 from fast_array_utils .types import H5Dataset
7593
7694 at = cast (ArrayType , request .param )
77-
78- mark = cast (Node , request .node ).get_closest_marker ("array_type" )
79- if mark :
80- select , skip , reason = _resolve_sel (* mark .args , ** mark .kwargs )
81- if not (at .flags & select ) or (at .flags & skip ):
82- pytest .skip (reason or f"{ at } not included in { select = } , { skip = } " )
83-
8495 if at .cls is H5Dataset :
8596 ctx = request .getfixturevalue ("conversion_context" )
8697 at = dataclasses .replace (at , conversion_context = ctx )
87-
8898 return at
8999
90100
@@ -101,7 +111,7 @@ def conversion_context(
101111 """
102112 import h5py
103113
104- node = cast (Node , request .node )
114+ node = cast (pytest . Item , request .node )
105115 tmp_path = tmp_path_factory .mktemp ("backed_adata" )
106116 tmp_path = tmp_path / f"test_{ node .name } _{ worker_id } .h5ad"
107117
@@ -122,7 +132,7 @@ def viz(obj: object) -> None:
122132 else :
123133 from dask import visualize
124134
125- path = cache .mkdir ("dask-viz" ) / cast (Node , request .node ).name
135+ path = cache .mkdir ("dask-viz" ) / cast (pytest . Item , request .node ).name
126136 visualize (obj , filename = str (path ), engine = "ipycytoscape" )
127137
128138 return viz
0 commit comments