Skip to content

Commit df17fa7

Browse files
authored
Modify collection instead of skipping (#39)
1 parent a7e216b commit df17fa7

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

src/testing/fast_array_utils/pytest.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
if TYPE_CHECKING:
1919
from collections.abc import Callable, Generator
2020

21-
from _pytest.nodes import Node
22-
else:
23-
Node = object
24-
2521

2622
__all__ = ["array_type", "conversion_context"]
2723

@@ -32,6 +28,34 @@ def pytest_configure(config: pytest.Config) -> None:
3228
)
3329

3430

31+
def _resolve_sel(select: Flags = ~Flags(0), skip: Flags = Flags(0)) -> tuple[Flags, Flags]:
32+
return select, skip
33+
34+
35+
def pytest_collection_modifyitems(
36+
session: pytest.Session, # noqa: ARG001
37+
config: pytest.Config, # noqa: ARG001
38+
items: list[pytest.Item],
39+
) -> None:
40+
"""Filter tests using `pytest.mark.array_type` based on `testing.fast_array_utils.Flags`."""
41+
# reverse so we can .pop() items from the back without changing others’ index
42+
for i, item in reversed(list(enumerate(items))):
43+
if not (
44+
isinstance(item, pytest.Function) and (mark := item.get_closest_marker("array_type"))
45+
):
46+
continue
47+
48+
msg = "Test function marked with `pytest.mark.array_type` must have `array_type` parameter"
49+
if not (at := item.callspec.params.get("array_type")):
50+
raise TypeError(msg)
51+
if not isinstance(at, ArrayType):
52+
msg = f"{msg} of type {ArrayType.__name__}, got {type(at).__name__}"
53+
raise TypeError(msg)
54+
select, skip = _resolve_sel(*mark.args, **mark.kwargs)
55+
if not (at.flags & select) or (at.flags & skip):
56+
del items[i]
57+
58+
3559
def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator:
3660
dist = None
3761
skip = False
@@ -41,12 +65,6 @@ def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator:
4165
return pytest.mark.skipif(skip, reason=f"{dist} not installed")
4266

4367

44-
def _resolve_sel(
45-
select: Flags = ~Flags(0), skip: Flags = Flags(0), *, reason: str | None = None
46-
) -> tuple[Flags, Flags, str | None]:
47-
return select, skip, reason
48-
49-
5068
@pytest.fixture(
5169
params=[pytest.param(t, id=str(t), marks=_skip_if_unimportable(t)) for t in SUPPORTED_TYPES],
5270
)
@@ -59,7 +77,7 @@ def array_type(request: pytest.FixtureRequest) -> ArrayType:
5977
6078
.. code:: python
6179
62-
@pytest.mark.array_type(Flags.Sparse, reason="`something` only supports sparse arrays")
80+
@pytest.mark.array_type(Flags.Sparse)
6381
def test_something(array_type: ArrayType) -> None:
6482
...
6583
@@ -74,17 +92,9 @@ def test_something(array_type: ArrayType) -> None:
7492
from fast_array_utils.types import H5Dataset
7593

7694
at = cast(ArrayType, request.param)
77-
78-
mark = cast(Node, request.node).get_closest_marker("array_type")
79-
if mark:
80-
select, skip, reason = _resolve_sel(*mark.args, **mark.kwargs)
81-
if not (at.flags & select) or (at.flags & skip):
82-
pytest.skip(reason or f"{at} not included in {select=}, {skip=}")
83-
8495
if at.cls is H5Dataset:
8596
ctx = request.getfixturevalue("conversion_context")
8697
at = dataclasses.replace(at, conversion_context=ctx)
87-
8898
return at
8999

90100

@@ -101,7 +111,7 @@ def conversion_context(
101111
"""
102112
import h5py
103113

104-
node = cast(Node, request.node)
114+
node = cast(pytest.Item, request.node)
105115
tmp_path = tmp_path_factory.mktemp("backed_adata")
106116
tmp_path = tmp_path / f"test_{node.name}_{worker_id}.h5ad"
107117

@@ -122,7 +132,7 @@ def viz(obj: object) -> None:
122132
else:
123133
from dask import visualize
124134

125-
path = cache.mkdir("dask-viz") / cast(Node, request.node).name
135+
path = cache.mkdir("dask-viz") / cast(pytest.Item, request.node).name
126136
visualize(obj, filename=str(path), engine="ipycytoscape")
127137

128138
return viz

0 commit comments

Comments
 (0)