Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions docs/dependency-injection.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

Docket includes a dependency injection system that provides access to context, configuration, and custom resources. It's similar to FastAPI's dependency injection but tailored for background task patterns.

As of version 0.18.0, Docket's dependency injection is built on the
[`uncalled-for`](https://github.com/chrisguidry/uncalled-for) package
([PyPI](https://pypi.org/project/uncalled-for/)), which provides the core
resolution engine, `Depends`, `Shared`, and `Dependency` base class. Docket
layers on task-specific context (`CurrentDocket`, `CurrentWorker`, etc.) and
behavioral dependencies (`Retry`, `Perpetual`, `Timeout`, etc.).

## Contextual Dependencies

### Accessing the Current Docket
Expand Down Expand Up @@ -355,12 +362,33 @@ async def fetch_pages(
await process_response(response)
```

Inside `__aenter__`, you can access the current execution context through the class-level context vars `self.docket`, `self.worker`, and `self.execution`:
Inside `__aenter__`, you can access the current execution context through the
module-level context variables `current_docket`, `current_worker`, and
`current_execution`:

```python
from docket.dependencies import Dependency, current_execution, current_worker

class AuditedDependency(Dependency):
async def __aenter__(self) -> AuditLog:
execution = self.execution.get()
worker = self.worker.get()
execution = current_execution.get()
worker = current_worker.get()
return AuditLog(task_key=execution.key, worker_name=worker.name)
```

Or use the higher-level contextual dependencies for cleaner code:

```python
from docket import CurrentExecution, CurrentWorker, Depends, Execution, Worker

async def create_audit_log(
execution: Execution = CurrentExecution(),
worker: Worker = CurrentWorker(),
) -> AuditLog:
return AuditLog(task_key=execution.key, worker_name=worker.name)

async def audited_task(
audit_log: AuditLog = Depends(create_audit_log),
) -> None:
...
```
4 changes: 2 additions & 2 deletions loq.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ max_lines = 750
# Source files that still need exceptions above 750
[[rules]]
path = "src/docket/worker.py"
max_lines = 1133
max_lines = 1141

[[rules]]
path = "src/docket/cli/__init__.py"
max_lines = 945

[[rules]]
path = "src/docket/execution.py"
max_lines = 1020
max_lines = 1019

[[rules]]
path = "src/docket/docket.py"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"typer>=0.15.1",
"typing_extensions>=4.12.0",
"tzdata>=2025.2; sys_platform == 'win32'",
"uncalled-for>=0.1.2",
]

[project.optional-dependencies]
Expand Down
6 changes: 6 additions & 0 deletions src/docket/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
FailureHandler,
Runtime,
TaskOutcome,
current_docket,
current_execution,
current_worker,
format_duration,
)
from ._concurrency import ConcurrencyBlocked, ConcurrencyLimit
Expand Down Expand Up @@ -53,6 +56,9 @@
"FailureHandler",
"CompletionHandler",
"TaskOutcome",
"current_docket",
"current_execution",
"current_worker",
"format_duration",
# Contextual dependencies
"CurrentDocket",
Expand Down
48 changes: 23 additions & 25 deletions src/docket/dependencies/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,32 @@
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import timedelta
from types import TracebackType
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar

from uncalled_for import Dependency as Dependency

if TYPE_CHECKING: # pragma: no cover
from ..docket import Docket
from ..execution import Execution
from ..worker import Worker

T = TypeVar("T", covariant=True)

current_docket: ContextVar[Docket] = ContextVar("current_docket")
current_worker: ContextVar[Worker] = ContextVar("current_worker")
current_execution: ContextVar[Execution] = ContextVar("current_execution")

# Backwards compatibility: prior to 0.18, docket defined its own Dependency base
# class with class-level ContextVars (Dependency.execution, Dependency.docket,
# Dependency.worker). Now that the base Dependency class comes from uncalled-for,
# those ContextVars live at module scope above. However, downstream consumers
# (notably FastMCP) access them as Dependency.execution.get(), so we monkeypatch
# them back onto the class to avoid breaking existing code. This shim can be
# removed once all known consumers have migrated to the module-level ContextVars.
Dependency.execution = current_execution # type: ignore[attr-defined]
Dependency.docket = current_docket # type: ignore[attr-defined]
Dependency.worker = current_worker # type: ignore[attr-defined]


def format_duration(seconds: float) -> str:
"""Format a duration for log output."""
Expand Down Expand Up @@ -45,27 +63,7 @@ def __init__(self, execution: Execution, reason: str = "admission control"):
super().__init__(f"Task {execution.key} blocked by {reason}")


class Dependency(abc.ABC):
"""Base class for all dependencies."""

single: bool = False

docket: ContextVar[Docket] = ContextVar("docket")
worker: ContextVar[Worker] = ContextVar("worker")
execution: ContextVar[Execution] = ContextVar("execution")

@abc.abstractmethod
async def __aenter__(self) -> Any: ... # pragma: no cover

async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_value: BaseException | None,
_traceback: TracebackType | None,
) -> bool: ... # pragma: no cover


class Runtime(Dependency):
class Runtime(Dependency[T]):
"""Base class for dependencies that control task execution.

Only one Runtime dependency can be active per task (single=True).
Expand Down Expand Up @@ -93,7 +91,7 @@ async def run(
... # pragma: no cover


class FailureHandler(Dependency):
class FailureHandler(Dependency[T]):
"""Base class for dependencies that control what happens when a task fails.

Called on exceptions. If handle_failure() returns True, the handler
Expand All @@ -120,7 +118,7 @@ async def handle_failure(self, execution: Execution, outcome: TaskOutcome) -> bo
... # pragma: no cover


class CompletionHandler(Dependency):
class CompletionHandler(Dependency[T]):
"""Base class for dependencies that control what happens after task completion.

Called after execution is truly done (success, or failure with no retry).
Expand Down
26 changes: 16 additions & 10 deletions src/docket/dependencies/_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from typing import TYPE_CHECKING

from .._cancellation import CANCEL_MSG_CLEANUP, cancel_task
from ._base import AdmissionBlocked, Dependency
from ._base import (
AdmissionBlocked,
Dependency,
current_docket,
current_execution,
current_worker,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,7 +44,7 @@ def __init__(self, execution: Execution, concurrency_key: str, max_concurrent: i
super().__init__(execution, reason=reason)


class ConcurrencyLimit(Dependency):
class ConcurrencyLimit(Dependency["ConcurrencyLimit"]):
"""Configures concurrency limits for task execution.

Can limit concurrency globally for a task, or per specific argument value.
Expand Down Expand Up @@ -94,9 +100,9 @@ def __init__(
async def __aenter__(self) -> ConcurrencyLimit:
from ._functional import _Depends

execution = self.execution.get()
docket = self.docket.get()
worker = self.worker.get()
execution = current_execution.get()
docket = current_docket.get()
worker = current_worker.get()

# Build concurrency key based on argument_name (if provided) or function name
scope = self.scope or docket.name
Expand Down Expand Up @@ -151,9 +157,9 @@ async def __aenter__(self) -> ConcurrencyLimit:

async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_value: BaseException | None,
_traceback: type[BaseException] | None,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: type[BaseException] | None,
) -> None:
# No-op: The original instance (used as default argument) has no state.
# Actual cleanup is handled by _cleanup() on the per-task instance,
Expand Down Expand Up @@ -256,7 +262,7 @@ async def _release_slot(self) -> None:
# Note: only registered as callback for instances with valid keys
assert self._concurrency_key and self._task_key

docket = self.docket.get()
docket = current_docket.get()
async with docket.redis() as redis:
# Remove this task from the sorted set and delete the key if empty
# KEYS[1]: concurrency_key, ARGV[1]: task_key
Expand All @@ -272,7 +278,7 @@ async def _release_slot(self) -> None:

async def _renew_lease_loop(self, redelivery_timeout: timedelta) -> None:
"""Periodically refresh slot timestamp to prevent expiration."""
docket = self.docket.get()
docket = current_docket.get()
renewal_interval = redelivery_timeout.total_seconds() / LEASE_RENEWAL_FACTOR
key_ttl = max(
MINIMUM_TTL_SECONDS,
Expand Down
30 changes: 15 additions & 15 deletions src/docket/dependencies/_contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
import logging
from typing import TYPE_CHECKING, Any, cast

from ._base import Dependency
from ._base import Dependency, current_docket, current_execution, current_worker

if TYPE_CHECKING: # pragma: no cover
from ..docket import Docket
from ..execution import Execution
from ..worker import Worker


class _CurrentWorker(Dependency):
class _CurrentWorker(Dependency["Worker"]):
async def __aenter__(self) -> Worker:
return self.worker.get()
return current_worker.get()


def CurrentWorker() -> Worker:
Expand All @@ -32,9 +32,9 @@ async def my_task(worker: Worker = CurrentWorker()) -> None:
return cast("Worker", _CurrentWorker())


class _CurrentDocket(Dependency):
class _CurrentDocket(Dependency["Docket"]):
async def __aenter__(self) -> Docket:
return self.docket.get()
return current_docket.get()


def CurrentDocket() -> Docket:
Expand All @@ -51,9 +51,9 @@ async def my_task(docket: Docket = CurrentDocket()) -> None:
return cast("Docket", _CurrentDocket())


class _CurrentExecution(Dependency):
class _CurrentExecution(Dependency["Execution"]):
async def __aenter__(self) -> Execution:
return self.execution.get()
return current_execution.get()


def CurrentExecution() -> Execution:
Expand All @@ -70,9 +70,9 @@ async def my_task(execution: Execution = CurrentExecution()) -> None:
return cast("Execution", _CurrentExecution())


class _TaskKey(Dependency):
class _TaskKey(Dependency[str]):
async def __aenter__(self) -> str:
return self.execution.get().key
return current_execution.get().key


def TaskKey() -> str:
Expand All @@ -89,7 +89,7 @@ async def my_task(key: str = TaskKey()) -> None:
return cast(str, _TaskKey())


class _TaskArgument(Dependency):
class _TaskArgument(Dependency[Any]):
parameter: str | None
optional: bool

Expand All @@ -99,7 +99,7 @@ def __init__(self, parameter: str | None = None, optional: bool = False) -> None

async def __aenter__(self) -> Any:
assert self.parameter is not None
execution = self.execution.get()
execution = current_execution.get()
try:
return execution.get_argument(self.parameter)
except KeyError:
Expand Down Expand Up @@ -128,15 +128,15 @@ async def greet_customer(customer_id: int, name: str = Depends(customer_name)) -
return cast(Any, _TaskArgument(parameter, optional))


class _TaskLogger(Dependency):
class _TaskLogger(Dependency["logging.LoggerAdapter[logging.Logger]"]):
async def __aenter__(self) -> logging.LoggerAdapter[logging.Logger]:
execution = self.execution.get()
execution = current_execution.get()
logger = logging.getLogger(f"docket.task.{execution.function_name}")
return logging.LoggerAdapter(
logger,
{
**self.docket.get().labels(),
**self.worker.get().labels(),
**current_docket.get().labels(),
**current_worker.get().labels(),
**execution.specific_labels(),
},
)
Expand Down
3 changes: 2 additions & 1 deletion src/docket/dependencies/_cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from croniter import croniter

from ._base import current_execution
from ._perpetual import Perpetual

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(
self._croniter = croniter(self.expression, datetime.now(self.tz), datetime)

async def __aenter__(self) -> Cron:
execution = self.execution.get()
execution = current_execution.get()
cron = Cron(expression=self.expression, automatic=self.automatic, tz=self.tz)
cron.args = execution.args
cron.kwargs = execution.kwargs
Expand Down
Loading