Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"typer>=0.15.1",
"typing_extensions>=4.12.0",
"tzdata>=2025.2; sys_platform == 'win32'",
"uncalled-for>=0.1.2",
]

[project.optional-dependencies]
Expand Down
6 changes: 6 additions & 0 deletions src/docket/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
FailureHandler,
Runtime,
TaskOutcome,
current_docket,
current_execution,
current_worker,
format_duration,
)
from ._concurrency import ConcurrencyBlocked, ConcurrencyLimit
Expand Down Expand Up @@ -53,6 +56,9 @@
"FailureHandler",
"CompletionHandler",
"TaskOutcome",
"current_docket",
"current_execution",
"current_worker",
"format_duration",
# Contextual dependencies
"CurrentDocket",
Expand Down
37 changes: 12 additions & 25 deletions src/docket/dependencies/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import timedelta
from types import TracebackType
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar

from uncalled_for import Dependency as Dependency

if TYPE_CHECKING: # pragma: no cover
from ..docket import Docket
from ..execution import Execution
from ..worker import Worker

T = TypeVar("T", covariant=True)

current_docket: ContextVar[Docket] = ContextVar("current_docket")
current_worker: ContextVar[Worker] = ContextVar("current_worker")
current_execution: ContextVar[Execution] = ContextVar("current_execution")


def format_duration(seconds: float) -> str:
"""Format a duration for log output."""
Expand Down Expand Up @@ -45,27 +52,7 @@ def __init__(self, execution: Execution, reason: str = "admission control"):
super().__init__(f"Task {execution.key} blocked by {reason}")


class Dependency(abc.ABC):
"""Base class for all dependencies."""

single: bool = False

docket: ContextVar[Docket] = ContextVar("docket")
worker: ContextVar[Worker] = ContextVar("worker")
execution: ContextVar[Execution] = ContextVar("execution")

@abc.abstractmethod
async def __aenter__(self) -> Any: ... # pragma: no cover

async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_value: BaseException | None,
_traceback: TracebackType | None,
) -> bool: ... # pragma: no cover


class Runtime(Dependency):
class Runtime(Dependency[T]):
"""Base class for dependencies that control task execution.

Only one Runtime dependency can be active per task (single=True).
Expand Down Expand Up @@ -93,7 +80,7 @@ async def run(
... # pragma: no cover


class FailureHandler(Dependency):
class FailureHandler(Dependency[T]):
"""Base class for dependencies that control what happens when a task fails.

Called on exceptions. If handle_failure() returns True, the handler
Expand All @@ -120,7 +107,7 @@ async def handle_failure(self, execution: Execution, outcome: TaskOutcome) -> bo
... # pragma: no cover


class CompletionHandler(Dependency):
class CompletionHandler(Dependency[T]):
"""Base class for dependencies that control what happens after task completion.

Called after execution is truly done (success, or failure with no retry).
Expand Down
26 changes: 16 additions & 10 deletions src/docket/dependencies/_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from typing import TYPE_CHECKING

from .._cancellation import CANCEL_MSG_CLEANUP, cancel_task
from ._base import AdmissionBlocked, Dependency
from ._base import (
AdmissionBlocked,
Dependency,
current_docket,
current_execution,
current_worker,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,7 +44,7 @@ def __init__(self, execution: Execution, concurrency_key: str, max_concurrent: i
super().__init__(execution, reason=reason)


class ConcurrencyLimit(Dependency):
class ConcurrencyLimit(Dependency["ConcurrencyLimit"]):
"""Configures concurrency limits for task execution.

Can limit concurrency globally for a task, or per specific argument value.
Expand Down Expand Up @@ -94,9 +100,9 @@ def __init__(
async def __aenter__(self) -> ConcurrencyLimit:
from ._functional import _Depends

execution = self.execution.get()
docket = self.docket.get()
worker = self.worker.get()
execution = current_execution.get()
docket = current_docket.get()
worker = current_worker.get()

# Build concurrency key based on argument_name (if provided) or function name
scope = self.scope or docket.name
Expand Down Expand Up @@ -151,9 +157,9 @@ async def __aenter__(self) -> ConcurrencyLimit:

async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_value: BaseException | None,
_traceback: type[BaseException] | None,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: type[BaseException] | None,
) -> None:
# No-op: The original instance (used as default argument) has no state.
# Actual cleanup is handled by _cleanup() on the per-task instance,
Expand Down Expand Up @@ -256,7 +262,7 @@ async def _release_slot(self) -> None:
# Note: only registered as callback for instances with valid keys
assert self._concurrency_key and self._task_key

