Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/docket/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from contextlib import AsyncExitStack, asynccontextmanager
from contextvars import ContextVar
from datetime import timedelta
from datetime import datetime, timedelta, timezone
from types import TracebackType
from typing import (
TYPE_CHECKING,
Expand All @@ -14,6 +14,7 @@
Callable,
Counter,
Generic,
NoReturn,
TypeVar,
cast,
)
Expand Down Expand Up @@ -188,6 +189,10 @@ async def my_task(logger: LoggerAdapter[Logger] = TaskLogger()) -> None:
return cast(logging.LoggerAdapter[logging.Logger], _TaskLogger())


class ForcedRetry(Exception):
"""Raised when a task requests a retry via `in_` or `at`"""


class Retry(Dependency):
"""Configures linear retries for a task. You can specify the total number of
attempts (or `None` to retry indefinitely), and the delay between attempts.
Expand Down Expand Up @@ -222,6 +227,17 @@ async def __aenter__(self) -> "Retry":
retry.attempt = execution.attempt
return retry

def at(self, when: datetime) -> NoReturn:
now = datetime.now(timezone.utc)
diff = when - now
diff = diff if diff.total_seconds() >= 0 else timedelta(0)

self.in_(diff)

def in_(self, when: timedelta) -> NoReturn:
self.delay: timedelta = when
raise ForcedRetry()


class ExponentialRetry(Retry):
"""Configures exponential retries for a task. You can specify the total number
Expand Down Expand Up @@ -251,22 +267,21 @@ def __init__(
maximum_delay: The maximum delay between attempts.
"""
super().__init__(attempts=attempts, delay=minimum_delay)
self.minimum_delay = minimum_delay
self.maximum_delay = maximum_delay

async def __aenter__(self) -> "ExponentialRetry":
execution = self.execution.get()

retry = ExponentialRetry(
attempts=self.attempts,
minimum_delay=self.minimum_delay,
minimum_delay=self.delay,
maximum_delay=self.maximum_delay,
)
retry.attempt = execution.attempt

if execution.attempt > 1:
backoff_factor = 2 ** (execution.attempt - 1)
calculated_delay = self.minimum_delay * backoff_factor
calculated_delay = self.delay * backoff_factor

if calculated_delay > self.maximum_delay:
retry.delay = self.maximum_delay
Expand Down
124 changes: 123 additions & 1 deletion tests/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from datetime import datetime, timedelta, timezone

import pytest

from docket import CurrentDocket, CurrentWorker, Docket, Worker
from docket.dependencies import Depends, Retry, TaskArgument
from docket.dependencies import Depends, ExponentialRetry, Retry, TaskArgument


async def test_dependencies_may_be_duplicated(docket: Docket, worker: Worker):
Expand Down Expand Up @@ -95,6 +96,127 @@ async def the_task(
assert calls == 2


@pytest.mark.parametrize("retry_cls", [Retry, ExponentialRetry])
async def test_user_can_request_a_retry_in_timedelta_time(
retry_cls: Retry, docket: Docket, worker: Worker
):
calls = 0
first_call_time = None
second_call_time = None

async def the_task(
a: str,
b: str,
retry: Retry = retry_cls(attempts=2), # type: ignore[reportCallIssue]
):
assert a == "a"
assert b == "b"

nonlocal calls
calls += 1

nonlocal first_call_time
if not first_call_time:
first_call_time = datetime.now(timezone.utc)
retry.in_(timedelta(seconds=0.5))
else:
nonlocal second_call_time
second_call_time = datetime.now(timezone.utc)

await docket.add(the_task)("a", "b")

await worker.run_until_finished()

assert calls == 2

assert isinstance(first_call_time, datetime)
assert isinstance(second_call_time, datetime)

delay = second_call_time - first_call_time
assert delay.total_seconds() > 0 < 1


@pytest.mark.parametrize("retry_cls", [Retry, ExponentialRetry])
async def test_user_can_request_a_retry_at_a_specific_time(
retry_cls: Retry, docket: Docket, worker: Worker
):
calls = 0
first_call_time = None
second_call_time = None

async def the_task(
a: str,
b: str,
retry: Retry = retry_cls(attempts=2), # type: ignore[reportCallIssue]
):
assert a == "a"
assert b == "b"

nonlocal calls
calls += 1

nonlocal first_call_time
if not first_call_time:
when = datetime.now(timezone.utc) + timedelta(seconds=0.5)
first_call_time = datetime.now(timezone.utc)
retry.at(when)
else:
nonlocal second_call_time
second_call_time = datetime.now(timezone.utc)

await docket.add(the_task)("a", "b")

await worker.run_until_finished()

assert calls == 2

assert isinstance(first_call_time, datetime)
assert isinstance(second_call_time, datetime)

delay = second_call_time - first_call_time
assert delay.total_seconds() > 0 < 1


async def test_user_can_request_a_retry_at_a_specific_time_in_the_past(
docket: Docket, worker: Worker
):
calls = 0
first_call_time = None
second_call_time = None

async def the_task(
a: str,
b: str,
retry: Retry = Retry(attempts=2),
):
assert a == "a"
assert b == "b"

nonlocal calls
calls += 1

nonlocal first_call_time
if not first_call_time:
when = datetime.now(timezone.utc) - timedelta(days=1)
first_call_time = datetime.now(timezone.utc)
retry.at(when)
else:
nonlocal second_call_time
second_call_time = datetime.now(timezone.utc)

await docket.add(the_task)("a", "b")

await worker.run_until_finished()

assert calls == 2

assert isinstance(first_call_time, datetime)
assert isinstance(second_call_time, datetime)

delay = second_call_time - first_call_time
assert delay.total_seconds() > 0 < 1


async def test_dependencies_error_for_missing_task_argument(
docket: Docket, worker: Worker, caplog: pytest.LogCaptureFixture
):
Expand Down
Loading