Skip to content

Commit a66b921

Browse files
committed
support async sleep for sync fn
1 parent 7027da3 commit a66b921

File tree

4 files changed

+68
-5
lines changed

4 files changed

+68
-5
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
fixes:
3+
- |
4+
Passing an async ``sleep`` callable (e.g. ``trio.sleep``) to ``@retry``
5+
now correctly uses ``AsyncRetrying``, even when the decorated function is
6+
synchronous. Previously, the async sleep would silently not be awaited,
7+
resulting in no delay between retries.

tenacity/__init__.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,16 @@
8989
if t.TYPE_CHECKING:
9090
import types
9191

92-
from typing_extensions import Self
92+
from typing_extensions import ParamSpec, Self
9393

9494
from . import asyncio as tasyncio
9595
from .retry import RetryBaseT
9696
from .stop import StopBaseT
9797
from .wait import WaitBaseT
9898

99+
P = ParamSpec("P")
100+
R = t.TypeVar("R")
101+
99102

100103
WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
101104
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any])
@@ -600,7 +603,29 @@ def retry(func: WrappedFn) -> WrappedFn: ...
600603

601604
@t.overload
602605
def retry(
603-
sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep,
606+
*,
607+
sleep: t.Callable[[t.Union[int, float]], t.Awaitable[None]],
608+
stop: "StopBaseT" = ...,
609+
wait: "WaitBaseT" = ...,
610+
retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = ...,
611+
before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = ...,
612+
after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = ...,
613+
before_sleep: t.Optional[
614+
t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]
615+
] = ...,
616+
reraise: bool = ...,
617+
retry_error_cls: t.Type["RetryError"] = ...,
618+
retry_error_callback: t.Optional[
619+
t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]
620+
] = ...,
621+
) -> (
622+
"t.Callable[[t.Callable[P, R | t.Awaitable[R]]], t.Callable[P, t.Awaitable[R]]]"
623+
): ...
624+
625+
626+
@t.overload
627+
def retry(
628+
sleep: t.Callable[[t.Union[int, float]], None] = sleep,
604629
stop: "StopBaseT" = stop_never,
605630
wait: "WaitBaseT" = wait_none(),
606631
retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(),
@@ -639,7 +664,9 @@ def wrap(f: WrappedFn) -> WrappedFn:
639664
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
640665
)
641666
r: "BaseRetrying"
642-
if _utils.is_coroutine_callable(f):
667+
if _utils.is_coroutine_callable(f) or _utils.is_coroutine_callable(
668+
dkw.get("sleep")
669+
):
643670
r = AsyncRetrying(*dargs, **dkw)
644671
elif (
645672
tornado

tenacity/asyncio/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,15 @@ async def __call__( # type: ignore[override]
107107
self.begin()
108108

109109
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
110+
is_async = _utils.is_coroutine_callable(fn)
110111
while True:
111112
do = await self.iter(retry_state=retry_state)
112113
if isinstance(do, DoAttempt):
113114
try:
114-
result = await fn(*args, **kwargs)
115+
if is_async:
116+
result = await fn(*args, **kwargs)
117+
else:
118+
result = fn(*args, **kwargs)
115119
except BaseException: # noqa: B902
116120
retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]
117121
else:

tests/test_asyncio.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt
3535
from tenacity.wait import wait_fixed
3636

37-
from .test_tenacity import NoIOErrorAfterCount, current_time_ms
37+
from .test_tenacity import (
38+
NoIOErrorAfterCount,
39+
NoneReturnUntilAfterCount,
40+
current_time_ms,
41+
)
3842

3943

4044
def asynctest(callable_):
@@ -463,5 +467,26 @@ async def foo():
463467
pass
464468

465469

470+
class TestSyncFunctionWithAsyncSleep(unittest.TestCase):
471+
@asynctest
472+
async def test_sync_function_with_async_sleep(self):
473+
"""A sync function with an async sleep callable uses AsyncRetrying."""
474+
mock_sleep = mock.AsyncMock()
475+
476+
thing = NoneReturnUntilAfterCount(2)
477+
478+
@retry(
479+
sleep=mock_sleep,
480+
wait=wait_fixed(1),
481+
retry=retry_if_result(lambda x: x is None),
482+
)
483+
def sync_function():
484+
return thing.go()
485+
486+
result = await sync_function()
487+
assert result is True
488+
assert mock_sleep.await_count == 2
489+
490+
466491
if __name__ == "__main__":
467492
unittest.main()

0 commit comments

Comments
 (0)