Skip to content

Commit a736a68

Browse files
authored
feat: Support task backoff (#81)
* feat: Support task backoff
1 parent b6a250b commit a736a68

File tree

10 files changed

+194
-28
lines changed

10 files changed

+194
-28
lines changed

poetry.lock

Lines changed: 14 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pytest-mock = "^3.14.0"
7979
ruff = "==0.11.9"
8080
setuptools = "^78.1.1"
8181
types-simplejson = "^3.20.0.20250326"
82+
types-python-dateutil = "^2.9.0.20250516"
8283

8384
[build-system]
8485
requires = ["poetry-core"]

settings/dev.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
ENABLE_CLEAN_UP_OLD_TASKS = True
5959
ENABLE_TASK_PROCESSOR_HEALTH_CHECK = True
6060
RECURRING_TASK_RUN_RETENTION_DAYS = 15
61+
TASK_BACKOFF_DEFAULT_DELAY_SECONDS = 5
6162
TASK_DELETE_BATCH_SIZE = 2000
6263
TASK_DELETE_INCLUDE_FAILED_TASKS = False
6364
TASK_DELETE_RETENTION_DAYS = 15

src/task_processor/exceptions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from datetime import datetime
2+
3+
14
class TaskProcessingError(Exception):
25
pass
36

@@ -6,5 +9,20 @@ class InvalidArgumentsError(TaskProcessingError):
69
pass
710

811

12+
class TaskBackoffError(TaskProcessingError):
13+
"""
14+
Raise this exception inside a task to indicate that it should be retried after a delay.
15+
This is typically used when a task fails due to a temporary issue, such as
16+
a network error or a service being unavailable.
17+
"""
18+
19+
def __init__(
20+
self,
21+
delay_until: datetime | None = None,
22+
) -> None:
23+
super().__init__()
24+
self.delay_until = delay_until
25+
26+
927
class TaskQueueFullError(Exception):
1028
pass

src/task_processor/models.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from django.db import models
88
from django.utils import timezone
99

10-
from task_processor.exceptions import TaskProcessingError, TaskQueueFullError
10+
from task_processor.exceptions import TaskQueueFullError
1111
from task_processor.managers import RecurringTaskManager, TaskManager
12-
from task_processor.task_registry import registered_tasks
12+
from task_processor.task_registry import get_task, registered_tasks
1313
from task_processor.types import TaskCallable
1414

1515
_django_json_encoder_default = DjangoJSONEncoder().default
@@ -36,13 +36,11 @@ class Meta:
3636
abstract = True
3737

3838
@property
39-
def args(self) -> typing.List[typing.Any]:
39+
def args(self) -> tuple[typing.Any, ...]:
4040
if self.serialized_args:
4141
args = self.deserialize_data(self.serialized_args)
42-
if typing.TYPE_CHECKING:
43-
assert isinstance(args, list)
44-
return args
45-
return []
42+
return tuple(args)
43+
return ()
4644

4745
@property
4846
def kwargs(self) -> typing.Dict[str, typing.Any]:
@@ -75,15 +73,8 @@ def run(self) -> None:
7573

7674
@property
7775
def callable(self) -> TaskCallable[typing.Any]:
78-
try:
79-
task = registered_tasks[self.task_identifier]
80-
return task.task_function
81-
except KeyError as e:
82-
raise TaskProcessingError(
83-
"No task registered with identifier '%s'. Ensure your task is "
84-
"decorated with @register_task_handler.",
85-
self.task_identifier,
86-
) from e
76+
task = get_task(self.task_identifier)
77+
return task.task_function
8778

8879

8980
class Task(AbstractBaseTask):

src/task_processor/processor.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from django.utils import timezone
1010

1111
from task_processor import metrics
12+
from task_processor.exceptions import TaskBackoffError
1213
from task_processor.managers import RecurringTaskManager, TaskManager
1314
from task_processor.models import (
1415
AbstractBaseTask,
@@ -18,7 +19,7 @@
1819
TaskResult,
1920
TaskRun,
2021
)
21-
from task_processor.task_registry import get_task
22+
from task_processor.task_registry import TaskType, get_task
2223

2324
T = typing.TypeVar("T", bound=AbstractBaseTask)
2425
AnyTaskRun = TaskRun | RecurringTaskRun
@@ -50,7 +51,7 @@ def run_tasks(database: str, num_tasks: int = 1) -> list[TaskRun]:
5051
if executed_tasks:
5152
Task.objects.using(database).bulk_update(
5253
executed_tasks,
53-
fields=["completed", "num_failures", "is_locked"],
54+
fields=["completed", "num_failures", "is_locked", "scheduled_for"],
5455
)
5556

5657
if task_runs:
@@ -120,6 +121,7 @@ def _run_task(
120121
ctx.enter_context(timer)
121122

122123
task_identifier = task.task_identifier
124+
registered_task = get_task(task_identifier)
123125

124126
logger.debug(
125127
f"Running task {task_identifier} id={task.pk} args={task.args} kwargs={task.kwargs}"
@@ -157,9 +159,26 @@ def _run_task(
157159
exc_info=True,
158160
)
159161

162+
if isinstance(e, TaskBackoffError):
163+
assert registered_task.task_type == TaskType.STANDARD, (
164+
"Attempt to back off a recurring task (currently not supported)"
165+
)
166+
if typing.TYPE_CHECKING:
167+
assert isinstance(task, Task)
168+
if task.num_failures <= 3:
169+
delay_until = e.delay_until or timezone.now() + timedelta(
170+
seconds=settings.TASK_BACKOFF_DEFAULT_DELAY_SECONDS,
171+
)
172+
task.scheduled_for = delay_until
173+
logger.info(
174+
"Backoff requested. Task '%s' set to retry at %s",
175+
task_identifier,
176+
delay_until,
177+
)
178+
160179
labels = {
161180
"task_identifier": task_identifier,
162-
"task_type": get_task(task_identifier).task_type.value.lower(),
181+
"task_type": registered_task.task_type.value.lower(),
163182
"result": result.lower(),
164183
}
165184

src/task_processor/task_registry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing
44
from dataclasses import dataclass
55

6+
from task_processor.exceptions import TaskProcessingError
67
from task_processor.types import TaskCallable
78

89
logger = logging.getLogger(__name__)
@@ -43,7 +44,14 @@ def initialise() -> None:
4344
def get_task(task_identifier: str) -> RegisteredTask:
4445
global registered_tasks
4546

46-
return registered_tasks[task_identifier]
47+
try:
48+
return registered_tasks[task_identifier]
49+
except KeyError:
50+
raise TaskProcessingError(
51+
"No task registered with identifier '%s'. Ensure your task is "
52+
"decorated with @register_task_handler.",
53+
task_identifier,
54+
)
4755

4856

4957
def register_task(

tests/unit/task_processor/test_unit_task_processor_decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
register_recurring_task,
1515
register_task_handler,
1616
)
17-
from task_processor.exceptions import InvalidArgumentsError
17+
from task_processor.exceptions import InvalidArgumentsError, TaskProcessingError
1818
from task_processor.models import RecurringTask, Task, TaskPriority
1919
from task_processor.task_registry import get_task, initialise
2020
from task_processor.task_run_method import TaskRunMethod
@@ -143,7 +143,7 @@ def some_function(first_arg: str, second_arg: str) -> None:
143143

144144
# Then
145145
assert not RecurringTask.objects.filter(task_identifier=task_identifier).exists()
146-
with pytest.raises(KeyError):
146+
with pytest.raises(TaskProcessingError):
147147
assert get_task(task_identifier)
148148

149149

tests/unit/task_processor/test_unit_task_processor_models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ def test_task_run(mocker: MockerFixture) -> None:
4040
mock.assert_called_once_with(*args, **kwargs)
4141

4242

43+
def test_task_args__no_data__return_expected() -> None:
44+
# Given
45+
task = Task(
46+
task_identifier="test_task",
47+
scheduled_for=timezone.now(),
48+
)
49+
50+
# When & Then
51+
assert task.args == ()
52+
53+
4354
@pytest.mark.parametrize(
4455
"input, expected_output",
4556
(

tests/unit/task_processor/test_unit_task_processor_processor.py

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import time
33
import uuid
4-
from datetime import timedelta
4+
from datetime import datetime, timedelta
55
from threading import Thread
66

77
import pytest
@@ -11,12 +11,13 @@
1111
from pytest_django.fixtures import SettingsWrapper
1212
from pytest_mock import MockerFixture
1313

14-
from common.test_tools import AssertMetricFixture
14+
from common.test_tools.types import AssertMetricFixture
1515
from task_processor.decorators import (
1616
TaskHandler,
1717
register_recurring_task,
1818
register_task_handler,
1919
)
20+
from task_processor.exceptions import TaskBackoffError
2021
from task_processor.models import (
2122
RecurringTask,
2223
RecurringTaskRun,
@@ -500,7 +501,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure(
500501
),
501502
(
502503
logging.DEBUG,
503-
f"Running task {task.task_identifier} id={task.id} args=['{msg}'] kwargs={{}}",
504+
f"Running task {task.task_identifier} id={task.id} args=('{msg}',) kwargs={{}}",
504505
),
505506
(
506507
logging.ERROR,
@@ -636,6 +637,111 @@ def test_run_task_runs_tasks_in_correct_priority(
636637
assert task_runs_3[0].task == task_2
637638

638639

640+
@pytest.mark.parametrize(
641+
"exception, expected_scheduled_for",
642+
[
643+
(TaskBackoffError(), datetime.fromisoformat("2023-12-08T06:05:57+00:00")),
644+
(
645+
TaskBackoffError(
646+
delay_until=datetime.fromisoformat("2023-12-08T06:15:52+00:00")
647+
),
648+
datetime.fromisoformat("2023-12-08T06:15:52+00:00"),
649+
),
650+
],
651+
)
652+
@pytest.mark.freeze_time("2023-12-08T06:05:47+00:00")
653+
@pytest.mark.multi_database
654+
@pytest.mark.task_processor_mode
655+
def test_run_task__backoff__persists_expected(
656+
exception: TaskBackoffError,
657+
expected_scheduled_for: datetime,
658+
current_database: str,
659+
settings: SettingsWrapper,
660+
caplog: pytest.LogCaptureFixture,
661+
) -> None:
662+
# Given
663+
settings.TASK_BACKOFF_DEFAULT_DELAY_SECONDS = 10
664+
665+
@register_task_handler()
666+
def backoff_task() -> None:
667+
raise exception
668+
669+
task = Task.create(
670+
backoff_task.task_identifier,
671+
scheduled_for=timezone.now(),
672+
args=(),
673+
priority=TaskPriority.HIGH,
674+
)
675+
task.save(using=current_database)
676+
677+
caplog.set_level(logging.INFO)
678+
expected_log_message = f"Backoff requested. Task '{backoff_task.task_identifier}' set to retry at {expected_scheduled_for}"
679+
680+
# When
681+
run_tasks(current_database)
682+
683+
# Then
684+
assert [
685+
record.message for record in caplog.records if record.levelno == logging.INFO
686+
] == [expected_log_message]
687+
task.refresh_from_db(using=current_database)
688+
assert task.scheduled_for == expected_scheduled_for
689+
690+
691+
@pytest.mark.multi_database
692+
@pytest.mark.task_processor_mode
693+
def test_run_task__backoff__recurring__raises_expected(
694+
current_database: str,
695+
) -> None:
696+
# Given
697+
@register_recurring_task(run_every=timedelta(seconds=1))
698+
def backoff_task() -> None:
699+
raise TaskBackoffError()
700+
701+
initialise()
702+
703+
# When & Then
704+
with pytest.raises(AssertionError) as exc_info:
705+
run_recurring_tasks(current_database)
706+
707+
assert (
708+
str(exc_info.value)
709+
== "Attempt to back off a recurring task (currently not supported)"
710+
)
711+
712+
713+
@pytest.mark.multi_database
714+
@pytest.mark.task_processor_mode
715+
def test_run_task__backoff__max_num_failures__noop(
716+
current_database: str,
717+
caplog: pytest.LogCaptureFixture,
718+
) -> None:
719+
# Given
720+
@register_task_handler()
721+
def backoff_task() -> None:
722+
raise TaskBackoffError()
723+
724+
expected_scheduled_for = timezone.now()
725+
task = Task.create(
726+
backoff_task.task_identifier,
727+
scheduled_for=expected_scheduled_for,
728+
args=(),
729+
priority=TaskPriority.HIGH,
730+
)
731+
task.num_failures = 4
732+
task.save(using=current_database)
733+
734+
caplog.set_level(logging.INFO)
735+
736+
# When
737+
run_tasks(current_database)
738+
739+
# Then
740+
task.refresh_from_db(using=current_database)
741+
assert task.scheduled_for == expected_scheduled_for
742+
assert not [record for record in caplog.records if record.levelno == logging.INFO]
743+
744+
639745
@pytest.mark.multi_database
640746
def test_run_tasks__fails_if_not_in_task_processor_mode(
641747
current_database: str,

0 commit comments

Comments
 (0)