Skip to content

Commit 7180b79

Browse files
chrisguidryclaude
andauthored
Support Annotated-style dependency injection (#354)
## Summary Dependencies can now be attached as `Annotated` type-hint metadata instead of only as default parameter values: ```python async def process(customer_id: Annotated[int, ConcurrencyLimit(1)]): ... ``` The parameter keeps its real value; the dependency runs as a side-effect. This is especially nice for ConcurrencyLimit where the old default-param style required a separate dummy parameter just to carry the dependency. Changes: - `resolved_dependencies()` now calls `get_annotation_dependencies()` from uncalled-for and enters each annotation dep via `bind_to_parameter()` - `ConcurrencyLimit` gains `bind_to_parameter()` to auto-infer the argument name from the annotated parameter - `ConcurrencyLimit(1)` shorthand: passing an int as the first positional arg sets `max_concurrent` (convenient for the Annotated style) - `get_single_dependency_parameter_of_type()` now searches annotations too - Bumps uncalled-for to >=0.2.0 for the annotation extraction API Includes contract tests for the uncalled-for behaviors we depend on, and integration tests covering concurrency limits, Depends side-effects, mixed styles, type aliases, and validation. Closes #334, closes #163. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3484ea9 commit 7180b79

File tree

9 files changed

+497
-37
lines changed

9 files changed

+497
-37
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies = [
3737
"typer>=0.15.1",
3838
"typing_extensions>=4.12.0",
3939
"tzdata>=2025.2; sys_platform == 'win32'",
40-
"uncalled-for>=0.1.2",
40+
"uncalled-for>=0.2.0",
4141
]
4242

4343
[project.optional-dependencies]

src/docket/dependencies/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from ._progress import Progress
4242
from ._resolution import (
4343
FailedDependency,
44+
get_annotation_dependencies,
4445
get_single_dependency_of_type,
4546
get_single_dependency_parameter_of_type,
4647
resolved_dependencies,
@@ -72,6 +73,7 @@
7273
"DependencyFunction",
7374
"Shared",
7475
"SharedContext",
76+
"get_annotation_dependencies",
7577
"get_dependency_parameters",
7678
# Retry
7779
"ForcedRetry",

src/docket/dependencies/_concurrency.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import logging
77
from datetime import datetime, timedelta, timezone
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any, overload
99

1010
from .._cancellation import CANCEL_MSG_CLEANUP, cancel_task
1111
from ._base import (
@@ -49,54 +49,82 @@ class ConcurrencyLimit(Dependency["ConcurrencyLimit"]):
4949
5050
Can limit concurrency globally for a task, or per specific argument value.
5151
52-
Example:
52+
Works both as a default parameter and as ``Annotated`` metadata::
5353
54-
```python
55-
async def expensive_operation(
56-
concurrency: ConcurrencyLimit = ConcurrencyLimit(max_concurrent=3)
57-
) -> None:
58-
# Only 3 instances of this task will run at a time
59-
...
54+
# Default-parameter style
55+
async def process_customer(
56+
customer_id: int,
57+
concurrency: ConcurrencyLimit = ConcurrencyLimit("customer_id", 1),
58+
) -> None: ...
6059
61-
async def process_customer(
62-
customer_id: int,
63-
concurrency: ConcurrencyLimit = ConcurrencyLimit("customer_id", max_concurrent=1)
64-
) -> None:
65-
# Only one task per customer_id will run at a time
66-
...
60+
# Annotated style (parameter name auto-inferred)
61+
async def process_customer(
62+
customer_id: Annotated[int, ConcurrencyLimit(1)],
63+
) -> None: ...
6764
68-
async def backup_db(
69-
db_name: str,
70-
concurrency: ConcurrencyLimit = ConcurrencyLimit("db_name", max_concurrent=3)
71-
) -> None:
72-
# Only 3 backup tasks per database name will run at a time
73-
...
74-
```
65+
# Per-task (no argument grouping)
66+
async def expensive(
67+
concurrency: ConcurrencyLimit = ConcurrencyLimit(max_concurrent=3),
68+
) -> None: ...
7569
"""
7670

7771
single: bool = True
7872

73+
@overload
7974
def __init__(
8075
self,
81-
argument_name: str | None = None,
76+
max_concurrent: int,
77+
/,
78+
*,
79+
scope: str | None = None,
80+
) -> None:
81+
"""Annotated style: ``Annotated[int, ConcurrencyLimit(1)]``."""
82+
83+
@overload
84+
def __init__(
85+
self,
86+
argument_name: str,
8287
max_concurrent: int = 1,
8388
scope: str | None = None,
8489
) -> None:
85-
"""
86-
Args:
87-
argument_name: The name of the task argument to use for concurrency grouping.
88-
If None, limits concurrency for the task function itself.
89-
max_concurrent: Maximum number of concurrent tasks
90-
scope: Optional scope prefix for Redis keys (defaults to docket name)
91-
"""
92-
self.argument_name = argument_name
93-
self.max_concurrent = max_concurrent
90+
"""Default-param style with per-argument grouping."""
91+
92+
@overload
93+
def __init__(
94+
self,
95+
*,
96+
max_concurrent: int = 1,
97+
scope: str | None = None,
98+
) -> None:
99+
"""Per-task concurrency (no argument grouping)."""
100+
101+
def __init__(
102+
self,
103+
argument_name: str | int | None = None,
104+
max_concurrent: int = 1,
105+
scope: str | None = None,
106+
) -> None:
107+
if isinstance(argument_name, int):
108+
self.argument_name: str | None = None
109+
self.max_concurrent: int = argument_name
110+
else:
111+
self.argument_name = argument_name
112+
self.max_concurrent = max_concurrent
94113
self.scope = scope
95114
self._concurrency_key: str | None = None
96115
self._initialized: bool = False
97116
self._task_key: str | None = None
98117
self._renewal_task: asyncio.Task[None] | None = None
99118

119+
def bind_to_parameter(self, name: str, value: Any) -> ConcurrencyLimit:
120+
"""Bind to an ``Annotated`` parameter, inferring argument_name if not set."""
121+
argument_name = self.argument_name if self.argument_name is not None else name
122+
return ConcurrencyLimit(
123+
argument_name,
124+
max_concurrent=self.max_concurrent,
125+
scope=self.scope,
126+
)
127+
100128
async def __aenter__(self) -> ConcurrencyLimit:
101129
from ._functional import _Depends
102130

src/docket/dependencies/_functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
DependencyFactory,
1010
Shared as Shared,
1111
SharedContext as SharedContext,
12-
_Depends as _UncalledForDepends,
12+
)
13+
from uncalled_for.functional import _Depends as _UncalledForDepends
14+
from uncalled_for.introspection import (
1315
_parameter_cache as _parameter_cache,
1416
get_dependency_parameters,
1517
)

src/docket/dependencies/_resolution.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from uncalled_for import (
99
FailedDependency as FailedDependency,
10+
get_annotation_dependencies as get_annotation_dependencies,
1011
validate_dependencies as validate_dependencies,
1112
)
1213

@@ -28,6 +29,10 @@ def get_single_dependency_parameter_of_type(
2829
for _, dependency in get_dependency_parameters(function).items():
2930
if isinstance(dependency, dependency_type):
3031
return dependency
32+
for _, dependencies in get_annotation_dependencies(function).items():
33+
for dependency in dependencies:
34+
if isinstance(dependency, dependency_type):
35+
return dependency # type: ignore[return-value]
3136
return None
3237

3338

@@ -81,6 +86,20 @@ async def resolved_dependencies(
8186
except Exception as error:
8287
arguments[parameter] = FailedDependency(parameter, error)
8388

89+
annotations = get_annotation_dependencies(execution.function)
90+
for parameter_name, dependencies in annotations.items():
91+
value = execution.kwargs.get(
92+
parameter_name, arguments.get(parameter_name)
93+
)
94+
for dependency in dependencies:
95+
bound = dependency.bind_to_parameter(parameter_name, value)
96+
try:
97+
await stack.enter_async_context(bound)
98+
except Exception as error:
99+
arguments[parameter_name] = FailedDependency(
100+
parameter_name, error
101+
)
102+
84103
yield arguments
85104
finally:
86105
_Depends.stack.reset(stack_token)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Tests for ConcurrencyLimit via Annotated-style type hints."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import time
7+
from typing import Annotated
8+
9+
import pytest
10+
11+
from docket import ConcurrencyLimit, Docket, Worker
12+
13+
from tests.concurrency_limits.overlap import assert_some_overlap
14+
15+
16+
async def test_annotated_concurrency_limit(docket: Docket, worker: Worker):
17+
"""Annotated[int, ConcurrencyLimit(1)] limits concurrency per-parameter."""
18+
results: list[str] = []
19+
20+
async def task(customer_id: Annotated[int, ConcurrencyLimit(1)]):
21+
results.append(f"start_{customer_id}")
22+
await asyncio.sleep(0.01)
23+
results.append(f"end_{customer_id}")
24+
25+
await docket.add(task)(customer_id=1)
26+
await docket.add(task)(customer_id=1)
27+
28+
await worker.run_until_finished()
29+
30+
assert results == ["start_1", "end_1", "start_1", "end_1"]
31+
32+
33+
async def test_annotated_concurrency_different_values(docket: Docket, worker: Worker):
34+
"""Different argument values get independent concurrency slots."""
35+
execution_intervals: dict[int, tuple[float, float]] = {}
36+
37+
async def task(customer_id: Annotated[int, ConcurrencyLimit(1)]):
38+
start = time.monotonic()
39+
await asyncio.sleep(0.1)
40+
end = time.monotonic()
41+
execution_intervals[customer_id] = (start, end)
42+
43+
await docket.add(task)(customer_id=1)
44+
await docket.add(task)(customer_id=2)
45+
await docket.add(task)(customer_id=3)
46+
47+
worker.concurrency = 10
48+
await worker.run_until_finished()
49+
50+
assert len(execution_intervals) == 3
51+
intervals = list(execution_intervals.values())
52+
assert_some_overlap(intervals, "Different customers should run concurrently")
53+
54+
55+
async def test_annotated_concurrency_max_concurrent(docket: Docket, worker: Worker):
56+
"""max_concurrent>1 allows that many concurrent executions per value."""
57+
active_tasks: list[int] = []
58+
max_concurrent_seen = 0
59+
lock = asyncio.Lock()
60+
61+
async def task(
62+
db_name: str,
63+
task_id: Annotated[int, ConcurrencyLimit("db_name", max_concurrent=2)],
64+
):
65+
nonlocal max_concurrent_seen
66+
async with lock:
67+
active_tasks.append(task_id)
68+
max_concurrent_seen = max(max_concurrent_seen, len(active_tasks))
69+
await asyncio.sleep(0.1)
70+
async with lock:
71+
active_tasks.remove(task_id)
72+
73+
for i in range(5):
74+
await docket.add(task)(db_name="postgres", task_id=i)
75+
76+
worker.concurrency = 10
77+
await worker.run_until_finished()
78+
79+
assert max_concurrent_seen <= 2
80+
81+
82+
async def test_two_annotated_concurrency_limits_rejected(docket: Docket):
83+
"""ConcurrencyLimit.single=True prevents two annotations on one function."""
84+
with pytest.raises(ValueError, match="Only one ConcurrencyLimit"):
85+
86+
async def task(
87+
customer_id: Annotated[int, ConcurrencyLimit(1)],
88+
region: Annotated[str, ConcurrencyLimit(2)],
89+
): ... # pragma: no cover
90+
91+
await docket.add(task)(customer_id=1, region="us")
92+
93+
94+
async def test_single_conflict_annotation_and_default(docket: Docket):
95+
"""single=True conflict detected across annotation and default-param styles."""
96+
with pytest.raises(ValueError, match="Only one ConcurrencyLimit"):
97+
98+
async def task(
99+
customer_id: Annotated[int, ConcurrencyLimit(1)],
100+
concurrency: ConcurrencyLimit = ConcurrencyLimit(max_concurrent=2),
101+
): ... # pragma: no cover
102+
103+
await docket.add(task)(customer_id=1)
104+
105+
106+
async def test_annotated_concurrency_keys_cleaned_up(docket: Docket, worker: Worker):
107+
"""Concurrency keys from annotated deps are properly cleaned up."""
108+
109+
async def task(customer_id: Annotated[int, ConcurrencyLimit(1)]):
110+
pass
111+
112+
await docket.add(task)(customer_id=42)
113+
await worker.run_until_finished()
114+
115+
async with docket.redis() as redis:
116+
key = f"{docket.name}:concurrency:customer_id:42"
117+
assert await redis.exists(key) == 0

0 commit comments

Comments
 (0)