Skip to content

Commit a8154bf

Browse files
authored
Add cap argument to testing mode (#80)
1 parent 0e1010a commit a8154bf

File tree

4 files changed

+71
-6
lines changed

4 files changed

+71
-6
lines changed

CHANGELOG.md

+7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/
1515

1616
## [Unreleased](https://github.com/hynek/stamina/compare/24.3.0...HEAD)
1717

18+
### Added
19+
20+
- *cap* argument to `stamina.set_testing()`.
21+
By default, the value passed as *attempts* is used strictly.
22+
When `cap=True`, it is used as an upper cap; that means that if the original attempts number is lower, it's not changed.
23+
[#80](https://github.com/hynek/stamina/pull/80)
24+
1825

1926
## [24.3.0](https://github.com/hynek/stamina/compare/24.2.0...24.3.0) - 2024-08-27
2027

src/stamina/_config.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,29 @@ class _Testing:
1919
Strictly private.
2020
"""
2121

22-
__slots__ = ("attempts",)
22+
__slots__ = ("attempts", "cap")
2323

2424
attempts: int
25+
cap: bool
2526

26-
def __init__(self, attempts: int) -> None:
27+
def __init__(self, attempts: int, cap: bool) -> None:
2728
self.attempts = attempts
29+
self.cap = cap
30+
31+
def get_attempts(self, non_testing_attempts: int | None) -> int:
32+
"""
33+
Get the number of attempts to use.
34+
35+
Args:
36+
non_testing_attempts: The number of attempts specified by the user.
37+
38+
Returns:
39+
The number of attempts to use.
40+
"""
41+
if self.cap:
42+
return min(self.attempts, non_testing_attempts or self.attempts)
43+
44+
return self.attempts
2845

2946

3047
class _Config:
@@ -137,14 +154,21 @@ def is_testing() -> bool:
137154
return CONFIG.testing is not None
138155

139156

140-
def set_testing(testing: bool, *, attempts: int = 1) -> None:
157+
def set_testing(
158+
testing: bool, *, attempts: int = 1, cap: bool = False
159+
) -> None:
141160
"""
142161
Activate or deactivate test mode.
143162
144163
In testing mode, backoffs are disabled, and attempts are set to *attempts*.
145164
165+
If *cap* is True, the number of attempts is not set but capped at
166+
*attempts*. This means that if *attempts* is greater than the number of
167+
attempts specified by the user, the user's value is used.
168+
146169
Is idempotent and can be called repeatedly with the same values.
147170
148171
.. versionadded:: 24.3.0
172+
.. versionadded:: 24.4.0 *cap*
149173
"""
150-
CONFIG.testing = _Testing(attempts) if testing else None
174+
CONFIG.testing = _Testing(attempts, cap) if testing else None

src/stamina/_core.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ class _RetryContextIterator:
403403
"_name",
404404
"_args",
405405
"_kw",
406+
"_attempts",
406407
"_wait_jitter",
407408
"_wait_initial",
408409
"_wait_max",
@@ -414,6 +415,7 @@ class _RetryContextIterator:
414415
_args: tuple[object, ...]
415416
_kw: dict[str, object]
416417

418+
_attempts: int | None
417419
_wait_jitter: float
418420
_wait_initial: float
419421
_wait_max: float
@@ -455,6 +457,7 @@ def from_params(
455457
_name=name,
456458
_args=args,
457459
_kw=kw,
460+
_attempts=attempts,
458461
_wait_jitter=wait_jitter,
459462
_wait_initial=wait_initial,
460463
_wait_max=wait_max,
@@ -494,7 +497,9 @@ def _apply_maybe_test_mode_to_tenacity_kw(
494497

495498
t_kw = self._t_kw.copy()
496499

497-
t_kw["stop"] = _t.stop_after_attempt(testing.attempts)
500+
t_kw["stop"] = _t.stop_after_attempt(
501+
testing.get_attempts(self._attempts)
502+
)
498503

499504
return t_kw
500505

tests/test_config.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from threading import Lock
66

77
from stamina import is_active, set_active
8-
from stamina._config import _Config
8+
from stamina._config import _Config, _Testing
99

1010

1111
def test_activate_deactivate():
@@ -38,3 +38,32 @@ def fake_on_retry(self):
3838

3939
assert (1, 2) == cfg._init_on_first_retry()
4040
assert fake_on_retry is cfg._get_on_retry
41+
42+
43+
class TestTesting:
44+
def test_cap_true(self):
45+
"""
46+
If cap is True, get_attempts returns the lower of the two values.
47+
"""
48+
t = _Testing(2, True)
49+
50+
assert 1 == t.get_attempts(1)
51+
assert 2 == t.get_attempts(3)
52+
53+
def test_cap_false(self):
54+
"""
55+
If cap is False, get_attempts always returns the testing value.
56+
"""
57+
t = _Testing(2, False)
58+
59+
assert 2 == t.get_attempts(1)
60+
assert 2 == t.get_attempts(3)
61+
62+
def test_cap_true_with_none(self):
63+
"""
64+
If cap is True and attempts is None, get_attempts returns the
65+
testing value.
66+
"""
67+
t = _Testing(100, True)
68+
69+
assert 100 == t.get_attempts(None)

0 commit comments

Comments
 (0)