diff --git a/backoff/_common.py b/backoff/_common.py index 2b2e54e..e867ed6 100644 --- a/backoff/_common.py +++ b/backoff/_common.py @@ -32,12 +32,18 @@ def _init_wait_gen(wait_gen, wait_gen_kwargs): def _next_wait(wait, send_value, jitter, elapsed, max_time): + remaining_time = None + if max_time is not None: + remaining_time = max_time - elapsed + if remaining_time <= 0: # we equal/exceed time limit + return 0 + value = wait.send(send_value) + if remaining_time and value >= remaining_time: + return remaining_time + try: - if jitter is not None: - seconds = jitter(value) - else: - seconds = value + seconds = jitter(value) if jitter is not None else value except TypeError: warnings.warn( "Nullary jitter function signature is deprecated. Use " @@ -46,12 +52,11 @@ def _next_wait(wait, send_value, jitter, elapsed, max_time): DeprecationWarning, stacklevel=2, ) - seconds = value + jitter() - # don't sleep longer than remaining allotted max_time - if max_time is not None: - seconds = min(seconds, max_time - elapsed) + # adding jitter may push value over max_limit + if remaining_time and seconds >= remaining_time: + return remaining_time return seconds diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 0000000..32dec7b --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,71 @@ +# coding:utf-8 +import unittest.mock + +from backoff._common import _next_wait + + +def test_next_wait_trunc_wait_fn(): + wait_mock = unittest.mock.Mock() + wait_mock.send.side_effect = lambda x: x + 2 + + # 9 + 2 > 10 + assert _next_wait(wait_mock, 9, None, 0, 10) == 10 + wait_mock.send.assert_called_once_with(9) + + +def test_next_wait_trunc_wait_fn_elapsed(): + wait_mock = unittest.mock.Mock() + wait_mock.send.side_effect = lambda x: x + 2 + + # 4 + 2 > 10 - 5 + assert _next_wait(wait_mock, 4, None, 5, 10) == 5 + wait_mock.send.assert_called_once_with(4) + + +def test_next_wait_elapsed_wait(): + wait_mock = unittest.mock.Mock() + wait_mock.send.side_effect = lambda x: x + 2 + + assert _next_wait(wait_mock, 0, None, 10, 10) == 0 + wait_mock.send.assert_not_called() + assert _next_wait(wait_mock, 0, None, 11, 10) == 0 + wait_mock.send.assert_not_called() + + +def test_next_wait_jitter_over(): + wait_mock = unittest.mock.Mock() + wait_mock.send.side_effect = lambda x: x + 2 + + jitter_fn_mock = unittest.mock.Mock() + jitter_fn_mock.side_effect = lambda x: x + 2 + + # 8 + 2 + 1 > 10 + assert _next_wait(wait_mock, 7, jitter_fn_mock, 0, 10) == 10 + wait_mock.send.assert_called_once_with(7) + jitter_fn_mock.assert_called_once_with(9) + + +def test_next_wait_jitter_skipped(): + wait_mock = unittest.mock.Mock() + wait_mock.send.side_effect = lambda x: x + 2 + + jitter_fn_mock = unittest.mock.Mock() + jitter_fn_mock.side_effect = lambda x: x - 1 + + # 8 + 2 == 10 + assert _next_wait(wait_mock, 8, jitter_fn_mock, 0, 10) == 10 + wait_mock.send.assert_called_once_with(8) + jitter_fn_mock.assert_not_called() + + +def test_next_wait_jitter_under(): + wait_mock = unittest.mock.Mock() + wait_mock.send.side_effect = lambda x: x + 2 + + jitter_fn_mock = unittest.mock.Mock() + jitter_fn_mock.side_effect = lambda x: x - 1 + + # 7 + 2 - 1 == 8 + assert _next_wait(wait_mock, 7, jitter_fn_mock, 0, 10) == 8 + wait_mock.send.assert_called_once_with(7) + jitter_fn_mock.assert_called_once_with(9)