Skip to content

Commit ca58e02

Browse files
Add in_ and at helpers to allow for directly requesting a retry (#130)
Closes #129
1 parent 0244315 commit ca58e02

File tree

2 files changed

+142
-5
lines changed

2 files changed

+142
-5
lines changed

src/docket/dependencies.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
from contextlib import AsyncExitStack, asynccontextmanager
55
from contextvars import ContextVar
6-
from datetime import timedelta
6+
from datetime import datetime, timedelta, timezone
77
from types import TracebackType
88
from typing import (
99
TYPE_CHECKING,
@@ -14,6 +14,7 @@
1414
Callable,
1515
Counter,
1616
Generic,
17+
NoReturn,
1718
TypeVar,
1819
cast,
1920
)
@@ -188,6 +189,10 @@ async def my_task(logger: LoggerAdapter[Logger] = TaskLogger()) -> None:
188189
return cast(logging.LoggerAdapter[logging.Logger], _TaskLogger())
189190

190191

192+
class ForcedRetry(Exception):
193+
"""Raised when a task requests a retry via `in_` or `at`"""
194+
195+
191196
class Retry(Dependency):
192197
"""Configures linear retries for a task. You can specify the total number of
193198
attempts (or `None` to retry indefinitely), and the delay between attempts.
@@ -222,6 +227,17 @@ async def __aenter__(self) -> "Retry":
222227
retry.attempt = execution.attempt
223228
return retry
224229

230+
def at(self, when: datetime) -> NoReturn:
231+
now = datetime.now(timezone.utc)
232+
diff = when - now
233+
diff = diff if diff.total_seconds() >= 0 else timedelta(0)
234+
235+
self.in_(diff)
236+
237+
def in_(self, when: timedelta) -> NoReturn:
238+
self.delay: timedelta = when
239+
raise ForcedRetry()
240+
225241

226242
class ExponentialRetry(Retry):
227243
"""Configures exponential retries for a task. You can specify the total number
@@ -251,22 +267,21 @@ def __init__(
251267
maximum_delay: The maximum delay between attempts.
252268
"""
253269
super().__init__(attempts=attempts, delay=minimum_delay)
254-
self.minimum_delay = minimum_delay
255270
self.maximum_delay = maximum_delay
256271

257272
async def __aenter__(self) -> "ExponentialRetry":
258273
execution = self.execution.get()
259274

260275
retry = ExponentialRetry(
261276
attempts=self.attempts,
262-
minimum_delay=self.minimum_delay,
277+
minimum_delay=self.delay,
263278
maximum_delay=self.maximum_delay,
264279
)
265280
retry.attempt = execution.attempt
266281

267282
if execution.attempt > 1:
268283
backoff_factor = 2 ** (execution.attempt - 1)
269-
calculated_delay = self.minimum_delay * backoff_factor
284+
calculated_delay = self.delay * backoff_factor
270285

271286
if calculated_delay > self.maximum_delay:
272287
retry.delay = self.maximum_delay

tests/test_dependencies.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
2+
from datetime import datetime, timedelta, timezone
23

34
import pytest
45

56
from docket import CurrentDocket, CurrentWorker, Docket, Worker
6-
from docket.dependencies import Depends, Retry, TaskArgument
7+
from docket.dependencies import Depends, ExponentialRetry, Retry, TaskArgument
78

89

910
async def test_dependencies_may_be_duplicated(docket: Docket, worker: Worker):
@@ -95,6 +96,127 @@ async def the_task(
9596
assert calls == 2
9697

9798

99+
@pytest.mark.parametrize("retry_cls", [Retry, ExponentialRetry])
100+
async def test_user_can_request_a_retry_in_timedelta_time(
101+
retry_cls: Retry, docket: Docket, worker: Worker
102+
):
103+
calls = 0
104+
first_call_time = None
105+
second_call_time = None
106+
107+
async def the_task(
108+
a: str,
109+
b: str,
110+
retry: Retry = retry_cls(attempts=2), # type: ignore[reportCallIssue]
111+
):
112+
assert a == "a"
113+
assert b == "b"
114+
115+
nonlocal calls
116+
calls += 1
117+
118+
nonlocal first_call_time
119+
if not first_call_time:
120+
first_call_time = datetime.now(timezone.utc)
121+
retry.in_(timedelta(seconds=0.5))
122+
else:
123+
nonlocal second_call_time
124+
second_call_time = datetime.now(timezone.utc)
125+
126+
await docket.add(the_task)("a", "b")
127+
128+
await worker.run_until_finished()
129+
130+
assert calls == 2
131+
132+
assert isinstance(first_call_time, datetime)
133+
assert isinstance(second_call_time, datetime)
134+
135+
delay = second_call_time - first_call_time
136+
assert delay.total_seconds() > 0 < 1
137+
138+
139+
@pytest.mark.parametrize("retry_cls", [Retry, ExponentialRetry])
140+
async def test_user_can_request_a_retry_at_a_specific_time(
141+
retry_cls: Retry, docket: Docket, worker: Worker
142+
):
143+
calls = 0
144+
first_call_time = None
145+
second_call_time = None
146+
147+
async def the_task(
148+
a: str,
149+
b: str,
150+
retry: Retry = retry_cls(attempts=2), # type: ignore[reportCallIssue]
151+
):
152+
assert a == "a"
153+
assert b == "b"
154+
155+
nonlocal calls
156+
calls += 1
157+
158+
nonlocal first_call_time
159+
if not first_call_time:
160+
when = datetime.now(timezone.utc) + timedelta(seconds=0.5)
161+
first_call_time = datetime.now(timezone.utc)
162+
retry.at(when)
163+
else:
164+
nonlocal second_call_time
165+
second_call_time = datetime.now(timezone.utc)
166+
167+
await docket.add(the_task)("a", "b")
168+
169+
await worker.run_until_finished()
170+
171+
assert calls == 2
172+
173+
assert isinstance(first_call_time, datetime)
174+
assert isinstance(second_call_time, datetime)
175+
176+
delay = second_call_time - first_call_time
177+
assert delay.total_seconds() > 0 < 1
178+
179+
180+
async def test_user_can_request_a_retry_at_a_specific_time_in_the_past(
181+
docket: Docket, worker: Worker
182+
):
183+
calls = 0
184+
first_call_time = None
185+
second_call_time = None
186+
187+
async def the_task(
188+
a: str,
189+
b: str,
190+
retry: Retry = Retry(attempts=2),
191+
):
192+
assert a == "a"
193+
assert b == "b"
194+
195+
nonlocal calls
196+
calls += 1
197+
198+
nonlocal first_call_time
199+
if not first_call_time:
200+
when = datetime.now(timezone.utc) - timedelta(days=1)
201+
first_call_time = datetime.now(timezone.utc)
202+
retry.at(when)
203+
else:
204+
nonlocal second_call_time
205+
second_call_time = datetime.now(timezone.utc)
206+
207+
await docket.add(the_task)("a", "b")
208+
209+
await worker.run_until_finished()
210+
211+
assert calls == 2
212+
213+
assert isinstance(first_call_time, datetime)
214+
assert isinstance(second_call_time, datetime)
215+
216+
delay = second_call_time - first_call_time
217+
assert delay.total_seconds() > 0 < 1
218+
219+
98220
async def test_dependencies_error_for_missing_task_argument(
99221
docket: Docket, worker: Worker, caplog: pytest.LogCaptureFixture
100222
):

0 commit comments

Comments
 (0)