Skip to content

Commit d3ea664

Browse files
chrisguidryclaude
andcommitted
Use uncalled-for for dependency injection plumbing
Docket's DI engine was extracted into the standalone `uncalled-for` library (https://pypi.org/project/uncalled-for/). This swaps out all the internal DI plumbing — parameter introspection, dependency resolution, validation, `Depends()`, `Shared`, `SharedContext` — for imports from that package. The public API is unchanged; everything is re-exported from `docket.dependencies` exactly as before. The docket-specific bits (ContextVars, Retry, Perpetual, ConcurrencyLimit, etc.) stay in docket. The three ambient ContextVars (`current_docket`, `current_worker`, `current_execution`) are now module-level variables in `_base.py` rather than class attributes on a custom `Dependency` subclass, so `Dependency` is just a direct re-export of `uncalled_for.Dependency`. Net result: −394 lines, +167 lines across the dependencies package. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c51bde1 commit d3ea664

File tree

17 files changed

+167
-394
lines changed

17 files changed

+167
-394
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
"typer>=0.15.1",
3838
"typing_extensions>=4.12.0",
3939
"tzdata>=2025.2; sys_platform == 'win32'",
40+
"uncalled-for>=0.1.2",
4041
]
4142

4243
[project.optional-dependencies]

src/docket/dependencies/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
FailureHandler,
1414
Runtime,
1515
TaskOutcome,
16+
current_docket,
17+
current_execution,
18+
current_worker,
1619
format_duration,
1720
)
1821
from ._concurrency import ConcurrencyBlocked, ConcurrencyLimit
@@ -53,6 +56,9 @@
5356
"FailureHandler",
5457
"CompletionHandler",
5558
"TaskOutcome",
59+
"current_docket",
60+
"current_execution",
61+
"current_worker",
5662
"format_duration",
5763
# Contextual dependencies
5864
"CurrentDocket",

src/docket/dependencies/_base.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,21 @@
66
from contextvars import ContextVar
77
from dataclasses import dataclass, field
88
from datetime import timedelta
9-
from types import TracebackType
10-
from typing import TYPE_CHECKING, Any, Awaitable, Callable
9+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar
10+
11+
from uncalled_for import Dependency as Dependency
1112

1213
if TYPE_CHECKING: # pragma: no cover
1314
from ..docket import Docket
1415
from ..execution import Execution
1516
from ..worker import Worker
1617

18+
T = TypeVar("T", covariant=True)
19+
20+
current_docket: ContextVar[Docket] = ContextVar("current_docket")
21+
current_worker: ContextVar[Worker] = ContextVar("current_worker")
22+
current_execution: ContextVar[Execution] = ContextVar("current_execution")
23+
1724

1825
def format_duration(seconds: float) -> str:
1926
"""Format a duration for log output."""
@@ -45,27 +52,7 @@ def __init__(self, execution: Execution, reason: str = "admission control"):
4552
super().__init__(f"Task {execution.key} blocked by {reason}")
4653

4754

48-
class Dependency(abc.ABC):
49-
"""Base class for all dependencies."""
50-
51-
single: bool = False
52-
53-
docket: ContextVar[Docket] = ContextVar("docket")
54-
worker: ContextVar[Worker] = ContextVar("worker")
55-
execution: ContextVar[Execution] = ContextVar("execution")
56-
57-
@abc.abstractmethod
58-
async def __aenter__(self) -> Any: ... # pragma: no cover
59-
60-
async def __aexit__(
61-
self,
62-
_exc_type: type[BaseException] | None,
63-
_exc_value: BaseException | None,
64-
_traceback: TracebackType | None,
65-
) -> bool: ... # pragma: no cover
66-
67-
68-
class Runtime(Dependency):
55+
class Runtime(Dependency[T]):
6956
"""Base class for dependencies that control task execution.
7057
7158
Only one Runtime dependency can be active per task (single=True).
@@ -93,7 +80,7 @@ async def run(
9380
... # pragma: no cover
9481

9582

96-
class FailureHandler(Dependency):
83+
class FailureHandler(Dependency[T]):
9784
"""Base class for dependencies that control what happens when a task fails.
9885
9986
Called on exceptions. If handle_failure() returns True, the handler
@@ -120,7 +107,7 @@ async def handle_failure(self, execution: Execution, outcome: TaskOutcome) -> bo
120107
... # pragma: no cover
121108

