Skip to content

Commit 0062982

Browse files
chrisguidryclaude
andauthored
Add FailureHandler and CompletionHandler base classes for dependency hooks (#299)
## Summary Introduces two new abstract base classes that let dependencies control post-execution flow: - `FailureHandler` - controls what happens when a task fails. `Retry` now inherits from this and implements `handle_failure()` to schedule retries. - `CompletionHandler` - controls what happens after task completion. `Perpetual` now inherits from this and implements `on_complete()` to schedule the next execution. This follows the pattern established by `Runtime` for `Timeout`: the Worker delegates control to specialized dependencies rather than knowing the details itself. The Worker's `_retry_if_requested()` and `_perpetuate_if_requested()` methods are removed (~50 lines), with that logic now living in the dependencies. The dependency hierarchy now looks like: ``` Dependency (base - can observe via __aexit__) ├── Runtime (controls HOW task executes - Timeout) ├── FailureHandler (controls WHAT HAPPENS on failure - Retry) └── CompletionHandler (controls WHAT HAPPENS after completion - Perpetual) ``` All three have `single = True` because only one thing can control each aspect. Also adds `after(timedelta)` and `at(datetime)` methods to both `Perpetual` and `Retry`, giving developers flexibility to schedule the next execution using whatever they have handy. For `Retry`, the old `in_()` method is kept as a backwards-compatible alias. Closes #297 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c8b09e7 commit 0062982

File tree

14 files changed

+627
-191
lines changed

14 files changed

+627
-191
lines changed

loq.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@ max_lines = 750
1010
# Source files that still need exceptions above 750
1111
[[rules]]
1212
path = "src/docket/worker.py"
13-
max_lines = 1150
13+
max_lines = 1125
1414

1515
[[rules]]
1616
path = "src/docket/cli.py"
1717
max_lines = 945
1818

1919
[[rules]]
2020
path = "src/docket/execution.py"
21-
max_lines = 900
21+
max_lines = 903
2222

2323
[[rules]]
2424
path = "src/docket/docket.py"
25-
max_lines = 900
25+
max_lines = 866
2626

2727
[[rules]]
2828
path = "src/docket/strikelist.py"
29-
max_lines = 650
29+
max_lines = 616

src/docket/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
Timeout,
2828
)
2929
from .docket import Docket
30-
from .execution import Execution, ExecutionState
30+
from .execution import Execution, ExecutionCancelled, ExecutionState
3131
from .strikelist import StrikeList
3232
from .worker import Worker
3333
from . import testing
@@ -42,6 +42,7 @@
4242
"Depends",
4343
"Docket",
4444
"Execution",
45+
"ExecutionCancelled",
4546
"ExecutionState",
4647
"ExponentialRetry",
4748
"Logged",

src/docket/dependencies/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,15 @@
66

77
from __future__ import annotations
88

