diff --git a/docs/dependency-injection.md b/docs/dependency-injection.md index ca08653..b337627 100644 --- a/docs/dependency-injection.md +++ b/docs/dependency-injection.md @@ -2,6 +2,13 @@ Docket includes a dependency injection system that provides access to context, configuration, and custom resources. It's similar to FastAPI's dependency injection but tailored for background task patterns. +As of version 0.18.0, Docket's dependency injection is built on the +[`uncalled-for`](https://github.com/chrisguidry/uncalled-for) package +([PyPI](https://pypi.org/project/uncalled-for/)), which provides the core +resolution engine, `Depends`, `Shared`, and `Dependency` base class. Docket +layers on task-specific context (`CurrentDocket`, `CurrentWorker`, etc.) and +behavioral dependencies (`Retry`, `Perpetual`, `Timeout`, etc.). + ## Contextual Dependencies ### Accessing the Current Docket @@ -355,12 +362,33 @@ async def fetch_pages( await process_response(response) ``` -Inside `__aenter__`, you can access the current execution context through the class-level context vars `self.docket`, `self.worker`, and `self.execution`: +Inside `__aenter__`, you can access the current execution context through the +module-level context variables `current_docket`, `current_worker`, and +`current_execution`: ```python +from docket.dependencies import Dependency, current_execution, current_worker + class AuditedDependency(Dependency): async def __aenter__(self) -> AuditLog: - execution = self.execution.get() - worker = self.worker.get() + execution = current_execution.get() + worker = current_worker.get() return AuditLog(task_key=execution.key, worker_name=worker.name) ``` + +Or use the higher-level contextual dependencies for cleaner code: + +```python +from docket import CurrentExecution, CurrentWorker, Depends, Execution, Worker + +async def create_audit_log( + execution: Execution = CurrentExecution(), + worker: Worker = CurrentWorker(), +) -> AuditLog: + return AuditLog(task_key=execution.key, worker_name=worker.name) + +async def audited_task( + audit_log: AuditLog = Depends(create_audit_log), +) -> None: + ... +``` diff --git a/loq.toml b/loq.toml index dfee2cd..5b31587 100644 --- a/loq.toml +++ b/loq.toml @@ -10,7 +10,7 @@ max_lines = 750 # Source files that still need exceptions above 750 [[rules]] path = "src/docket/worker.py" -max_lines = 1133 +max_lines = 1141 [[rules]] path = "src/docket/cli/__init__.py" @@ -18,7 +18,7 @@ max_lines = 945 [[rules]] path = "src/docket/execution.py" -max_lines = 1020 +max_lines = 1019 [[rules]] path = "src/docket/docket.py" diff --git a/pyproject.toml b/pyproject.toml index eff9cad..06a830b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "typer>=0.15.1", "typing_extensions>=4.12.0", "tzdata>=2025.2; sys_platform == 'win32'", + "uncalled-for>=0.1.2", ] [project.optional-dependencies] diff --git a/src/docket/dependencies/__init__.py b/src/docket/dependencies/__init__.py index 2a51077..fdadf39 100644 --- a/src/docket/dependencies/__init__.py +++ b/src/docket/dependencies/__init__.py @@ -13,6 +13,9 @@ FailureHandler, Runtime, TaskOutcome, + current_docket, + current_execution, + current_worker, format_duration, ) from ._concurrency import ConcurrencyBlocked, ConcurrencyLimit @@ -53,6 +56,9 @@ "FailureHandler", "CompletionHandler", "TaskOutcome", + "current_docket", + "current_execution", + "current_worker", "format_duration", # Contextual dependencies "CurrentDocket", diff --git a/src/docket/dependencies/_base.py b/src/docket/dependencies/_base.py index 8144409..a085af2 100644 --- a/src/docket/dependencies/_base.py +++ b/src/docket/dependencies/_base.py @@ -6,14 +6,32 @@ from contextvars import ContextVar from dataclasses import dataclass, field from datetime import timedelta -from types import TracebackType -from typing import TYPE_CHECKING, Any, Awaitable, Callable +from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar + +from uncalled_for import Dependency as Dependency if TYPE_CHECKING: # pragma: no cover from ..docket import Docket from ..execution import Execution from ..worker import Worker +T = TypeVar("T", covariant=True) + +current_docket: ContextVar[Docket] = ContextVar("current_docket") +current_worker: ContextVar[Worker] = ContextVar("current_worker") +current_execution: ContextVar[Execution] = ContextVar("current_execution") + +# Backwards compatibility: prior to 0.18, docket defined its own Dependency base +# class with class-level ContextVars (Dependency.execution, Dependency.docket, +# Dependency.worker). Now that the base Dependency class comes from uncalled-for, +# those ContextVars live at module scope above. However, downstream consumers +# (notably FastMCP) access them as Dependency.execution.get(), so we monkeypatch +# them back onto the class to avoid breaking existing code. This shim can be +# removed once all known consumers have migrated to the module-level ContextVars. +Dependency.execution = current_execution # type: ignore[attr-defined] +Dependency.docket = current_docket # type: ignore[attr-defined] +Dependency.worker = current_worker # type: ignore[attr-defined] + def format_duration(seconds: float) -> str: """Format a duration for log output.""" @@ -45,27 +63,7 @@ def __init__(self, execution: Execution, reason: str = "admission control"): super().__init__(f"Task {execution.key} blocked by {reason}") -class Dependency(abc.ABC): - """Base class for all dependencies.""" - - single: bool = False - - docket: ContextVar[Docket] = ContextVar("docket") - worker: ContextVar[Worker] = ContextVar("worker") - execution: ContextVar[Execution] = ContextVar("execution") - - @abc.abstractmethod - async def __aenter__(self) -> Any: ... # pragma: no cover - - async def __aexit__( - self, - _exc_type: type[BaseException] | None, - _exc_value: BaseException | None, - _traceback: TracebackType | None, - ) -> bool: ... # pragma: no cover - - -class Runtime(Dependency): +class Runtime(Dependency[T]): """Base class for dependencies that control task execution. Only one Runtime dependency can be active per task (single=True). @@ -93,7 +91,7 @@ async def run( ... # pragma: no cover -class FailureHandler(Dependency): +class FailureHandler(Dependency[T]): """Base class for dependencies that control what happens when a task fails. Called on exceptions. If handle_failure() returns True, the handler @@ -120,7 +118,7 @@ async def handle_failure(self, execution: Execution, outcome: TaskOutcome) -> bo ... # pragma: no cover -class CompletionHandler(Dependency): +class CompletionHandler(Dependency[T]): """Base class for dependencies that control what happens after task completion. Called after execution is truly done (success, or failure with no retry). diff --git a/src/docket/dependencies/_concurrency.py b/src/docket/dependencies/_concurrency.py index 71e74e4..9343860 100644 --- a/src/docket/dependencies/_concurrency.py +++ b/src/docket/dependencies/_concurrency.py @@ -8,7 +8,13 @@ from typing import TYPE_CHECKING from .._cancellation import CANCEL_MSG_CLEANUP, cancel_task -from ._base import AdmissionBlocked, Dependency +from ._base import ( + AdmissionBlocked, + Dependency, + current_docket, + current_execution, + current_worker, +) logger = logging.getLogger(__name__) @@ -38,7 +44,7 @@ def __init__(self, execution: Execution, concurrency_key: str, max_concurrent: i super().__init__(execution, reason=reason) -class ConcurrencyLimit(Dependency): +class ConcurrencyLimit(Dependency["ConcurrencyLimit"]): """Configures concurrency limits for task execution. Can limit concurrency globally for a task, or per specific argument value. @@ -94,9 +100,9 @@ def __init__( async def __aenter__(self) -> ConcurrencyLimit: from ._functional import _Depends - execution = self.execution.get() - docket = self.docket.get() - worker = self.worker.get() + execution = current_execution.get() + docket = current_docket.get() + worker = current_worker.get() # Build concurrency key based on argument_name (if provided) or function name scope = self.scope or docket.name @@ -151,9 +157,9 @@ async def __aenter__(self) -> ConcurrencyLimit: async def __aexit__( self, - _exc_type: type[BaseException] | None, - _exc_value: BaseException | None, - _traceback: type[BaseException] | None, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: type[BaseException] | None, ) -> None: # No-op: The original instance (used as default argument) has no state. # Actual cleanup is handled by _cleanup() on the per-task instance, @@ -256,7 +262,7 @@ async def _release_slot(self) -> None: # Note: only registered as callback for instances with valid keys assert self._concurrency_key and self._task_key - docket = self.docket.get() + docket = current_docket.get() async with docket.redis() as redis: # Remove this task from the sorted set and delete the key if empty # KEYS[1]: concurrency_key, ARGV[1]: task_key @@ -272,7 +278,7 @@ async def _release_slot(self) -> None: async def _renew_lease_loop(self, redelivery_timeout: timedelta) -> None: """Periodically refresh slot timestamp to prevent expiration.""" - docket = self.docket.get() + docket = current_docket.get() renewal_interval = redelivery_timeout.total_seconds() / LEASE_RENEWAL_FACTOR key_ttl = max( MINIMUM_TTL_SECONDS, diff --git a/src/docket/dependencies/_contextual.py b/src/docket/dependencies/_contextual.py index e903264..b42e66a 100644 --- a/src/docket/dependencies/_contextual.py +++ b/src/docket/dependencies/_contextual.py @@ -5,7 +5,7 @@ import logging from typing import TYPE_CHECKING, Any, cast -from ._base import Dependency +from ._base import Dependency, current_docket, current_execution, current_worker if TYPE_CHECKING: # pragma: no cover from ..docket import Docket @@ -13,9 +13,9 @@ from ..worker import Worker -class _CurrentWorker(Dependency): +class _CurrentWorker(Dependency["Worker"]): async def __aenter__(self) -> Worker: - return self.worker.get() + return current_worker.get() def CurrentWorker() -> Worker: @@ -32,9 +32,9 @@ async def my_task(worker: Worker = CurrentWorker()) -> None: return cast("Worker", _CurrentWorker()) -class _CurrentDocket(Dependency): +class _CurrentDocket(Dependency["Docket"]): async def __aenter__(self) -> Docket: - return self.docket.get() + return current_docket.get() def CurrentDocket() -> Docket: @@ -51,9 +51,9 @@ async def my_task(docket: Docket = CurrentDocket()) -> None: return cast("Docket", _CurrentDocket()) -class _CurrentExecution(Dependency): +class _CurrentExecution(Dependency["Execution"]): async def __aenter__(self) -> Execution: - return self.execution.get() + return current_execution.get() def CurrentExecution() -> Execution: @@ -70,9 +70,9 @@ async def my_task(execution: Execution = CurrentExecution()) -> None: return cast("Execution", _CurrentExecution()) -class _TaskKey(Dependency): +class _TaskKey(Dependency[str]): async def __aenter__(self) -> str: - return self.execution.get().key + return current_execution.get().key def TaskKey() -> str: @@ -89,7 +89,7 @@ async def my_task(key: str = TaskKey()) -> None: return cast(str, _TaskKey()) -class _TaskArgument(Dependency): +class _TaskArgument(Dependency[Any]): parameter: str | None optional: bool @@ -99,7 +99,7 @@ def __init__(self, parameter: str | None = None, optional: bool = False) -> None async def __aenter__(self) -> Any: assert self.parameter is not None - execution = self.execution.get() + execution = current_execution.get() try: return execution.get_argument(self.parameter) except KeyError: @@ -128,15 +128,15 @@ async def greet_customer(customer_id: int, name: str = Depends(customer_name)) - return cast(Any, _TaskArgument(parameter, optional)) -class _TaskLogger(Dependency): +class _TaskLogger(Dependency["logging.LoggerAdapter[logging.Logger]"]): async def __aenter__(self) -> logging.LoggerAdapter[logging.Logger]: - execution = self.execution.get() + execution = current_execution.get() logger = logging.getLogger(f"docket.task.{execution.function_name}") return logging.LoggerAdapter( logger, { - **self.docket.get().labels(), - **self.worker.get().labels(), + **current_docket.get().labels(), + **current_worker.get().labels(), **execution.specific_labels(), }, ) diff --git a/src/docket/dependencies/_cron.py b/src/docket/dependencies/_cron.py index e521785..4f25d27 100644 --- a/src/docket/dependencies/_cron.py +++ b/src/docket/dependencies/_cron.py @@ -7,6 +7,7 @@ from croniter import croniter +from ._base import current_execution from ._perpetual import Perpetual if TYPE_CHECKING: # pragma: no cover @@ -82,7 +83,7 @@ def __init__( self._croniter = croniter(self.expression, datetime.now(self.tz), datetime) async def __aenter__(self) -> Cron: - execution = self.execution.get() + execution = current_execution.get() cron = Cron(expression=self.expression, automatic=self.automatic, tz=self.tz) cron.args = execution.args cron.kwargs = execution.kwargs diff --git a/src/docket/dependencies/_functional.py b/src/docket/dependencies/_functional.py index 2d1f0ad..5b74b1e 100644 --- a/src/docket/dependencies/_functional.py +++ b/src/docket/dependencies/_functional.py @@ -2,123 +2,37 @@ from __future__ import annotations -import asyncio -import inspect -from contextlib import AsyncExitStack -from contextvars import ContextVar -from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - AsyncContextManager, - Awaitable, - Callable, - ClassVar, - ContextManager, - Generic, - TypeVar, - cast, +from collections.abc import Callable +from typing import Any, TypeVar, cast + +from uncalled_for import ( + DependencyFactory, + Shared as Shared, + SharedContext as SharedContext, + _Depends as _UncalledForDepends, + _parameter_cache as _parameter_cache, + get_dependency_parameters, ) -from ..execution import TaskFunction, get_signature -from ..instrumentation import CACHE_SIZE -from ._base import Dependency from ._contextual import _TaskArgument -if TYPE_CHECKING: # pragma: no cover - from ..docket import Docket - from ..worker import Worker - R = TypeVar("R") -DependencyFunction = Callable[ - ..., R | Awaitable[R] | ContextManager[R] | AsyncContextManager[R] -] - - -class _FunctionalDependency(Dependency, Generic[R]): - """Base class for functional dependencies (Depends and Shared). - - Functional dependencies wrap a factory function that returns (or yields) a value. - This base class provides the common factory storage and value resolution logic. - """ - - factory: DependencyFunction[R] - - def __init__(self, factory: DependencyFunction[R]) -> None: - self.factory = factory - - async def _resolve_factory_value( - self, - stack: AsyncExitStack, - raw_value: R | Awaitable[R] | ContextManager[R] | AsyncContextManager[R], - ) -> R: - """Resolve a DependencyFunction's return value to its final form. - - Handles the four possible return types: - - AsyncContextManager: enters and returns yielded value - - ContextManager: enters and returns yielded value - - Awaitable: awaits and returns result - - Plain value: returns as-is - """ - if isinstance(raw_value, AsyncContextManager): - return await stack.enter_async_context(raw_value) - elif isinstance(raw_value, ContextManager): - return stack.enter_context(raw_value) - elif inspect.iscoroutine(raw_value) or isinstance(raw_value, Awaitable): - return await cast(Awaitable[R], raw_value) - else: - return cast(R, raw_value) - +DependencyFunction = DependencyFactory -_parameter_cache: dict[ - TaskFunction | DependencyFunction[Any], - dict[str, Dependency], -] = {} - -def get_dependency_parameters( - function: TaskFunction | DependencyFunction[Any], -) -> dict[str, Dependency]: - if function in _parameter_cache: - CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"}) - return _parameter_cache[function] - - dependencies: dict[str, Dependency] = {} - - signature = get_signature(function) - - for parameter, param in signature.parameters.items(): - if not isinstance(param.default, Dependency): - continue - - dependencies[parameter] = param.default - - _parameter_cache[function] = dependencies - CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"}) - return dependencies - - -class _Depends(_FunctionalDependency[R]): - """Task-scoped dependency resolved fresh for each task.""" - - cache: ClassVar[ContextVar[dict[DependencyFunction[Any], Any]]] = ContextVar( - "cache" - ) - stack: ClassVar[ContextVar[AsyncExitStack]] = ContextVar("stack") +class _Depends(_UncalledForDepends[R]): + """Docket's call-scoped dependency with TaskArgument inference.""" async def _resolve_parameters( self, - function: TaskFunction | DependencyFunction[Any], + function: Callable[..., Any], ) -> dict[str, Any]: stack = self.stack.get() - arguments: dict[str, Any] = {} parameters = get_dependency_parameters(function) for parameter, dependency in parameters.items(): - # Special case for TaskArguments, they are "magical" and infer the parameter - # they refer to from the parameter name (unless otherwise specified) if isinstance(dependency, _TaskArgument) and not dependency.parameter: dependency.parameter = parameter @@ -126,22 +40,8 @@ async def _resolve_parameters( return arguments - async def __aenter__(self) -> R: - cache = self.cache.get() - - if self.factory in cache: - return cache[self.factory] - stack = self.stack.get() - arguments = await self._resolve_parameters(self.factory) - raw_value = self.factory(**arguments) - resolved_value = await self._resolve_factory_value(stack, raw_value) - - cache[self.factory] = resolved_value - return resolved_value - - -def Depends(dependency: DependencyFunction[R]) -> R: +def Depends(dependency: DependencyFactory[R]) -> R: """Include a user-defined function as a dependency. Dependencies may be: - Synchronous functions returning a value - Asynchronous functions returning a value (awaitable) @@ -202,157 +102,3 @@ async def my_task( ``` """ return cast(R, _Depends(dependency)) - - -class _Shared(_FunctionalDependency[R]): - """Worker-scoped dependency resolved once and shared across all tasks. - - Unlike Depends (which resolves per-task), Shared dependencies initialize once - at worker startup (or lazily on first use) and the same instance is provided - to all tasks throughout the worker's lifetime. - """ - - async def __aenter__(self) -> R: - resolved = SharedContext.resolved.get() - - # Fast path: already resolved (keyed by factory function) - if self.factory in resolved: - return resolved[self.factory] - - # Resolve factory's dependencies OUTSIDE the lock to avoid deadlock - # when a Shared depends on another Shared - arguments = await self._resolve_parameters() - - # Now acquire lock to check/store the resolved value - async with SharedContext.lock.get(): - # Double-check after acquiring lock (another task may have resolved) - if self.factory in resolved: # pragma: no cover - return resolved[self.factory] - - stack = SharedContext.stack.get() - raw_value = self.factory(**arguments) - resolved_value = await self._resolve_factory_value(stack, raw_value) - - resolved[self.factory] = resolved_value - return resolved_value - - async def _resolve_parameters(self) -> dict[str, Any]: - """Resolve parameters for the factory function.""" - stack = SharedContext.stack.get() - arguments: dict[str, Any] = {} - parameters = get_dependency_parameters(self.factory) - - for parameter, dependency in parameters.items(): - arguments[parameter] = await stack.enter_async_context(dependency) - - return arguments - - -class SharedContext: - """Manages worker-scoped Shared dependency lifecycle. - - Created by the Worker to set up ContextVars for Shared dependencies. - Handles initialization of the AsyncExitStack and cleanup on worker exit. - """ - - # ContextVars for Shared dependency state - resolved: ClassVar[ContextVar[dict[DependencyFunction[Any], Any]]] = ContextVar( - "shared_resolved" - ) - lock: ClassVar[ContextVar[asyncio.Lock]] = ContextVar("shared_lock") - stack: ClassVar[ContextVar[AsyncExitStack]] = ContextVar("shared_stack") - - def __init__(self, docket: Docket, worker: Worker) -> None: - self._docket = docket - self._worker = worker - - async def __aenter__(self) -> SharedContext: - self._stack = AsyncExitStack() - await self._stack.__aenter__() - - self._docket_token = Dependency.docket.set(self._docket) - self._worker_token = Dependency.worker.set(self._worker) - self._resolved_token = SharedContext.resolved.set({}) - self._lock_token = SharedContext.lock.set(asyncio.Lock()) - self._stack_token = SharedContext.stack.set(self._stack) - - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - # Close Shared dependencies (context managers exit in reverse order) - await self._stack.__aexit__(exc_type, exc_value, traceback) - - SharedContext.stack.reset(self._stack_token) - SharedContext.lock.reset(self._lock_token) - SharedContext.resolved.reset(self._resolved_token) - Dependency.worker.reset(self._worker_token) - Dependency.docket.reset(self._docket_token) - - -def Shared(factory: DependencyFunction[R]) -> R: - """Declare a worker-scoped dependency shared across all tasks. - - The factory initializes once when first needed and the returned/yielded value is - shared by all tasks for the lifetime of the worker. Factories may be: - - Synchronous functions returning a value - - Asynchronous functions returning a value (awaitable) - - Synchronous context managers (using @contextmanager) - - Asynchronous context managers (using @asynccontextmanager) - - Context managers are useful when cleanup is needed at worker shutdown. - - Identity is the factory function - multiple Shared(same_factory) calls anywhere - in the codebase resolve to the same cached value. - - Example with async context manager (for resources needing cleanup): - - ```python - from contextlib import asynccontextmanager - - @asynccontextmanager - async def create_db_pool(): - pool = await AsyncConnectionPool.create(conninfo="...") - try: - yield pool - finally: - await pool.close() - - @task - async def my_task(pool: Pool = Shared(create_db_pool)): - async with pool.connection() as conn: - await conn.execute("SELECT ...") - ``` - - Example with async function (for simple shared values): - - ```python - async def load_config() -> Config: - return await fetch_config_from_remote() - - @task - async def my_task(config: Config = Shared(load_config)): - # Same config instance across all tasks - print(config.api_url) - ``` - - Shared dependencies can depend on other Shared dependencies, Depends, and - contextual dependencies like CurrentDocket and CurrentWorker: - - ```python - @asynccontextmanager - async def create_pool( - docket: Docket = CurrentDocket(), - url: str = Depends(get_connection_string), - ): - logger.info(f"Creating pool for {docket.name}") - pool = await create_pool(url) - yield pool - await pool.close() - ``` - """ - return cast(R, _Shared(factory)) diff --git a/src/docket/dependencies/_perpetual.py b/src/docket/dependencies/_perpetual.py index c7005e3..5c086d0 100644 --- a/src/docket/dependencies/_perpetual.py +++ b/src/docket/dependencies/_perpetual.py @@ -6,7 +6,14 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any -from ._base import CompletionHandler, TaskOutcome, format_duration +from ._base import ( + CompletionHandler, + TaskOutcome, + current_docket, + current_execution, + current_worker, + format_duration, +) if TYPE_CHECKING: # pragma: no cover from ..execution import Execution @@ -16,7 +23,7 @@ logger = logging.getLogger(__name__) -class Perpetual(CompletionHandler): +class Perpetual(CompletionHandler["Perpetual"]): """Declare a task that should be run perpetually. Perpetual tasks are automatically rescheduled for the future after they finish (whether they succeed or fail). A perpetual task can be scheduled at worker startup with the `automatic=True`. @@ -30,8 +37,6 @@ async def my_task(perpetual: Perpetual = Perpetual()) -> None: ``` """ - single = True - every: timedelta automatic: bool @@ -60,7 +65,7 @@ def __init__( self._next_when = None async def __aenter__(self) -> Perpetual: - execution = self.execution.get() + execution = current_execution.get() perpetual = Perpetual(every=self.every, automatic=self.automatic) perpetual.args = execution.args perpetual.kwargs = execution.kwargs @@ -89,13 +94,13 @@ def at(self, when: datetime) -> None: async def on_complete(self, execution: Execution, outcome: TaskOutcome) -> bool: """Handle completion by scheduling the next execution.""" if self.cancelled: - docket = self.docket.get() + docket = current_docket.get() async with docket.redis() as redis: await docket._cancel(redis, execution.key) return False if await execution.is_superseded(): - worker = self.worker.get() + worker = current_worker.get() TASKS_SUPERSEDED.add( 1, { @@ -111,8 +116,8 @@ async def on_complete(self, execution: Execution, outcome: TaskOutcome) -> bool: ) return True - docket = self.docket.get() - worker = self.worker.get() + docket = current_docket.get() + worker = current_worker.get() if self._next_when: when = self._next_when diff --git a/src/docket/dependencies/_progress.py b/src/docket/dependencies/_progress.py index 79dcb61..01daed8 100644 --- a/src/docket/dependencies/_progress.py +++ b/src/docket/dependencies/_progress.py @@ -4,13 +4,13 @@ from typing import TYPE_CHECKING -from ._base import Dependency +from ._base import Dependency, current_execution if TYPE_CHECKING: # pragma: no cover from ..execution import ExecutionProgress -class Progress(Dependency): +class Progress(Dependency["Progress"]): """A dependency to report progress updates for the currently executing task. Tasks can use this to report their current progress (current/total values) and @@ -33,7 +33,7 @@ def __init__(self) -> None: self._progress: ExecutionProgress | None = None async def __aenter__(self) -> Progress: - execution = self.execution.get() + execution = current_execution.get() self._progress = execution.progress return self diff --git a/src/docket/dependencies/_resolution.py b/src/docket/dependencies/_resolution.py index 805244e..dc8a648 100644 --- a/src/docket/dependencies/_resolution.py +++ b/src/docket/dependencies/_resolution.py @@ -3,9 +3,14 @@ from __future__ import annotations from contextlib import AsyncExitStack, asynccontextmanager -from typing import TYPE_CHECKING, Any, AsyncGenerator, Counter, TypeVar +from typing import TYPE_CHECKING, Any, AsyncGenerator, TypeVar -from ._base import Dependency +from uncalled_for import ( + FailedDependency as FailedDependency, + validate_dependencies as validate_dependencies, +) + +from ._base import Dependency, current_docket, current_execution, current_worker from ._contextual import _TaskArgument from ._functional import _Depends, get_dependency_parameters @@ -27,7 +32,7 @@ def get_single_dependency_parameter_of_type( def get_single_dependency_of_type( - dependencies: dict[str, Dependency], dependency_type: type[D] + dependencies: dict[str, Dependency[Any]], dependency_type: type[D] ) -> D | None: assert dependency_type.single, "Dependency must be single" for _, dependency in dependencies.items(): @@ -36,59 +41,13 @@ def get_single_dependency_of_type( return None -def _single_base_classes(dependency: Dependency) -> list[type[Dependency]]: - """Return all base classes (including the concrete type) that have single=True.""" - return [ - cls - for cls in type(dependency).__mro__ - if issubclass(cls, Dependency) - and cls is not Dependency - and getattr(cls, "single", False) - ] - - -def validate_dependencies(function: TaskFunction) -> None: - parameters = get_dependency_parameters(function) - dependencies = list(parameters.values()) - - # Check concrete types (original behavior) - counts = Counter(type(dependency) for dependency in dependencies) - for dependency_type, count in counts.items(): - if dependency_type.single and count > 1: - raise ValueError( - f"Only one {dependency_type.__name__} dependency is allowed per task" - ) - - # Check base classes with single=True (e.g., Runtime) - # Two different subclasses of Runtime should conflict - single_bases: set[type[Dependency]] = set() - for dependency in dependencies: - single_bases.update(_single_base_classes(dependency)) - - for base_class in single_bases: - instances = [d for d in dependencies if isinstance(d, base_class)] - if len(instances) > 1: - types = ", ".join(type(d).__name__ for d in instances) - raise ValueError( - f"Only one {base_class.__name__} dependency is allowed per task, " - f"but found: {types}" - ) - - -class FailedDependency: - def __init__(self, parameter: str, error: Exception) -> None: - self.parameter = parameter - self.error = error - - @asynccontextmanager async def resolved_dependencies( worker: Worker, execution: Execution ) -> AsyncGenerator[dict[str, Any], None]: - # Capture tokens for all contextvar sets to ensure proper cleanup - docket_token = Dependency.docket.set(worker.docket) - worker_token = Dependency.worker.set(worker) - execution_token = Dependency.execution.set(execution) + docket_token = current_docket.set(worker.docket) + worker_token = current_worker.set(worker) + execution_token = current_execution.set(execution) cache_token = _Depends.cache.set({}) try: @@ -104,10 +63,8 @@ async def resolved_dependencies( arguments[parameter] = kwargs[parameter] continue - # Special case for TaskArguments, they are "magical" and infer the parameter - # they refer to from the parameter name (unless otherwise specified). At - # the top-level task function call, it doesn't make sense to specify one - # _without_ a parameter name, so we'll call that a failed dependency. + # At the top-level task function call, a bare TaskArgument without + # a parameter name doesn't make sense, so mark it as failed. if ( isinstance(dependency, _TaskArgument) and not dependency.parameter @@ -129,6 +86,6 @@ async def resolved_dependencies( _Depends.stack.reset(stack_token) finally: _Depends.cache.reset(cache_token) - Dependency.execution.reset(execution_token) - Dependency.worker.reset(worker_token) - Dependency.docket.reset(docket_token) + current_execution.reset(execution_token) + current_worker.reset(worker_token) + current_docket.reset(docket_token) diff --git a/src/docket/dependencies/_retry.py b/src/docket/dependencies/_retry.py index e395a5c..46fef9d 100644 --- a/src/docket/dependencies/_retry.py +++ b/src/docket/dependencies/_retry.py @@ -6,7 +6,13 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, NoReturn -from ._base import FailureHandler, TaskOutcome, format_duration +from ._base import ( + FailureHandler, + TaskOutcome, + current_execution, + current_worker, + format_duration, +) if TYPE_CHECKING: # pragma: no cover from ..execution import Execution @@ -20,7 +26,7 @@ class ForcedRetry(Exception): """Raised when a task requests a retry via `after` or `at`""" -class Retry(FailureHandler): +class Retry(FailureHandler["Retry"]): """Configures linear retries for a task. You can specify the total number of attempts (or `None` to retry indefinitely), and the delay between attempts. @@ -33,8 +39,6 @@ async def my_task(retry: Retry = Retry(attempts=3)) -> None: ``` """ - single: bool = True - attempts: int | None delay: timedelta attempt: int @@ -53,7 +57,7 @@ def __init__( self.attempt = 1 async def __aenter__(self) -> Retry: - execution = self.execution.get() + execution = current_execution.get() retry = Retry(attempts=self.attempts, delay=self.delay) retry.attempt = execution.attempt return retry @@ -83,7 +87,7 @@ async def handle_failure(self, execution: Execution, outcome: TaskOutcome) -> bo execution.attempt += 1 await execution.schedule(replace=True) - worker = self.worker.get() + worker = current_worker.get() TASKS_RETRIED.add(1, {**worker.labels(), **execution.general_labels()}) logger.info( "↫ [%s] %s", @@ -125,7 +129,7 @@ def __init__( self.maximum_delay = maximum_delay async def __aenter__(self) -> ExponentialRetry: - execution = self.execution.get() + execution = current_execution.get() retry = ExponentialRetry( attempts=self.attempts, diff --git a/src/docket/dependencies/_timeout.py b/src/docket/dependencies/_timeout.py index e17761f..a9b4906 100644 --- a/src/docket/dependencies/_timeout.py +++ b/src/docket/dependencies/_timeout.py @@ -11,10 +11,10 @@ from ..execution import Execution from .._cancellation import cancel_task -from ._base import Runtime +from ._base import Runtime, current_docket -class Timeout(Runtime): +class Timeout(Runtime["Timeout"]): """Configures a timeout for a task. You can specify the base timeout, and the task will be cancelled if it exceeds this duration. The timeout may be extended within the context of a single running task. @@ -28,8 +28,6 @@ async def my_task(timeout: Timeout = Timeout(timedelta(seconds=10))) -> None: ``` """ - single: bool = True - base: timedelta _deadline: float @@ -74,7 +72,7 @@ async def run( """Execute the function with timeout enforcement.""" self.start() - docket = self.docket.get() + docket = current_docket.get() task = asyncio.create_task( function(*args, **kwargs), # type: ignore[arg-type] name=f"{docket.name} - task:{execution.key}", diff --git a/src/docket/execution.py b/src/docket/execution.py index c28d659..c371305 100644 --- a/src/docket/execution.py +++ b/src/docket/execution.py @@ -20,6 +20,7 @@ import cloudpickle import opentelemetry.context +import uncalled_for from opentelemetry import propagate, trace from ._telemetry import suppress_instrumentation from typing_extensions import Self @@ -990,13 +991,11 @@ async def subscribe(self) -> AsyncGenerator[StateEvent | ProgressEvent, None]: def compact_signature(signature: inspect.Signature) -> str: - from .dependencies import Dependency - parameters: list[str] = [] dependencies: int = 0 for parameter in signature.parameters.values(): - if isinstance(parameter.default, Dependency): + if isinstance(parameter.default, uncalled_for.Dependency): dependencies += 1 continue diff --git a/src/docket/worker.py b/src/docket/worker.py index dab8894..17ea7bc 100644 --- a/src/docket/worker.py +++ b/src/docket/worker.py @@ -51,6 +51,8 @@ SharedContext, TaskLogger, TaskOutcome, + current_docket, + current_worker, format_duration, get_single_dependency_of_type, get_single_dependency_parameter_of_type, @@ -223,8 +225,14 @@ async def __aenter__(self) -> Self: cancel_task, self._heartbeat_task, CANCEL_MSG_CLEANUP ) + # Worker-scoped ContextVars for ambient access to docket/worker + self._docket_token = current_docket.set(self.docket) + self._stack.callback(lambda: current_docket.reset(self._docket_token)) + self._worker_token = current_worker.set(self) + self._stack.callback(lambda: current_worker.reset(self._worker_token)) + # Shared context is set up last, so it's cleaned up first (LIFO) - self._shared_context = SharedContext(self.docket, self) + self._shared_context = SharedContext() self._stack.callback(lambda: delattr(self, "_shared_context")) await self._stack.enter_async_context(self._shared_context) diff --git a/tests/concurrency_limits/test_basic.py b/tests/concurrency_limits/test_basic.py index 77aba39..65728c9 100644 --- a/tests/concurrency_limits/test_basic.py +++ b/tests/concurrency_limits/test_basic.py @@ -248,7 +248,7 @@ async def test_concurrency_limit_single_dependency_validation(docket: Docket): """Test that only one ConcurrencyLimit dependency is allowed per task.""" with pytest.raises( ValueError, - match="Only one ConcurrencyLimit dependency is allowed per task", + match="Only one ConcurrencyLimit dependency is allowed", ): async def invalid_task( diff --git a/tests/test_dependencies_advanced.py b/tests/test_dependencies_advanced.py index ac8bd71..c7a8f45 100644 --- a/tests/test_dependencies_advanced.py +++ b/tests/test_dependencies_advanced.py @@ -7,7 +7,13 @@ import pytest from docket import CurrentDocket, Docket, Worker -from docket.dependencies import Depends, Dependency, resolved_dependencies +from docket.dependencies import ( + Depends, + current_docket, + current_execution, + current_worker, + resolved_dependencies, +) from docket.dependencies._functional import _Depends # pyright: ignore[reportPrivateUsage] from docket.execution import Execution @@ -206,13 +212,13 @@ async def test_contextvar_isolation_between_tasks(docket: Docket, worker: Worker async def first_task(a: str): # Capture the execution context during first task - execution = Dependency.execution.get() + execution = current_execution.get() executions_seen.append(("first", execution)) assert a == "first" async def second_task(b: str): # Capture the execution context during second task - execution = Dependency.execution.get() + execution = current_execution.get() executions_seen.append(("second", execution)) assert b == "second" @@ -238,7 +244,7 @@ async def second_task(b: str): async def test_contextvar_cleanup_after_task(docket: Docket, worker: Worker): """Task-scoped contextvars are reset after task execution completes. - Worker-scoped contextvars (Dependency.docket, Dependency.worker) remain + Worker-scoped contextvars (current_docket, current_worker) remain set for the entire worker lifetime to support Shared dependencies. """ captured_stack = None @@ -261,12 +267,12 @@ async def capture_task(): _Depends.cache.get() with pytest.raises(LookupError): - Dependency.execution.get() + current_execution.get() # Worker-scoped contextvars (docket, worker) remain set for the worker's # lifetime to support Shared dependency initialization - assert Dependency.docket.get() is docket - assert Dependency.worker.get() is worker + assert current_docket.get() is docket + assert current_worker.get() is worker async def test_dependency_cache_isolated_between_tasks(docket: Docket, worker: Worker): @@ -354,17 +360,17 @@ async def task2(): ... captured_stack1 = None async with resolved_dependencies(worker, execution1): - captured_exec1 = Dependency.execution.get() + captured_exec1 = current_execution.get() captured_stack1 = _Depends.stack.get() assert captured_exec1 is execution1 # After exiting, contextvars should be reset (raise LookupError) with pytest.raises(LookupError): - Dependency.execution.get() + current_execution.get() # Now make a second call - should not see values from first call async with resolved_dependencies(worker, execution2): - captured_exec2 = Dependency.execution.get() + captured_exec2 = current_execution.get() captured_stack2 = _Depends.stack.get() assert captured_exec2 is execution2 assert captured_exec2 is not captured_exec1 @@ -376,7 +382,7 @@ async def test_contextvar_not_leaked_to_caller(docket: Docket): """Verify contextvars don't leak outside resolved_dependencies context""" # Before calling resolved_dependencies, contextvars should not be set with pytest.raises(LookupError): - Dependency.execution.get() + current_execution.get() async def dummy_task(): ... @@ -395,11 +401,11 @@ async def dummy_task(): ... # Use resolved_dependencies async with resolved_dependencies(test_worker, execution): # Inside context, we should be able to get values - assert Dependency.execution.get() is execution + assert current_execution.get() is execution # After exiting context, contextvars should be cleaned up with pytest.raises(LookupError): - Dependency.execution.get() + current_execution.get() with pytest.raises(LookupError): _Depends.stack.get() diff --git a/tests/test_dependencies_core.py b/tests/test_dependencies_core.py index d90769a..46075c3 100644 --- a/tests/test_dependencies_core.py +++ b/tests/test_dependencies_core.py @@ -270,3 +270,23 @@ async def dependent_task(a: list[str] = TaskArgument()) -> None: assert "Failed to resolve dependencies for parameter(s): a" in caplog.text assert "ValueError: No parameter name specified" in caplog.text + + +def test_dependency_class_has_backwards_compatible_context_vars(): + """Dependency.execution/docket/worker are available for downstream consumers. + + Prior to 0.18, docket's Dependency class had class-level ContextVars. Now + that Dependency comes from uncalled-for, those ContextVars are module-level + in docket.dependencies._base. We monkeypatch them back onto the class so + existing code (e.g. FastMCP's `Dependency.execution.get()`) keeps working. + """ + from docket.dependencies import ( + Dependency, + current_docket, + current_execution, + current_worker, + ) + + assert Dependency.execution is current_execution # type: ignore[attr-defined] + assert Dependency.docket is current_docket # type: ignore[attr-defined] + assert Dependency.worker is current_worker # type: ignore[attr-defined] diff --git a/tests/test_dependency_uniqueness.py b/tests/test_dependency_uniqueness.py index 6aec9e2..f94090f 100644 --- a/tests/test_dependency_uniqueness.py +++ b/tests/test_dependency_uniqueness.py @@ -28,7 +28,7 @@ async def the_task( with pytest.raises( ValueError, - match="Only one Retry dependency is allowed per task", + match="Only one Retry dependency is allowed", ): await docket.add(the_task)("a") @@ -36,7 +36,7 @@ async def the_task( async def test_runtime_subclasses_must_be_unique(docket: Docket, worker: Worker): """Two different Runtime subclasses should conflict since Runtime.single=True.""" - class CustomRuntime(Runtime): + class CustomRuntime(Runtime["CustomRuntime"]): async def __aenter__(self) -> "CustomRuntime": return self # pragma: no cover @@ -58,7 +58,7 @@ async def the_task( with pytest.raises( ValueError, - match=r"Only one Runtime dependency is allowed per task, but found: .+", + match=r"Only one Runtime dependency is allowed, but found: .+", ): await docket.add(the_task)("a") @@ -68,7 +68,7 @@ async def test_failure_handler_subclasses_must_be_unique( ): """Two different FailureHandler subclasses should conflict since FailureHandler.single=True.""" - class CustomFailureHandler(FailureHandler): + class CustomFailureHandler(FailureHandler["CustomFailureHandler"]): async def __aenter__(self) -> "CustomFailureHandler": return self # pragma: no cover @@ -86,7 +86,7 @@ async def the_task( with pytest.raises( ValueError, - match=r"Only one FailureHandler dependency is allowed per task, but found: .+", + match=r"Only one FailureHandler dependency is allowed, but found: .+", ): await docket.add(the_task)("a") @@ -96,7 +96,7 @@ async def test_completion_handler_subclasses_must_be_unique( ): """Two different CompletionHandler subclasses should conflict since CompletionHandler.single=True.""" - class CustomCompletionHandler(CompletionHandler): + class CustomCompletionHandler(CompletionHandler["CustomCompletionHandler"]): async def __aenter__(self) -> "CustomCompletionHandler": return self # pragma: no cover @@ -112,6 +112,6 @@ async def the_task( with pytest.raises( ValueError, - match=r"Only one CompletionHandler dependency is allowed per task, but found: .+", + match=r"Only one CompletionHandler dependency is allowed, but found: .+", ): await docket.add(the_task)("a") diff --git a/uv.lock b/uv.lock index 4a062c7..80fef04 100644 --- a/uv.lock +++ b/uv.lock @@ -1564,6 +1564,7 @@ dependencies = [ { name = "typer" }, { name = "typing-extensions" }, { name = "tzdata", marker = "sys_platform == 'win32'" }, + { name = "uncalled-for" }, ] [package.optional-dependencies] @@ -1626,6 +1627,7 @@ requires-dist = [ { name = "typer", specifier = ">=0.15.1" }, { name = "typing-extensions", specifier = ">=4.12.0" }, { name = "tzdata", marker = "sys_platform == 'win32'", specifier = ">=2025.2" }, + { name = "uncalled-for", specifier = ">=0.1.2" }, ] provides-extras = ["metrics"] @@ -2190,6 +2192,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1", size = 348521, upload-time = "2025-12-13T17:45:33.889Z" }, ] +[[package]] +name = "uncalled-for" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/5a/b726094d69b1e1feaffacfddddb42e6f1e4ba2154d6ebea2d2b478154e41/uncalled_for-0.1.2.tar.gz", hash = "sha256:f65e2956e410353c6d56a9c9aa40fd3c2bd01b8176c19510921d8fe1703f0f85", size = 47075, upload-time = "2026-02-25T20:55:19.935Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/13/9bdcf067d60d50c11b64ea24a53cd2bf8531161cd2ed4fd862c747877e1d/uncalled_for-0.1.2-py3-none-any.whl", hash = "sha256:8d1a63ab090046143a576676febe675e15e4543ae461c0bdcbfa7c23396e8667", size = 7290, upload-time = "2026-02-25T20:55:18.747Z" }, +] + [[package]] name = "urllib3" version = "2.6.3"