122109

123-
class CompletionHandler(Dependency):
110+
class CompletionHandler(Dependency[T]):
124111
"""Base class for dependencies that control what happens after task completion.
125112
126113
Called after execution is truly done (success, or failure with no retry).

src/docket/dependencies/_concurrency.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from typing import TYPE_CHECKING
99

1010
from .._cancellation import CANCEL_MSG_CLEANUP, cancel_task
11-
from ._base import AdmissionBlocked, Dependency
11+
from ._base import (
12+
AdmissionBlocked,
13+
Dependency,
14+
current_docket,
15+
current_execution,
16+
current_worker,
17+
)
1218

1319
logger = logging.getLogger(__name__)
1420

@@ -38,7 +44,7 @@ def __init__(self, execution: Execution, concurrency_key: str, max_concurrent: i
3844
super().__init__(execution, reason=reason)
3945

4046

41-
class ConcurrencyLimit(Dependency):
47+
class ConcurrencyLimit(Dependency["ConcurrencyLimit"]):
4248
"""Configures concurrency limits for task execution.
4349
4450
Can limit concurrency globally for a task, or per specific argument value.
@@ -94,9 +100,9 @@ def __init__(
94100
async def __aenter__(self) -> ConcurrencyLimit:
95101
from ._functional import _Depends
96102

97-
execution = self.execution.get()
98-
docket = self.docket.get()
99-
worker = self.worker.get()
103+
execution = current_execution.get()
104+
docket = current_docket.get()
105+
worker = current_worker.get()
100106

101107
# Build concurrency key based on argument_name (if provided) or function name
102108
scope = self.scope or docket.name
@@ -151,9 +157,9 @@ async def __aenter__(self) -> ConcurrencyLimit:
151157

152158
async def __aexit__(
153159
self,
154-
_exc_type: type[BaseException] | None,
155-
_exc_value: BaseException | None,
156-
_traceback: type[BaseException] | None,
160+
exc_type: type[BaseException] | None,
161+
exc_value: BaseException | None,
162+
traceback: type[BaseException] | None,
157163
) -> None:
158164
# No-op: The original instance (used as default argument) has no state.
159165
# Actual cleanup is handled by _cleanup() on the per-task instance,
@@ -256,7 +262,7 @@ async def _release_slot(self) -> None:
256262
# Note: only registered as callback for instances with valid keys
257263
assert self._concurrency_key and self._task_key
258264

259-
docket = self.docket.get()
265+
docket = current_docket.get()
260266
async with docket.redis() as redis:
261267
# Remove this task from the sorted set and delete the key if empty
262268
# KEYS[1]: concurrency_key, ARGV[1]: task_key
@@ -272,7 +278,7 @@ async def _release_slot(self) -> None:
272278

273279
async def _renew_lease_loop(self, redelivery_timeout: timedelta) -> None:
274280
"""Periodically refresh slot timestamp to prevent expiration."""
275-
docket = self.docket.get()
281+
docket = current_docket.get()
276282
renewal_interval = redelivery_timeout.total_seconds() / LEASE_RENEWAL_FACTOR
277283
key_ttl = max(
278284
MINIMUM_TTL_SECONDS,

src/docket/dependencies/_contextual.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
import logging
66
from typing import TYPE_CHECKING, Any, cast
77

8-
from ._base import Dependency
8+
from ._base import Dependency, current_docket, current_execution, current_worker
99

1010
if TYPE_CHECKING: # pragma: no cover
1111
from ..docket import Docket
1212
from ..execution import Execution
1313
from ..worker import Worker
1414

1515

16-
class _CurrentWorker(Dependency):
16+
class _CurrentWorker(Dependency["Worker"]):
1717
async def __aenter__(self) -> Worker:
18-
return self.worker.get()
18+
return current_worker.get()
1919

2020

2121
def CurrentWorker() -> Worker:
@@ -32,9 +32,9 @@ async def my_task(worker: Worker = CurrentWorker()) -> None:
3232
return cast("Worker", _CurrentWorker())
3333

3434

35-
class _CurrentDocket(Dependency):
35+
class _CurrentDocket(Dependency["Docket"]):
3636
async def __aenter__(self) -> Docket:
37-
return self.docket.get()
37+
return current_docket.get()
3838

3939

4040
def CurrentDocket() -> Docket:
@@ -51,9 +51,9 @@ async def my_task(docket: Docket = CurrentDocket()) -> None:
5151
return cast("Docket", _CurrentDocket())
5252

5353

54-
class _CurrentExecution(Dependency):
54+
class _CurrentExecution(Dependency["Execution"]):
5555
async def __aenter__(self) -> Execution:
56-
return self.execution.get()
56+
return current_execution.get()
5757

5858

5959
def CurrentExecution() -> Execution:
@@ -70,9 +70,9 @@ async def my_task(execution: Execution = CurrentExecution()) -> None:
7070
return cast("Execution", _CurrentExecution())
7171

7272

73-
class _TaskKey(Dependency):
73+
class _TaskKey(Dependency[str]):
7474
async def __aenter__(self) -> str:
75-
return self.execution.get().key
75+
return current_execution.get().key
7676

7777

7878
def TaskKey() -> str:
@@ -89,7 +89,7 @@ async def my_task(key: str = TaskKey()) -> None:
8989
return cast(str, _TaskKey())
9090

9191

92-
class _TaskArgument(Dependency):
92+
class _TaskArgument(Dependency[Any]):
9393
parameter: str | None
9494
optional: bool
9595

@@ -99,7 +99,7 @@ def __init__(self, parameter: str | None = None, optional: bool = False) -> None
9999

100100
async def __aenter__(self) -> Any:
101101
assert self.parameter is not None
102-
execution = self.execution.get()
102+
execution = current_execution.get()
103103
try:
104104
return execution.get_argument(self.parameter)
105105
except KeyError:
@@ -128,15 +128,15 @@ async def greet_customer(customer_id: int, name: str = Depends(customer_name)) -
128128
return cast(Any, _TaskArgument(parameter, optional))
129129

130130

131-
class _TaskLogger(Dependency):
131+
class _TaskLogger(Dependency["logging.LoggerAdapter[logging.Logger]"]):
132132
async def __aenter__(self) -> logging.LoggerAdapter[logging.Logger]:
133-
execution = self.execution.get()
133+
execution = current_execution.get()
134134
logger = logging.getLogger(f"docket.task.{execution.function_name}")
135135
return logging.LoggerAdapter(
136136
logger,
137137
{
138-
**self.docket.get().labels(),
139-
**self.worker.get().labels(),
138+
**current_docket.get().labels(),
139+
**current_worker.get().labels(),
140140
**execution.specific_labels(),
141141
},
142142
)

src/docket/dependencies/_cron.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from croniter import croniter
99

10+
from ._base import current_execution
1011
from ._perpetual import Perpetual
1112

1213
if TYPE_CHECKING: # pragma: no cover
@@ -82,7 +83,7 @@ def __init__(
8283
self._croniter = croniter(self.expression, datetime.now(self.tz), datetime)
8384

8485
async def __aenter__(self) -> Cron:
85-
execution = self.execution.get()
86+
execution = current_execution.get()
8687
cron = Cron(expression=self.expression, automatic=self.automatic, tz=self.tz)
8788
cron.args = execution.args
8889
cron.kwargs = execution.kwargs

0 commit comments

Comments
 (0)