Skip to content

Commit 8bafd80

Browse files
authored
make array_type a session-scoped fixture (#91)
1 parent 630b0cb commit 8bafd80

File tree

4 files changed

+57
-11
lines changed

4 files changed

+57
-11
lines changed

src/testing/fast_array_utils/_array_type.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class Flags(enum.Flag):
7777
class ConversionContext:
7878
"""Conversion context required for h5py."""
7979

80-
hdf5_file: h5py.File
80+
hdf5_file: h5py.File # TODO(flying-sheep): ReadOnly <https://peps.python.org/pep-0767/>
8181

8282

8383
@dataclass(frozen=True)

src/testing/fast_array_utils/pytest.py

+40-9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from __future__ import annotations
88

99
import dataclasses
10+
import os
11+
import re
1012
from importlib.util import find_spec
1113
from typing import TYPE_CHECKING, cast
1214

@@ -93,8 +95,8 @@ def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator:
9395
]
9496

9597

96-
@pytest.fixture(params=SUPPORTED_TYPE_PARAMS)
97-
def array_type(request: pytest.FixtureRequest, tmp_path: Path) -> Generator[ArrayType, None, None]:
98+
@pytest.fixture(scope="session", params=SUPPORTED_TYPE_PARAMS)
99+
def array_type(request: pytest.FixtureRequest) -> ArrayType:
98100
"""Fixture for a supported :class:`~testing.fast_array_utils.ArrayType`.
99101
100102
Use :class:`testing.fast_array_utils.Flags` to select or skip array types:
@@ -131,13 +133,42 @@ def test_something(array_type: ArrayType) -> None:
131133
...
132134
"""
133135
at = cast("ArrayType", request.param)
134-
f: h5py.File | None = None
135136
if at.cls is types.H5Dataset or (at.inner and at.inner.cls is types.H5Dataset):
137+
at = dataclasses.replace(at, conversion_context=CC(request))
138+
return at
139+
140+
141+
try: # get the exception type
142+
pytest.fail("x")
143+
except BaseException as e: # noqa: BLE001
144+
Failed = type(e)
145+
else:
146+
raise AssertionError
147+
148+
149+
class CC(ConversionContext):
150+
def __init__(self, request: pytest.FixtureRequest) -> None:
151+
self._request = request
152+
153+
@property # This is intentionally not cached and creates a new file on each access
154+
def hdf5_file(self) -> h5py.File: # type: ignore[override]
136155
import h5py
137156

138-
f = h5py.File(tmp_path / f"{request.fixturename}.h5", "w")
139-
ctx = ConversionContext(hdf5_file=f)
140-
at = dataclasses.replace(at, conversion_context=ctx)
141-
yield at
142-
if f:
143-
f.close()
157+
try: # If we’re being called in a test or function-scoped fixture, use the test `tmp_path`
158+
return cast("h5py.File", self._request.getfixturevalue("tmp_hdf5_file"))
159+
except Failed: # We’re being called from a session-scoped fixture or so
160+
factory = cast(
161+
"pytest.TempPathFactory", self._request.getfixturevalue("tmp_path_factory")
162+
)
163+
name = re.sub(r"[^\w_. -()\[\]{}]", "_", os.environ["PYTEST_CURRENT_TEST"])
164+
f = h5py.File(factory.mktemp(name) / "test.h5", "w")
165+
self._request.addfinalizer(f.close)
166+
return f
167+
168+
169+
@pytest.fixture
170+
def tmp_hdf5_file(tmp_path: Path) -> Generator[h5py.File, None, None]:
171+
import h5py
172+
173+
with h5py.File(tmp_path / "test.h5", "w") as f:
174+
yield f

tests/test_test_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from numpy.typing import DTypeLike, NDArray
2020
from scipy.sparse import coo_array, coo_matrix
2121

22-
from testing.fast_array_utils import ArrayType
22+
from testing.fast_array_utils import Array, ArrayType
2323

2424

2525
other_array_type = array_type
@@ -78,3 +78,14 @@ def test_array_types(array_type: ArrayType) -> None:
7878
assert any(
7979
getattr(t, "mod", None) in {"zarr", "h5py"} for t in (array_type, array_type.inner)
8080
) == bool(array_type.flags & Flags.Disk)
81+
82+
83+
@pytest.fixture(scope="session")
84+
def session_scoped_array(array_type: ArrayType) -> Array:
85+
return array_type(np.arange(12).reshape(3, 4), dtype=np.float32)
86+
87+
88+
def test_session_scoped_array(session_scoped_array: Array) -> None:
89+
"""Tests that creating a session-scoped array works."""
90+
assert session_scoped_array.shape == (3, 4)
91+
assert session_scoped_array.dtype == np.float32

typings/h5py.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ class Dataset(HLObject):
1717
class Group(HLObject): ...
1818

1919
class File(Group, closing[File]): # not actually a subclass of closing
20+
filename: str
21+
mode: Literal["r", "r+"]
22+
libver: Literal["earliest", "latest", "v108", "v110"]
23+
2024
def __init__(
2125
self,
2226
name: AnyStr | os.PathLike[AnyStr] | IO[bytes],

0 commit comments

Comments
 (0)