Skip to content

Commit e331c29

Browse files
chrisguidryclaude
andcommitted
Fix RateLimit sorted set member uniqueness for Perpetual tasks
Using `execution.key` as the sorted set member meant Perpetual tasks (which reuse the same key via `replace()`) only ever had one entry — ZADD overwrites the score instead of adding a new member, so ZCARD stays at 1 and the rate limit never fires. The member is now `{execution.key}:{now_ms}`, so each execution attempt gets its own entry. Also cleans up phantom slots in `__aexit__` when a later dependency (like ConcurrencyLimit) blocks — the task never ran but would otherwise consume a rate-limit slot. Caught during PR #356 review. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4e7e7b8 commit e331c29

File tree

2 files changed

+84
-5
lines changed

2 files changed

+84
-5
lines changed

src/docket/dependencies/_ratelimit.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
33
Caps how many times a task (or a per-parameter scope) can execute within a
44
sliding window. Uses a Redis sorted set as a sliding window log: members are
5-
execution keys, scores are millisecond timestamps.
5+
``{execution_key}:{now_ms}`` strings (unique per attempt), scores are
6+
millisecond timestamps.
67
"""
78

89
from __future__ import annotations
@@ -17,7 +18,7 @@
1718
# Lua script for atomic sliding-window rate limit check.
1819
#
1920
# KEYS[1] = sorted set key (one per scope)
20-
# ARGV[1] = execution key (member)
21+
# ARGV[1] = member (execution key + timestamp, unique per attempt)
2122
# ARGV[2] = current time in milliseconds
2223
# ARGV[3] = window size in milliseconds
2324
# ARGV[4] = max allowed count (limit)
@@ -99,6 +100,8 @@ def __init__(
99100
self.scope = scope
100101
self._argument_name: str | None = None
101102
self._argument_value: Any = None
103+
self._ratelimit_key: str | None = None
104+
self._member: str | None = None
102105

103106
def bind_to_parameter(self, name: str, value: Any) -> RateLimit:
104107
bound = RateLimit(self.limit, per=self.per, drop=self.drop, scope=self.scope)
@@ -121,18 +124,21 @@ async def __aenter__(self) -> RateLimit:
121124
window_ms = int(self.per.total_seconds() * 1000)
122125
now_ms = int(time.time() * 1000)
123126
ttl_ms = window_ms * 2
127+
member = f"{execution.key}:{now_ms}"
124128

125129
async with docket.redis() as redis:
126130
script = redis.register_script(_RATELIMIT_LUA)
127131
result: list[int] = await script(
128132
keys=[ratelimit_key],
129-
args=[execution.key, now_ms, window_ms, self.limit, ttl_ms],
133+
args=[member, now_ms, window_ms, self.limit, ttl_ms],
130134
)
131135

132136
action = result[0]
133137
retry_after_ms = result[1]
134138

135139
if action == _ACTION_PROCEED:
140+
self._ratelimit_key = ratelimit_key
141+
self._member = member
136142
return self
137143

138144
reason = f"rate limit ({self.limit}/{self.per}) on {ratelimit_key}"
@@ -152,4 +158,9 @@ async def __aexit__(
152158
exc_value: BaseException | None,
153159
traceback: TracebackType | None,
154160
) -> None:
155-
pass
161+
if exc_type is not None and self._member is not None:
162+
if issubclass(exc_type, AdmissionBlocked):
163+
assert self._ratelimit_key is not None
164+
docket = current_docket.get()
165+
async with docket.redis() as redis:
166+
await redis.zrem(self._ratelimit_key, self._member)

tests/test_ratelimit.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import timedelta
77
from typing import Annotated
88

9-
from docket import ConcurrencyLimit, Docket, Worker
9+
from docket import ConcurrencyLimit, Docket, Perpetual, Worker
1010
from docket.dependencies import RateLimit
1111

1212

@@ -169,3 +169,71 @@ async def rated_task(
169169
)
170170
]
171171
assert ratelimit_keys == []
172+
173+
174+
async def test_rate_limit_slot_kept_on_task_failure(docket: Docket, worker: Worker):
175+
"""A failed task still counts against the rate limit."""
176+
results: list[str] = []
177+
178+
async def failing_task(
179+
rate: RateLimit = RateLimit(2, per=timedelta(seconds=5), drop=True),
180+
):
181+
results.append("attempted")
182+
if len(results) == 1:
183+
raise RuntimeError("boom")
184+
185+
await docket.add(failing_task)()
186+
await docket.add(failing_task)()
187+
await docket.add(failing_task)()
188+
189+
await worker.run_until_finished()
190+
191+
assert len(results) == 2
192+
193+
194+
async def test_perpetual_task_counts_each_execution(docket: Docket, worker: Worker):
195+
"""Perpetual re-executions each count against the rate limit."""
196+
results: list[str] = []
197+
198+
async def perpetual_rated(
199+
perpetual: Perpetual = Perpetual(every=timedelta(milliseconds=10)),
200+
rate: RateLimit = RateLimit(2, per=timedelta(seconds=5), drop=True),
201+
):
202+
results.append("executed")
203+
204+
execution = await docket.add(perpetual_rated)()
205+
await worker.run_at_most({execution.key: 4})
206+
207+
assert len(results) == 2
208+
209+
210+
async def test_rate_limit_slot_freed_when_another_dep_blocks(
211+
docket: Docket, worker: Worker
212+
):
213+
"""RateLimit slot is freed if a later dependency blocks the task.
214+
215+
RateLimit (default-param, resolved first) proceeds and records a slot,
216+
then ConcurrencyLimit (annotation, resolved second) blocks. Without
217+
__aexit__ cleanup the phantom slot stays in the sorted set.
218+
"""
219+
results: list[str] = []
220+
221+
async def blocker(
222+
customer_id: Annotated[int, ConcurrencyLimit(1)],
223+
rate: RateLimit = RateLimit(2, per=timedelta(seconds=5), drop=True),
224+
):
225+
results.append(f"executed_{customer_id}")
226+
if len(results) == 1:
227+
await asyncio.sleep(0.3)
228+
229+
# Task 1 grabs both the rate-limit slot and the concurrency slot.
230+
# Task 2 passes rate-limit (2/2) but is concurrency-blocked, then
231+
# rescheduled. Without cleanup, the phantom slot means task 2's
232+
# retry hits 2/2 and gets dropped.
233+
await docket.add(blocker)(customer_id=1)
234+
await docket.add(blocker)(customer_id=1)
235+
236+
worker.concurrency = 2
237+
await worker.run_until_finished()
238+
239+
assert len(results) == 2

0 commit comments

Comments
 (0)