Skip to content

Commit 7afac78

Browse files
chrisguidryclaude
andcommitted
Review feedback: overloads, naming, and test organization
- Add @overload signatures to ConcurrencyLimit.__init__ to clarify the three calling conventions (int shorthand, str argument name, keyword-only) - Rename abbreviated variables (deps, param_name, dep) to full names - Rename test helper to my_side_effect for clarity - Move concurrency-specific tests into tests/concurrency_limits/test_annotated.py alongside the existing concurrency test suite - Remove section header comments from tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8af600a commit 7afac78

File tree

4 files changed

+164
-159
lines changed

4 files changed

+164
-159
lines changed

src/docket/dependencies/_concurrency.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import logging
77
from datetime import datetime, timedelta, timezone
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, Any, overload
99

1010
from .._cancellation import CANCEL_MSG_CLEANUP, cancel_task
1111
from ._base import (
@@ -70,21 +70,40 @@ async def expensive(
7070

7171
single: bool = True
7272

73+
@overload
74+
def __init__(
75+
self,
76+
max_concurrent: int,
77+
/,
78+
*,
79+
scope: str | None = None,
80+
) -> None:
81+
"""Annotated style: ``Annotated[int, ConcurrencyLimit(1)]``."""
82+
83+
@overload
84+
def __init__(
85+
self,
86+
argument_name: str,
87+
max_concurrent: int = 1,
88+
scope: str | None = None,
89+
) -> None:
90+
"""Default-param style with per-argument grouping."""
91+
92+
@overload
93+
def __init__(
94+
self,
95+
*,
96+
max_concurrent: int = 1,
97+
scope: str | None = None,
98+
) -> None:
99+
"""Per-task concurrency (no argument grouping)."""
100+
73101
def __init__(
74102
self,
75103
argument_name: str | int | None = None,
76104
max_concurrent: int = 1,
77105
scope: str | None = None,
78106
) -> None:
79-
"""
80-
Args:
81-
argument_name: The name of the task argument to use for concurrency grouping.
82-
If an ``int`` is passed as the first positional arg, it is treated as
83-
*max_concurrent* (convenient for ``Annotated[int, ConcurrencyLimit(1)]``).
84-
If ``None``, limits concurrency for the task function itself.
85-
max_concurrent: Maximum number of concurrent tasks
86-
scope: Optional scope prefix for Redis keys (defaults to docket name)
87-
"""
88107
if isinstance(argument_name, int):
89108
self.argument_name: str | None = None
90109
self.max_concurrent: int = argument_name
@@ -99,10 +118,9 @@ def __init__(
99118

100119
def bind_to_parameter(self, name: str, value: Any) -> ConcurrencyLimit:
101120
"""Bind to an ``Annotated`` parameter, inferring argument_name if not set."""
121+
argument_name = self.argument_name if self.argument_name is not None else name
102122
return ConcurrencyLimit(
103-
argument_name=self.argument_name
104-
if self.argument_name is not None
105-
else name,
123+
argument_name,
106124
max_concurrent=self.max_concurrent,
107125
scope=self.scope,
108126
)

src/docket/dependencies/_resolution.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def get_single_dependency_parameter_of_type(
2929
for _, dependency in get_dependency_parameters(function).items():
3030
if isinstance(dependency, dependency_type):
3131
return dependency
32-
for _, deps in get_annotation_dependencies(function).items():
33-
for dependency in deps:
32+
for _, dependencies in get_annotation_dependencies(function).items():
33+
for dependency in dependencies:
3434
if isinstance(dependency, dependency_type):
3535
return dependency # type: ignore[return-value]
3636
return None
@@ -86,15 +86,19 @@ async def resolved_dependencies(
8686
except Exception as error:
8787
arguments[parameter] = FailedDependency(parameter, error)
8888

89-
annotation_deps = get_annotation_dependencies(execution.function)
90-
for param_name, deps in annotation_deps.items():
91-
value = execution.kwargs.get(param_name, arguments.get(param_name))
92-
for dep in deps:
93-
bound = dep.bind_to_parameter(param_name, value)
89+
annotations = get_annotation_dependencies(execution.function)
90+
for parameter_name, dependencies in annotations.items():
91+
value = execution.kwargs.get(
92+
parameter_name, arguments.get(parameter_name)
93+
)
94+
for dependency in dependencies:
95+
bound = dependency.bind_to_parameter(parameter_name, value)
9496
try:
9597
await stack.enter_async_context(bound)
9698
except Exception as error:
97-
arguments[param_name] = FailedDependency(param_name, error)
99+
arguments[parameter_name] = FailedDependency(
100+
parameter_name, error
101+
)
98102

99103
yield arguments
100104
finally:
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Tests for ConcurrencyLimit via Annotated-style type hints."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import time
7+
from typing import Annotated
8+
9+
import pytest
10+
11+
from docket import ConcurrencyLimit, Docket, Worker
12+
13+
from tests.concurrency_limits.overlap import assert_some_overlap
14+
15+
16+
async def test_annotated_concurrency_limit(docket: Docket, worker: Worker):
17+
"""Annotated[int, ConcurrencyLimit(1)] limits concurrency per-parameter."""
18+
results: list[str] = []
19+
20+
async def task(customer_id: Annotated[int, ConcurrencyLimit(1)]):
21+
results.append(f"start_{customer_id}")
22+
await asyncio.sleep(0.01)
23+
results.append(f"end_{customer_id}")
24+
25+
await docket.add(task)(customer_id=1)
26+
await docket.add(task)(customer_id=1)
27+
28+
await worker.run_until_finished()
29+
30+
assert results == ["start_1", "end_1", "start_1", "end_1"]
31+
32+
33+
async def test_annotated_concurrency_different_values(docket: Docket, worker: Worker):
34+
"""Different argument values get independent concurrency slots."""
35+
execution_intervals: dict[int, tuple[float, float]] = {}
36+
37+
async def task(customer_id: Annotated[int, ConcurrencyLimit(1)]):
38+
start = time.monotonic()
39+
await asyncio.sleep(0.1)
40+
end = time.monotonic()
41+
execution_intervals[customer_id] = (start, end)
42+
43+
await docket.add(task)(customer_id=1)
44+
await docket.add(task)(customer_id=2)
45+
await docket.add(task)(customer_id=3)
46+
47+
worker.concurrency = 10
48+
await worker.run_until_finished()
49+
50+
assert len(execution_intervals) == 3
51+
intervals = list(execution_intervals.values())
52+
assert_some_overlap(intervals, "Different customers should run concurrently")
53+
54+
55+
async def test_annotated_concurrency_max_concurrent(docket: Docket, worker: Worker):
56+
"""max_concurrent>1 allows that many concurrent executions per value."""
57+
active_tasks: list[int] = []
58+
max_concurrent_seen = 0
59+
lock = asyncio.Lock()
60+
61+
async def task(
62+
db_name: str,
63+
task_id: Annotated[int, ConcurrencyLimit("db_name", max_concurrent=2)],
64+
):
65+
nonlocal max_concurrent_seen
66+
async with lock:
67+
active_tasks.append(task_id)
68+
max_concurrent_seen = max(max_concurrent_seen, len(active_tasks))
69+
await asyncio.sleep(0.1)
70+
async with lock:
71+
active_tasks.remove(task_id)
72+
73+
for i in range(5):
74+
await docket.add(task)(db_name="postgres", task_id=i)
75+
76+
worker.concurrency = 10
77+
await worker.run_until_finished()
78+
79+
assert max_concurrent_seen <= 2
80+
81+
82+
async def test_two_annotated_concurrency_limits_rejected(docket: Docket):
83+
"""ConcurrencyLimit.single=True prevents two annotations on one function."""
84+
with pytest.raises(ValueError, match="Only one ConcurrencyLimit"):
85+
86+
async def task(
87+
customer_id: Annotated[int, ConcurrencyLimit(1)],
88+
region: Annotated[str, ConcurrencyLimit(2)],
89+
): ... # pragma: no cover
90+
91+
await docket.add(task)(customer_id=1, region="us")
92+
93+
94+
async def test_single_conflict_annotation_and_default(docket: Docket):
95+
"""single=True conflict detected across annotation and default-param styles."""
96+
with pytest.raises(ValueError, match="Only one ConcurrencyLimit"):
97+
98+
async def task(
99+
customer_id: Annotated[int, ConcurrencyLimit(1)],
100+
concurrency: ConcurrencyLimit = ConcurrencyLimit(max_concurrent=2),
101+
): ... # pragma: no cover
102+
103+
await docket.add(task)(customer_id=1)
104+
105+
106+
async def test_annotated_concurrency_keys_cleaned_up(docket: Docket, worker: Worker):
107+
"""Concurrency keys from annotated deps are properly cleaned up."""
108+
109+
async def task(customer_id: Annotated[int, ConcurrencyLimit(1)]):
110+
pass
111+
112+
await docket.add(task)(customer_id=42)
113+
await worker.run_until_finished()
114+
115+
async with docket.redis() as redis:
116+
key = f"{docket.name}:concurrency:customer_id:42"
117+
assert await redis.exists(key) == 0

0 commit comments

Comments
 (0)