Skip to content

Commit 3484ea9

Browse files
chrisguidryclaude
andauthored
Use uncalled-for for dependency injection plumbing (#353)
Docket's DI engine was extracted into the standalone [`uncalled-for`](https://pypi.org/project/uncalled-for/) library. This swaps out all the internal DI plumbing — parameter introspection, dependency resolution, validation, `Depends()`, `Shared`, `SharedContext` — for imports from that package. The public API is unchanged; everything is re-exported from `docket.dependencies` exactly as before. The docket-specific bits (ContextVars, Retry, Perpetual, ConcurrencyLimit, etc.) stay in docket. The three ambient ContextVars (`current_docket`, `current_worker`, `current_execution`) are now module-level variables in `_base.py` rather than class attributes on a custom `Dependency` subclass, so `Dependency` is just a direct re-export of `uncalled_for.Dependency`. Net result: −394 lines, +167 lines across the dependencies package. Closes #352 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent cd4d96b commit 3484ea9

21 files changed

+227
-433
lines changed

docs/dependency-injection.md

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
Docket includes a dependency injection system that provides access to context, configuration, and custom resources. It's similar to FastAPI's dependency injection but tailored for background task patterns.
44

5+
As of version 0.18.0, Docket's dependency injection is built on the
6+
[`uncalled-for`](https://github.com/chrisguidry/uncalled-for) package
7+
([PyPI](https://pypi.org/project/uncalled-for/)), which provides the core
8+
resolution engine, `Depends`, `Shared`, and `Dependency` base class. Docket
9+
layers on task-specific context (`CurrentDocket`, `CurrentWorker`, etc.) and
10+
behavioral dependencies (`Retry`, `Perpetual`, `Timeout`, etc.).
11+
512
## Contextual Dependencies
613

714
### Accessing the Current Docket
@@ -355,12 +362,33 @@ async def fetch_pages(
355362
await process_response(response)
356363
```
357364

358-
Inside `__aenter__`, you can access the current execution context through the class-level context vars `self.docket`, `self.worker`, and `self.execution`:
365+
Inside `__aenter__`, you can access the current execution context through the
366+
module-level context variables `current_docket`, `current_worker`, and
367+
`current_execution`:
359368

360369
```python
370+
from docket.dependencies import Dependency, current_execution, current_worker
371+
361372
class AuditedDependency(Dependency):
362373
async def __aenter__(self) -> AuditLog:
363-
execution = self.execution.get()
364-
worker = self.worker.get()
374+
execution = current_execution.get()
375+
worker = current_worker.get()
365376
return AuditLog(task_key=execution.key, worker_name=worker.name)
366377
```
378+
379+
Or use the higher-level contextual dependencies for cleaner code:
380+
381+
```python
382+
from docket import CurrentExecution, CurrentWorker, Depends, Execution, Worker
383+
384+
async def create_audit_log(
385+
execution: Execution = CurrentExecution(),
386+
worker: Worker = CurrentWorker(),
387+
) -> AuditLog:
388+
return AuditLog(task_key=execution.key, worker_name=worker.name)
389+
390+
async def audited_task(
391+
audit_log: AuditLog = Depends(create_audit_log),
392+
) -> None:
393+
...
394+
```

loq.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ max_lines = 750
1010
# Source files that still need exceptions above 750
1111
[[rules]]
1212
path = "src/docket/worker.py"
13-
max_lines = 1133
13+
max_lines = 1141
1414

1515
[[rules]]
1616
path = "src/docket/cli/__init__.py"
1717
max_lines = 945
1818

1919
[[rules]]
2020
path = "src/docket/execution.py"
21-
max_lines = 1020
21+
max_lines = 1019
2222

2323
[[rules]]
2424
path = "src/docket/docket.py"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
"typer>=0.15.1",
3838
"typing_extensions>=4.12.0",
3939
"tzdata>=2025.2; sys_platform == 'win32'",
40+
"uncalled-for>=0.1.2",
4041
]
4142

4243
[project.optional-dependencies]

src/docket/dependencies/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
FailureHandler,
1414
Runtime,
1515
TaskOutcome,
16+
current_docket,
17+
current_execution,
18+
current_worker,
1619
format_duration,
1720
)
1821
from ._concurrency import ConcurrencyBlocked, ConcurrencyLimit
@@ -53,6 +56,9 @@
5356
"FailureHandler",
5457
"CompletionHandler",
5558
"TaskOutcome",
59+
"current_docket",
60+
"current_execution",
61+
"current_worker",
5662
"format_duration",
5763
# Contextual dependencies
5864
"CurrentDocket",

src/docket/dependencies/_base.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,32 @@
66
from contextvars import ContextVar
77
from dataclasses import dataclass, field
88
from datetime import timedelta
9-
from types import TracebackType
10-
from typing import TYPE_CHECKING, Any, Awaitable, Callable
9+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar
10+
11+
from uncalled_for import Dependency as Dependency
1112

1213
if TYPE_CHECKING: # pragma: no cover
1314
from ..docket import Docket
1415
from ..execution import Execution
1516
from ..worker import Worker
1617

18+
T = TypeVar("T", covariant=True)
19+
20+
current_docket: ContextVar[Docket] = ContextVar("current_docket")
21+
current_worker: ContextVar[Worker] = ContextVar("current_worker")
22+
current_execution: ContextVar[Execution] = ContextVar("current_execution")
23+
24+
# Backwards compatibility: prior to 0.18, docket defined its own Dependency base
25+
# class with class-level ContextVars (Dependency.execution, Dependency.docket,
26+
# Dependency.worker). Now that the base Dependency class comes from uncalled-for,
27+
# those ContextVars live at module scope above. However, downstream consumers
28+
# (notably FastMCP) access them as Dependency.execution.get(), so we monkeypatch
29+
# them back onto the class to avoid breaking existing code. This shim can be
30+
# removed once all known consumers have migrated to the module-level ContextVars.
31+
Dependency.execution = current_execution # type: ignore[attr-defined]
32+
Dependency.docket = current_docket # type: ignore[attr-defined]
33+
Dependency.worker = current_worker # type: ignore[attr-defined]
34+
1735

1836
def format_duration(seconds: float) -> str:
1937
"""Format a duration for log output."""
@@ -45,27 +63,7 @@ def __init__(self, execution: Execution, reason: str = "admission control"):
4563
super().__init__(f"Task {execution.key} blocked by {reason}")
4664

4765

48-
class Dependency(abc.ABC):
49-
"""Base class for all dependencies."""
50-
51-
single: bool = False
52-
53-
docket: ContextVar[Docket] = ContextVar("docket")
54-
worker: ContextVar[Worker] = ContextVar("worker")
55-
execution: ContextVar[Execution] = ContextVar("execution")
56-
57-
@abc.abstractmethod
58-
async def __aenter__(self) -> Any: ... # pragma: no cover
59-
60-
async def __aexit__(
61-
self,
62-
_exc_type: type[BaseException] | None,
63-
_exc_value: BaseException | None,
64-
_traceback: TracebackType | None,
65-
) -> bool: ... # pragma: no cover
66-
67-
68-
class Runtime(Dependency):
66+
class Runtime(Dependency[T]):
6967
"""Base class for dependencies that control task execution.
7068
7169
Only one Runtime dependency can be active per task (single=True).
@@ -93,7 +91,7 @@ async def run(
9391
... # pragma: no cover
9492

9593

96-
class FailureHandler(Dependency):
94+
class FailureHandler(Dependency[T]):
9795
"""Base class for dependencies that control what happens when a task fails.
9896
9997
Called on exceptions. If handle_failure() returns True, the handler
@@ -120,7 +118,7 @@ async def handle_failure(self, execution: Execution, outcome: TaskOutcome) -> bo
120118
... # pragma: no cover
121119

122120

123-
class CompletionHandler(Dependency):
121+
class CompletionHandler(Dependency[T]):
124122
"""Base class for dependencies that control what happens after task completion.
125123
126124
Called after execution is truly done (success, or failure with no retry).

src/docket/dependencies/_concurrency.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from typing import TYPE_CHECKING
99

1010
from .._cancellation import CANCEL_MSG_CLEANUP, cancel_task
11-
from ._base import AdmissionBlocked, Dependency
11+
from ._base import (
12+
AdmissionBlocked,
13+
Dependency,
14+
current_docket,
15+
current_execution,
16+
current_worker,
17+
)
1218

1319
logger = logging.getLogger(__name__)
1420

@@ -38,7 +44,7 @@ def __init__(self, execution: Execution, concurrency_key: str, max_concurrent: i
3844
super().__init__(execution, reason=reason)
3945

4046

41-
class ConcurrencyLimit(Dependency):
47+
class ConcurrencyLimit(Dependency["ConcurrencyLimit"]):
4248
"""Configures concurrency limits for task execution.
4349
4450
Can limit concurrency globally for a task, or per specific argument value.
@@ -94,9 +100,9 @@ def __init__(
94100
async def __aenter__(self) -> ConcurrencyLimit:
95101
from ._functional import _Depends
96102

97-
execution = self.execution.get()
98-
docket = self.docket.get()
99-
worker = self.worker.get()
103+
execution = current_execution.get()
104+
docket = current_docket.get()
105+
worker = current_worker.get()
100106

101107
# Build concurrency key based on argument_name (if provided) or function name
102108
scope = self.scope or docket.name
@@ -151,9 +157,9 @@ async def __aenter__(self) -> ConcurrencyLimit:
151157

152158
async def __aexit__(
153159
self,
154-
_exc_type: type[BaseException] | None,
155-
_exc_value: BaseException | None,
156-
_traceback: type[BaseException] | None,
160+
exc_type: type[BaseException] | None,
161+
exc_value: BaseException | None,
162+
traceback: type[BaseException] | None,
157163
) -> None:
158164
# No-op: The original instance (used as default argument) has no state.
159165
# Actual cleanup is handled by _cleanup() on the per-task instance,
@@ -256,7 +262,7 @@ async def _release_slot(self) -> None:
256262
# Note: only registered as callback for instances with valid keys
257263
assert self._concurrency_key and self._task_key
258264

259-
docket = self.docket.get()
265+
docket = current_docket.get()
260266
async with docket.redis() as redis:
261267
# Remove this task from the sorted set and delete the key if empty
262268
# KEYS[1]: concurrency_key, ARGV[1]: task_key
@@ -272,7 +278,7 @@ async def _release_slot(self) -> None:
272278

273279
async def _renew_lease_loop(self, redelivery_timeout: timedelta) -> None:
274280
"""Periodically refresh slot timestamp to prevent expiration."""
275-
docket = self.docket.get()
281+
docket = current_docket.get()
276282
renewal_interval = redelivery_timeout.total_seconds() / LEASE_RENEWAL_FACTOR
277283
key_ttl = max(
278284
MINIMUM_TTL_SECONDS,

src/docket/dependencies/_contextual.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
import logging
66
from typing import TYPE_CHECKING, Any, cast
77

8-
from ._base import Dependency
8+
from ._base import Dependency, current_docket, current_execution, current_worker
99

1010
if TYPE_CHECKING: # pragma: no cover
1111
from ..docket import Docket
1212
from ..execution import Execution
1313
from ..worker import Worker
1414

1515

16-
class _CurrentWorker(Dependency):
16+
class _CurrentWorker(Dependency["Worker"]):
1717
async def __aenter__(self) -> Worker:
18-
return self.worker.get()
18+
return current_worker.get()
1919

2020

2121
def CurrentWorker() -> Worker:
@@ -32,9 +32,9 @@ async def my_task(worker: Worker = CurrentWorker()) -> None:
3232
return cast("Worker", _CurrentWorker())
3333

3434

35-
class _CurrentDocket(Dependency):
35+
class _CurrentDocket(Dependency["Docket"]):
3636
async def __aenter__(self) -> Docket:
37-
return self.docket.get()
37+
return current_docket.get()
3838

3939

4040
def CurrentDocket() -> Docket:
@@ -51,9 +51,9 @@ async def my_task(docket: Docket = CurrentDocket()) -> None:
5151
return cast("Docket", _CurrentDocket())
5252

5353

54-
class _CurrentExecution(Dependency):
54+
class _CurrentExecution(Dependency["Execution"]):
5555
async def __aenter__(self) -> Execution:
56-
return self.execution.get()
56+
return current_execution.get()
5757

5858

5959
def CurrentExecution() -> Execution:
@@ -70,9 +70,9 @@ async def my_task(execution: Execution = CurrentExecution()) -> None:
7070
return cast("Execution", _CurrentExecution())
7171

7272

73-
class _TaskKey(Dependency):
73+
class _TaskKey(Dependency[str]):
7474
async def __aenter__(self) -> str:
75-
return self.execution.get().key
75+
return current_execution.get().key
7676

7777

7878
def TaskKey() -> str:
@@ -89,7 +89,7 @@ async def my_task(key: str = TaskKey()) -> None:
8989
return cast(str, _TaskKey())
9090

9191

92-
class _TaskArgument(Dependency):
92+
class _TaskArgument(Dependency[Any]):
9393
parameter: str | None
9494
optional: bool
9595

@@ -99,7 +99,7 @@ def __init__(self, parameter: str | None = None, optional: bool = False) -> None
9999

100100
async def __aenter__(self) -> Any:
101101
assert self.parameter is not None
102-
execution = self.execution.get()
102+
execution = current_execution.get()
103103
try:
104104
return execution.get_argument(self.parameter)
105105
except KeyError:
@@ -128,15 +128,15 @@ async def greet_customer(customer_id: int, name: str = Depends(customer_name)) -
128128
return cast(Any, _TaskArgument(parameter, optional))
129129

130130

131-
class _TaskLogger(Dependency):
131+
class _TaskLogger(Dependency["logging.LoggerAdapter[logging.Logger]"]):
132132
async def __aenter__(self) -> logging.LoggerAdapter[logging.Logger]:
133-
execution = self.execution.get()
133+
execution = current_execution.get()
134134
logger = logging.getLogger(f"docket.task.{execution.function_name}")
135135
return logging.LoggerAdapter(
136136
logger,
137137
{
138-
**self.docket.get().labels(),
139-
**self.worker.get().labels(),
138+
**current_docket.get().labels(),
139+
**current_worker.get().labels(),
140140
**execution.specific_labels(),
141141
},
142142
)

src/docket/dependencies/_cron.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from croniter import croniter
99

10+
from ._base import current_execution
1011
from ._perpetual import Perpetual
1112

1213
if TYPE_CHECKING: # pragma: no cover
@@ -82,7 +83,7 @@ def __init__(
8283
self._croniter = croniter(self.expression, datetime.now(self.tz), datetime)
8384

8485
async def __aenter__(self) -> Cron:
85-
execution = self.execution.get()
86+
execution = current_execution.get()
8687
cron = Cron(expression=self.expression, automatic=self.automatic, tz=self.tz)
8788
cron.args = execution.args
8889
cron.kwargs = execution.kwargs

0 commit comments

Comments
 (0)