9-
from ._base import AdmissionBlocked, Dependency, Runtime
9+
from ._base import (
10+
AdmissionBlocked,
11+
CompletionHandler,
12+
Dependency,
13+
FailureHandler,
14+
Runtime,
15+
TaskOutcome,
16+
format_duration,
17+
)
1018
from ._concurrency import ConcurrencyBlocked, ConcurrencyLimit
1119
from ._contextual import (
1220
CurrentDocket,
@@ -41,6 +49,10 @@
4149
# Base
4250
"Dependency",
4351
"Runtime",
52+
"FailureHandler",
53+
"CompletionHandler",
54+
"TaskOutcome",
55+
"format_duration",
4456
# Contextual dependencies
4557
"CurrentDocket",
4658
"CurrentExecution",

src/docket/dependencies/_base.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import abc
66
from contextvars import ContextVar
7+
from dataclasses import dataclass, field
8+
from datetime import timedelta
79
from types import TracebackType
810
from typing import TYPE_CHECKING, Any, Awaitable, Callable
911

@@ -13,6 +15,23 @@
1315
from ..worker import Worker
1416

1517

18+
def format_duration(seconds: float) -> str:
19+
"""Format a duration for log output."""
20+
if seconds < 100:
21+
return f"{seconds * 1000:6.0f}ms"
22+
else:
23+
return f"{seconds:6.0f}s "
24+
25+
26+
@dataclass
27+
class TaskOutcome:
28+
"""Captures the outcome of a task execution for handlers."""
29+
30+
duration: timedelta
31+
result: Any = field(default=None)
32+
exception: BaseException | None = field(default=None)
33+
34+
1635
class AdmissionBlocked(Exception):
1736
"""Raised when a task cannot start due to admission control.
1837
@@ -72,3 +91,57 @@ async def run(
7291
kwargs: Keyword arguments including resolved dependencies
7392
"""
7493
... # pragma: no cover
94+
95+
96+
class FailureHandler(Dependency):
97+
"""Base class for dependencies that control what happens when a task fails.
98+
99+
Called on exceptions. If handle_failure() returns True, the handler
100+
took responsibility (e.g., scheduled a retry) and Worker won't mark
101+
the execution as failed.
102+
103+
Only one FailureHandler per task (single=True).
104+
"""
105+
106+
single = True
107+
108+
@abc.abstractmethod
109+
async def handle_failure(self, execution: Execution, outcome: TaskOutcome) -> bool:
110+
"""Handle a task failure.
111+
112+
Args:
113+
execution: The task execution context
114+
outcome: The task outcome containing duration and exception
115+
116+
Returns:
117+
True if handled (Worker won't mark as failed)
118+
False if not handled (Worker proceeds normally)
119+
"""
120+
... # pragma: no cover
121+
122+
123+
class CompletionHandler(Dependency):
124+
"""Base class for dependencies that control what happens after task completion.
125+
126+
Called after execution is truly done (success, or failure with no retry).
127+
If on_complete() returns True, the handler took responsibility (e.g.,
128+
scheduled follow-up work) and did its own logging.
129+
130+
Only one CompletionHandler per task (single=True).
131+
"""
132+
133+
single = True
134+
135+
@abc.abstractmethod
136+
async def on_complete(self, execution: Execution, outcome: TaskOutcome) -> bool:
137+
"""Handle task completion.
138+
139+
Args:
140+
execution: The task execution context
141+
outcome: The task outcome containing duration, result, and exception
142+
143+
Returns:
144+
True if handled (did own logging/metrics)
145+
False if not handled (Worker does normal logging)
146+
"""
147+
... # pragma: no cover

src/docket/dependencies/_perpetual.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,21 @@
22

33
from __future__ import annotations
44

5-
from datetime import timedelta
6-
from typing import Any
5+
import logging
6+
from datetime import datetime, timedelta, timezone
7+
from typing import TYPE_CHECKING, Any
78

8-
from ._base import Dependency
9+
from ._base import CompletionHandler, TaskOutcome, format_duration
910

11+
if TYPE_CHECKING: # pragma: no cover
12+
from ..execution import Execution
1013

11-
class Perpetual(Dependency):
14+
from ..instrumentation import TASKS_PERPETUATED
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class Perpetual(CompletionHandler):
1220
"""Declare a task that should be run perpetually. Perpetual tasks are automatically
1321
rescheduled for the future after they finish (whether they succeed or fail). A
1422
perpetual task can be scheduled at worker startup with the `automatic=True`.
@@ -31,6 +39,7 @@ async def my_task(perpetual: Perpetual = Perpetual()) -> None:
3139
kwargs: dict[str, Any]
3240

3341
cancelled: bool
42+
_next_when: datetime | None
3443

3544
def __init__(
3645
self,
@@ -48,6 +57,7 @@ def __init__(
4857
self.every = every
4958
self.automatic = automatic
5059
self.cancelled = False
60+
self._next_when = None
5161

5262
async def __aenter__(self) -> Perpetual:
5363
execution = self.execution.get()
@@ -62,3 +72,42 @@ def cancel(self) -> None:
6272
def perpetuate(self, *args: Any, **kwargs: Any) -> None:
6373
self.args = args
6474
self.kwargs = kwargs
75+
76+
def after(self, delay: timedelta) -> None:
77+
"""Schedule the next execution after the given delay."""
78+
self._next_when = datetime.now(timezone.utc) + delay
79+
80+
def at(self, when: datetime) -> None:
81+
"""Schedule the next execution at the given time."""
82+
self._next_when = when
83+
84+
async def on_complete(self, execution: Execution, outcome: TaskOutcome) -> bool:
85+
"""Handle completion by scheduling the next execution."""
86+
if self.cancelled:
87+
docket = self.docket.get()
88+
async with docket.redis() as redis:
89+
await docket._cancel(redis, execution.key)
90+
return False
91+
92+
docket = self.docket.get()
93+
worker = self.worker.get()
94+
95+
if self._next_when:
96+
when = self._next_when
97+
else:
98+
now = datetime.now(timezone.utc)
99+
when = max(now, now + self.every - outcome.duration)
100+
101+
await docket.replace(execution.function, when, execution.key)(
102+
*self.args,
103+
**self.kwargs,
104+
)
105+
106+
TASKS_PERPETUATED.add(1, {**worker.labels(), **execution.general_labels()})
107+
logger.info(
108+
"↫ [%s] %s",
109+
format_duration(outcome.duration.total_seconds()),
110+
execution.call_repr(),
111+
)
112+
113+
return True

src/docket/dependencies/_retry.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,25 @@
22

33
from __future__ import annotations
44

5+
import logging
56
from datetime import datetime, timedelta, timezone
6-
from typing import NoReturn
7+
from typing import TYPE_CHECKING, NoReturn
78

8-
from ._base import Dependency
9+
from ._base import FailureHandler, TaskOutcome, format_duration
10+
11+
if TYPE_CHECKING: # pragma: no cover
12+
from ..execution import Execution
13+
14+
from ..instrumentation import TASKS_RETRIED
15+
16+
logger = logging.getLogger(__name__)
917

1018

1119
class ForcedRetry(Exception):
12-
"""Raised when a task requests a retry via `in_` or `at`"""
20+
"""Raised when a task requests a retry via `after` or `at`"""
1321

1422

15-
class Retry(Dependency):
23+
class Retry(FailureHandler):
1624
"""Configures linear retries for a task. You can specify the total number of
1725
attempts (or `None` to retry indefinitely), and the delay between attempts.
1826
@@ -50,16 +58,40 @@ async def __aenter__(self) -> Retry:
5058
retry.attempt = execution.attempt
5159
return retry
5260

61+
def after(self, delay: timedelta) -> NoReturn:
62+
"""Request a retry after the given delay."""
63+
self.delay = delay
64+
raise ForcedRetry()
65+
5366
def at(self, when: datetime) -> NoReturn:
67+
"""Request a retry at the given time."""
5468
now = datetime.now(timezone.utc)
5569
diff = when - now
5670
diff = diff if diff.total_seconds() >= 0 else timedelta(0)
71+
self.after(diff)
72+
73+
def in_(self, delay: timedelta) -> NoReturn:
74+
"""Deprecated: use after() instead."""
75+
self.after(delay)
76+
77+
async def handle_failure(self, execution: Execution, outcome: TaskOutcome) -> bool:
78+
"""Handle failure by scheduling a retry if attempts remain."""
79+
if self.attempts is not None and execution.attempt >= self.attempts:
80+
return False
81+
82+
execution.when = datetime.now(timezone.utc) + self.delay
83+
execution.attempt += 1
84+
await execution.schedule(replace=True)
85+
86+
worker = self.worker.get()
87+
TASKS_RETRIED.add(1, {**worker.labels(), **execution.general_labels()})
88+
logger.info(
89+
"↫ [%s] %s",
90+
format_duration(outcome.duration.total_seconds()),
91+
execution.call_repr(),
92+
)
5793

58-
self.in_(diff)
59-
60-
def in_(self, when: timedelta) -> NoReturn:
61-
self.delay = when
62-
raise ForcedRetry()
94+
return True
6395

6496

6597
class ExponentialRetry(Retry):

src/docket/execution.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@
3333

3434
logger: logging.Logger = logging.getLogger(__name__)
3535

36+
37+
class ExecutionCancelled(Exception):
38+
"""Raised when get_result() is called on a cancelled execution."""
39+
40+
pass
41+
42+
3643
TaskFunction = Callable[..., Awaitable[Any]]
3744
Message = dict[bytes, bytes]
3845

@@ -682,8 +689,14 @@ async def get_result(
682689
if timeout is not None:
683690
deadline = datetime.now(timezone.utc) + timeout
684691

692+
terminal_states = (
693+
ExecutionState.COMPLETED,
694+
ExecutionState.FAILED,
695+
ExecutionState.CANCELLED,
696+
)
697+
685698
# Wait for execution to complete if not already done
686-
if self.state not in (ExecutionState.COMPLETED, ExecutionState.FAILED):
699+
if self.state not in terminal_states:
687700
# Calculate timeout duration if absolute deadline provided
688701
timeout_seconds = None
689702
if deadline is not None:
@@ -701,10 +714,7 @@ async def wait_for_completion():
701714
async for event in self.subscribe(): # pragma: no branch
702715
if event["type"] == "state":
703716
state = ExecutionState(event["state"])
704-
if state in (
705-
ExecutionState.COMPLETED,
706-
ExecutionState.FAILED,
707-
):
717+
if state in terminal_states:
708718
# Sync to get latest data including result key
709719
await self.sync()
710720
break
@@ -716,6 +726,10 @@ async def wait_for_completion():
716726
f"Timeout waiting for execution {self.key} to complete"
717727
)
718728

729+
# If cancelled, raise ExecutionCancelled
730+
if self.state == ExecutionState.CANCELLED:
731+
raise ExecutionCancelled(f"Execution {self.key} was cancelled")
732+
719733
# If failed, retrieve and raise the exception
720734
if self.state == ExecutionState.FAILED:
721735
if self.result_key:

0 commit comments

Comments
 (0)