Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions releasenotes/notes/async-sleep-retrying-32de5866f5d041.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
Passing an async ``sleep`` callable (e.g. ``trio.sleep``) to ``@retry``
now correctly uses ``AsyncRetrying``, even when the decorated function is
synchronous. Previously, the async sleep would silently not be awaited,
resulting in no delay between retries.
29 changes: 27 additions & 2 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@

WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any])
P = t.ParamSpec("P")
R = t.TypeVar("R")


@dataclasses.dataclass(slots=True)
Expand Down Expand Up @@ -589,7 +591,27 @@ def retry(func: WrappedFn) -> WrappedFn: ...

@t.overload
def retry(
sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep,
*,
sleep: t.Callable[[t.Union[int, float]], t.Awaitable[None]],
stop: "StopBaseT" = ...,
wait: "WaitBaseT" = ...,
retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = ...,
before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = ...,
after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = ...,
before_sleep: t.Optional[
t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]
] = ...,
reraise: bool = ...,
retry_error_cls: t.Type["RetryError"] = ...,
retry_error_callback: t.Optional[
t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]
] = ...,
) -> t.Callable[[t.Callable[P, R | t.Awaitable[R]]], t.Callable[P, t.Awaitable[R]]]: ...


@t.overload
def retry(
sleep: t.Callable[[t.Union[int, float]], None] = sleep,
stop: "StopBaseT" = stop_never,
wait: "WaitBaseT" = wait_none(),
retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(),
Expand Down Expand Up @@ -628,7 +650,10 @@ def wrap(f: WrappedFn) -> WrappedFn:
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
)
r: "BaseRetrying"
if _utils.is_coroutine_callable(f):
sleep = dkw.get("sleep")
if _utils.is_coroutine_callable(f) or (
sleep is not None and _utils.is_coroutine_callable(sleep)
):
r = AsyncRetrying(*dargs, **dkw)
elif (
tornado
Expand Down
6 changes: 5 additions & 1 deletion tenacity/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,15 @@ async def __call__( # type: ignore[override]
self.begin()

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
is_async = _utils.is_coroutine_callable(fn)
while True:
do = await self.iter(retry_state=retry_state)
if isinstance(do, DoAttempt):
try:
result = await fn(*args, **kwargs)
if is_async:
result = await fn(*args, **kwargs)
else:
result = fn(*args, **kwargs)
except BaseException: # noqa: B902
retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]
else:
Expand Down
27 changes: 26 additions & 1 deletion tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt
from tenacity.wait import wait_fixed

from .test_tenacity import NoIOErrorAfterCount, current_time_ms
from .test_tenacity import (
NoIOErrorAfterCount,
NoneReturnUntilAfterCount,
current_time_ms,
)


def asynctest(callable_):
Expand Down Expand Up @@ -463,5 +467,26 @@ async def foo():
pass


class TestSyncFunctionWithAsyncSleep(unittest.TestCase):
@asynctest
async def test_sync_function_with_async_sleep(self):
"""A sync function with an async sleep callable uses AsyncRetrying."""
mock_sleep = mock.AsyncMock()

thing = NoneReturnUntilAfterCount(2)

@retry(
sleep=mock_sleep,
wait=wait_fixed(1),
retry=retry_if_result(lambda x: x is None),
)
def sync_function():
return thing.go()

result = await sync_function()
assert result is True
assert mock_sleep.await_count == 2


if __name__ == "__main__":
unittest.main()