Skip to content

Commit 1e202fe

Browse files
committed
support async sleep for sync fn
1 parent cb2ce95 commit 1e202fe

File tree

4 files changed

+69
-5
lines changed

4 files changed

+69
-5
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
fixes:
3+
- |
4+
Passing an async ``sleep`` callable (e.g. ``trio.sleep``) to ``@retry``
5+
now correctly uses ``AsyncRetrying``, even when the decorated function is
6+
synchronous. Previously, the async sleep would silently not be awaited,
7+
resulting in no delay between retries.

tenacity/__init__.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,16 @@
8989
if t.TYPE_CHECKING:
9090
import types
9191

92-
from typing_extensions import Self
92+
from typing_extensions import ParamSpec, Self
9393

9494
from . import asyncio as tasyncio
9595
from .retry import RetryBaseT
9696
from .stop import StopBaseT
9797
from .wait import WaitBaseT
9898

99+
P = ParamSpec("P")
100+
R = t.TypeVar("R")
101+
99102

100103
WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
101104
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any])
@@ -589,7 +592,29 @@ def retry(func: WrappedFn) -> WrappedFn: ...
589592

590593
@t.overload
591594
def retry(
592-
sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep,
595+
*,
596+
sleep: t.Callable[[t.Union[int, float]], t.Awaitable[None]],
597+
stop: "StopBaseT" = ...,
598+
wait: "WaitBaseT" = ...,
599+
retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = ...,
600+
before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = ...,
601+
after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = ...,
602+
before_sleep: t.Optional[
603+
t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]
604+
] = ...,
605+
reraise: bool = ...,
606+
retry_error_cls: t.Type["RetryError"] = ...,
607+
retry_error_callback: t.Optional[
608+
t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]
609+
] = ...,
610+
) -> (
611+
"t.Callable[[t.Callable[P, R | t.Awaitable[R]]], t.Callable[P, t.Awaitable[R]]]"
612+
): ...
613+
614+
615+
@t.overload
616+
def retry(
617+
sleep: t.Callable[[t.Union[int, float]], None] = sleep,
593618
stop: "StopBaseT" = stop_never,
594619
wait: "WaitBaseT" = wait_none(),
595620
retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(),
@@ -628,7 +653,10 @@ def wrap(f: WrappedFn) -> WrappedFn:
628653
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
629654
)
630655
r: "BaseRetrying"
631-
if _utils.is_coroutine_callable(f):
656+
sleep = dkw.get("sleep")
657+
if _utils.is_coroutine_callable(f) or (
658+
sleep is not None and _utils.is_coroutine_callable(sleep)
659+
):
632660
r = AsyncRetrying(*dargs, **dkw)
633661
elif (
634662
tornado

tenacity/asyncio/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,15 @@ async def __call__( # type: ignore[override]
107107
self.begin()
108108

109109
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
110+
is_async = _utils.is_coroutine_callable(fn)
110111
while True:
111112
do = await self.iter(retry_state=retry_state)
112113
if isinstance(do, DoAttempt):
113114
try:
114-
result = await fn(*args, **kwargs)
115+
if is_async:
116+
result = await fn(*args, **kwargs)
117+
else:
118+
result = fn(*args, **kwargs)
115119
except BaseException: # noqa: B902
116120
retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]
117121
else:

tests/test_asyncio.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt
3535
from tenacity.wait import wait_fixed
3636

37-
from .test_tenacity import NoIOErrorAfterCount, current_time_ms
37+
from .test_tenacity import (
38+
NoIOErrorAfterCount,
39+
NoneReturnUntilAfterCount,
40+
current_time_ms,
41+
)
3842

3943

4044
def asynctest(callable_):
@@ -463,5 +467,26 @@ async def foo():
463467
pass
464468

465469

470+
class TestSyncFunctionWithAsyncSleep(unittest.TestCase):
471+
@asynctest
472+
async def test_sync_function_with_async_sleep(self):
473+
"""A sync function with an async sleep callable uses AsyncRetrying."""
474+
mock_sleep = mock.AsyncMock()
475+
476+
thing = NoneReturnUntilAfterCount(2)
477+
478+
@retry(
479+
sleep=mock_sleep,
480+
wait=wait_fixed(1),
481+
retry=retry_if_result(lambda x: x is None),
482+
)
483+
def sync_function():
484+
return thing.go()
485+
486+
result = await sync_function()
487+
assert result is True
488+
assert mock_sleep.await_count == 2
489+
490+
466491
if __name__ == "__main__":
467492
unittest.main()

0 commit comments

Comments
 (0)