Skip to content

Commit 51b8072

Browse files
authored
Allow set_testing to be used as a context manager (#94)
* Allow set_testing to be used as a context manager implements #85 * Add PR link
1 parent 5b51f2c commit 51b8072

File tree

5 files changed

+116
-3
lines changed

5 files changed

+116
-3
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/
2222
When `cap=True`, it is used as an upper cap; that means that if the original attempts number is lower, it's not changed.
2323
[#80](https://github.com/hynek/stamina/pull/80)
2424

25+
- `stamina.set_testing()` can now be used as a context manager.
26+
[#94](https://github.com/hynek/stamina/pull/94)
27+
2528

2629
## [24.3.0](https://github.com/hynek/stamina/compare/24.2.0...24.3.0) - 2024-08-27
2730

src/stamina/_config.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
from threading import Lock
8+
from types import TracebackType
89
from typing import Callable
910

1011
from .instrumentation import RetryHookFactory
@@ -154,9 +155,25 @@ def is_testing() -> bool:
154155
return CONFIG.testing is not None
155156

156157

158+
class _RestoreTestingCM:
159+
def __init__(self, old: _Testing | None) -> None:
160+
self.old = old
161+
162+
def __enter__(self) -> None:
163+
pass
164+
165+
def __exit__(
166+
self,
167+
exc_type: type[BaseException] | None,
168+
exc_val: BaseException | None,
169+
exc_tb: TracebackType | None,
170+
) -> None:
171+
CONFIG.testing = self.old
172+
173+
157174
def set_testing(
158175
testing: bool, *, attempts: int = 1, cap: bool = False
159-
) -> None:
176+
) -> _RestoreTestingCM:
160177
"""
161178
Activate or deactivate test mode.
162179
@@ -170,5 +187,9 @@ def set_testing(
170187
171188
.. versionadded:: 24.3.0
172189
.. versionadded:: 25.1.0 *cap*
190+
.. versionadded:: 25.1.0 Can be used as a context manager.
173191
"""
192+
old = CONFIG.testing
174193
CONFIG.testing = _Testing(attempts, cap) if testing else None
194+
195+
return _RestoreTestingCM(old)

tests/test_async.py

+21
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,24 @@ def test_repr(self):
306306
assert f"<BoundAsyncRetryingCaller(ValueError, {r})>" == repr(
307307
arc.on(ValueError)
308308
)
309+
310+
311+
async def test_testing_mode_context():
312+
"""
313+
Testing mode context manager works with async code.
314+
"""
315+
assert not stamina.is_testing()
316+
317+
with stamina.set_testing(True, attempts=3):
318+
assert stamina.is_testing()
319+
320+
with pytest.raises(ValueError): # noqa: PT012
321+
async for attempt in stamina.retry_context(on=ValueError):
322+
assert 0.0 == attempt.next_wait
323+
324+
with attempt:
325+
raise ValueError
326+
327+
assert 3 == attempt.num
328+
329+
assert not stamina.is_testing()

tests/test_config.py

+49-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5+
from contextlib import suppress
56
from threading import Lock
67

7-
from stamina import is_active, set_active
8-
from stamina._config import _Config, _Testing
8+
from stamina import is_active, is_testing, set_active, set_testing
9+
from stamina._config import CONFIG, _Config, _Testing
910

1011

1112
def test_activate_deactivate():
@@ -67,3 +68,49 @@ def test_cap_true_with_none(self):
6768
t = _Testing(100, True)
6869

6970
assert 100 == t.get_attempts(None)
71+
72+
def test_context_manager(self):
73+
"""
74+
set_testing works as a context manager.
75+
"""
76+
assert not is_testing()
77+
78+
with set_testing(True, attempts=3):
79+
assert is_testing()
80+
assert 3 == CONFIG.testing.get_attempts(None)
81+
assert not CONFIG.testing.cap
82+
83+
assert not is_testing()
84+
85+
def test_context_manager_nested(self):
86+
"""
87+
set_testing context managers can be nested.
88+
"""
89+
assert not is_testing()
90+
91+
with set_testing(True, attempts=3):
92+
assert is_testing()
93+
assert CONFIG.testing.attempts == 3
94+
95+
with set_testing(True, attempts=5, cap=True):
96+
assert is_testing()
97+
assert CONFIG.testing.attempts == 5
98+
assert CONFIG.testing.cap
99+
100+
assert is_testing()
101+
assert CONFIG.testing.attempts == 3
102+
assert not CONFIG.testing.cap
103+
104+
assert not is_testing()
105+
106+
def test_context_manager_exception(self):
107+
"""
108+
set_testing context manager restores state even if an exception occurs.
109+
"""
110+
assert not is_testing()
111+
112+
with suppress(ValueError), set_testing(True, attempts=3):
113+
assert is_testing()
114+
raise ValueError("test")
115+
116+
assert not is_testing()

tests/test_sync.py

+21
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,24 @@ def test_repr(self):
290290
assert f"<BoundRetryingCaller(ValueError, {r})>" == repr(
291291
rc.on(ValueError)
292292
)
293+
294+
295+
def test_testing_mode_context():
296+
"""
297+
Testing mode context manager works with sync code.
298+
"""
299+
assert not stamina.is_testing()
300+
301+
with stamina.set_testing(True, attempts=3):
302+
assert stamina.is_testing()
303+
304+
with pytest.raises(ValueError): # noqa: PT012
305+
for attempt in stamina.retry_context(on=ValueError):
306+
assert 0.0 == attempt.next_wait
307+
308+
with attempt:
309+
raise ValueError
310+
311+
assert 3 == attempt.num
312+
313+
assert not stamina.is_testing()

0 commit comments

Comments
 (0)