docket = self.docket.get()
docket = current_docket.get()
async with docket.redis() as redis:
# Remove this task from the sorted set and delete the key if empty
# KEYS[1]: concurrency_key, ARGV[1]: task_key
Expand All @@ -272,7 +278,7 @@ async def _release_slot(self) -> None:

async def _renew_lease_loop(self, redelivery_timeout: timedelta) -> None:
"""Periodically refresh slot timestamp to prevent expiration."""
docket = self.docket.get()
docket = current_docket.get()
renewal_interval = redelivery_timeout.total_seconds() / LEASE_RENEWAL_FACTOR
key_ttl = max(
MINIMUM_TTL_SECONDS,
Expand Down
30 changes: 15 additions & 15 deletions src/docket/dependencies/_contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
import logging
from typing import TYPE_CHECKING, Any, cast

from ._base import Dependency
from ._base import Dependency, current_docket, current_execution, current_worker

if TYPE_CHECKING: # pragma: no cover
from ..docket import Docket
from ..execution import Execution
from ..worker import Worker


class _CurrentWorker(Dependency):
class _CurrentWorker(Dependency["Worker"]):
async def __aenter__(self) -> Worker:
return self.worker.get()
return current_worker.get()


def CurrentWorker() -> Worker:
Expand All @@ -32,9 +32,9 @@ async def my_task(worker: Worker = CurrentWorker()) -> None:
return cast("Worker", _CurrentWorker())


class _CurrentDocket(Dependency):
class _CurrentDocket(Dependency["Docket"]):
async def __aenter__(self) -> Docket:
return self.docket.get()
return current_docket.get()


def CurrentDocket() -> Docket:
Expand All @@ -51,9 +51,9 @@ async def my_task(docket: Docket = CurrentDocket()) -> None:
return cast("Docket", _CurrentDocket())


class _CurrentExecution(Dependency):
class _CurrentExecution(Dependency["Execution"]):
async def __aenter__(self) -> Execution:
return self.execution.get()
return current_execution.get()


def CurrentExecution() -> Execution:
Expand All @@ -70,9 +70,9 @@ async def my_task(execution: Execution = CurrentExecution()) -> None:
return cast("Execution", _CurrentExecution())


class _TaskKey(Dependency):
class _TaskKey(Dependency[str]):
async def __aenter__(self) -> str:
return self.execution.get().key
return current_execution.get().key


def TaskKey() -> str:
Expand All @@ -89,7 +89,7 @@ async def my_task(key: str = TaskKey()) -> None:
return cast(str, _TaskKey())


class _TaskArgument(Dependency):
class _TaskArgument(Dependency[Any]):
parameter: str | None
optional: bool

Expand All @@ -99,7 +99,7 @@ def __init__(self, parameter: str | None = None, optional: bool = False) -> None

async def __aenter__(self) -> Any:
assert self.parameter is not None
execution = self.execution.get()
execution = current_execution.get()
try:
return execution.get_argument(self.parameter)
except KeyError:
Expand Down Expand Up @@ -128,15 +128,15 @@ async def greet_customer(customer_id: int, name: str = Depends(customer_name)) -
return cast(Any, _TaskArgument(parameter, optional))


class _TaskLogger(Dependency):
class _TaskLogger(Dependency["logging.LoggerAdapter[logging.Logger]"]):
async def __aenter__(self) -> logging.LoggerAdapter[logging.Logger]:
execution = self.execution.get()
execution = current_execution.get()
logger = logging.getLogger(f"docket.task.{execution.function_name}")
return logging.LoggerAdapter(
logger,
{
**self.docket.get().labels(),
**self.worker.get().labels(),
**current_docket.get().labels(),
**current_worker.get().labels(),
**execution.specific_labels(),
},
)
Expand Down
3 changes: 2 additions & 1 deletion src/docket/dependencies/_cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from croniter import croniter

from ._base import current_execution
from ._perpetual import Perpetual

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(
self._croniter = croniter(self.expression, datetime.now(self.tz), datetime)

async def __aenter__(self) -> Cron:
execution = self.execution.get()
execution = current_execution.get()
cron = Cron(expression=self.expression, automatic=self.automatic, tz=self.tz)
cron.args = execution.args
cron.kwargs = execution.kwargs
Expand Down
Loading