Skip to content

Commit 9dc405c

Browse files
chrisguidryclaude
andauthored
Fix race condition in task replacement where duplicate executions occur (#151)
Fix race condition in task replacement where duplicate executions occur Previously, `docket.replace()` could not cancel tasks that had already been moved from the queue to the stream by the scheduler, causing duplicate execution when the old task ran alongside the replacement. This happened because: 1. Scheduler moves due tasks from queue → stream, assigning Redis message IDs 2. `replace()` only checked the queue, missing tasks already in the stream 3. Both old and new tasks would execute This implements atomic task scheduling and cancellation using Lua scripts that track stream message IDs. The solution: - **Atomic scheduling**: New Lua script handles task existence checks, replacement cancellation, and scheduling in a single operation - **Stream message ID tracking**: When tasks move to stream, their message IDs are stored in the known task metadata - **Atomic cancellation**: Can delete tasks from stream using stored message IDs or from queue for scheduled tasks - **Race-free replacement**: Uses Redis locks per task key and atomic operations Comprehensive test coverage includes race condition scenarios and idempotent operations. Closes #149 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude <noreply@anthropic.com>
1 parent 7a36609 commit 9dc405c

File tree

4 files changed

+257
-40
lines changed

4 files changed

+257
-40
lines changed

src/docket/docket.py

Lines changed: 151 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Mapping,
1717
NoReturn,
1818
ParamSpec,
19+
Protocol,
1920
Self,
2021
Sequence,
2122
TypedDict,
@@ -27,7 +28,6 @@
2728
import redis.exceptions
2829
from opentelemetry import propagate, trace
2930
from redis.asyncio import ConnectionPool, Redis
30-
from redis.asyncio.client import Pipeline
3131
from uuid_extensions import uuid7
3232

3333
from .execution import (
@@ -55,6 +55,18 @@
5555
tracer: trace.Tracer = trace.get_tracer(__name__)
5656

5757

58+
class _schedule_task(Protocol):
59+
async def __call__(
60+
self, keys: list[str], args: list[str | float | bytes]
61+
) -> str: ... # pragma: no cover
62+
63+
64+
class _cancel_task(Protocol):
65+
async def __call__(
66+
self, keys: list[str], args: list[str]
67+
) -> str: ... # pragma: no cover
68+
69+
5870
P = ParamSpec("P")
5971
R = TypeVar("R")
6072

@@ -131,6 +143,8 @@ async def my_task(greeting: str, recipient: str) -> None:
131143

132144
_monitor_strikes_task: asyncio.Task[None]
133145
_connection_pool: ConnectionPool
146+
_schedule_task_script: _schedule_task | None
147+
_cancel_task_script: _cancel_task | None
134148

135149
def __init__(
136150
self,
@@ -156,6 +170,8 @@ def __init__(
156170
self.url = url
157171
self.heartbeat_interval = heartbeat_interval
158172
self.missed_heartbeats = missed_heartbeats
173+
self._schedule_task_script = None
174+
self._cancel_task_script = None
159175

160176
@property
161177
def worker_group_name(self) -> str:
@@ -300,9 +316,7 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
300316
execution = Execution(function, args, kwargs, when, key, attempt=1)
301317

302318
async with self.redis() as redis:
303-
async with redis.pipeline() as pipeline:
304-
await self._schedule(redis, pipeline, execution, replace=False)
305-
await pipeline.execute()
319+
await self._schedule(redis, execution, replace=False)
306320

307321
TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()})
308322
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
@@ -361,9 +375,7 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
361375
execution = Execution(function, args, kwargs, when, key, attempt=1)
362376

363377
async with self.redis() as redis:
364-
async with redis.pipeline() as pipeline:
365-
await self._schedule(redis, pipeline, execution, replace=True)
366-
await pipeline.execute()
378+
await self._schedule(redis, execution, replace=True)
367379

368380
TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()})
369381
TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()})
@@ -383,9 +395,7 @@ async def schedule(self, execution: Execution) -> None:
383395
},
384396
):
385397
async with self.redis() as redis:
386-
async with redis.pipeline() as pipeline:
387-
await self._schedule(redis, pipeline, execution, replace=False)
388-
await pipeline.execute()
398+
await self._schedule(redis, execution, replace=False)
389399

390400
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
391401

@@ -400,9 +410,7 @@ async def cancel(self, key: str) -> None:
400410
attributes={**self.labels(), "docket.key": key},
401411
):
402412
async with self.redis() as redis:
403-
async with redis.pipeline() as pipeline:
404-
await self._cancel(pipeline, key)
405-
await pipeline.execute()
413+
await self._cancel(redis, key)
406414

407415
TASKS_CANCELLED.add(1, self.labels())
408416

@@ -423,10 +431,17 @@ def parked_task_key(self, key: str) -> str:
423431
async def _schedule(
424432
self,
425433
redis: Redis,
426-
pipeline: Pipeline,
427434
execution: Execution,
428435
replace: bool = False,
429436
) -> None:
437+
"""Schedule a task atomically.
438+
439+
Handles:
440+
- Checking for task existence
441+
- Cancelling existing tasks when replacing
442+
- Adding tasks to stream (immediate) or queue (future)
443+
- Tracking stream message IDs for later cancellation
444+
"""
430445
if self.strike_list.is_stricken(execution):
431446
logger.warning(
432447
"%r is stricken, skipping schedule of %r",
@@ -449,32 +464,133 @@ async def _schedule(
449464
key = execution.key
450465
when = execution.when
451466
known_task_key = self.known_task_key(key)
467+
is_immediate = when <= datetime.now(timezone.utc)
452468

469+
# Lock per task key to prevent race conditions between concurrent operations
453470
async with redis.lock(f"{known_task_key}:lock", timeout=10):
454-
if replace:
455-
await self._cancel(pipeline, key)
456-
else:
457-
# if the task is already in the queue or stream, retain it
458-
if await redis.exists(known_task_key):
459-
logger.debug(
460-
"Task %r is already in the queue or stream, not scheduling",
461-
key,
462-
extra=self.labels(),
463-
)
464-
return
471+
if self._schedule_task_script is None:
472+
self._schedule_task_script = cast(
473+
_schedule_task,
474+
redis.register_script(
475+
# KEYS: stream_key, known_key, parked_key, queue_key
476+
# ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields
477+
"""
478+
local stream_key = KEYS[1]
479+
local known_key = KEYS[2]
480+
local parked_key = KEYS[3]
481+
local queue_key = KEYS[4]
482+
483+
local task_key = ARGV[1]
484+
local when_timestamp = ARGV[2]
485+
local is_immediate = ARGV[3] == '1'
486+
local replace = ARGV[4] == '1'
487+
488+
-- Extract message fields from ARGV[5] onwards
489+
local message = {}
490+
for i = 5, #ARGV, 2 do
491+
message[#message + 1] = ARGV[i] -- field name
492+
message[#message + 1] = ARGV[i + 1] -- field value
493+
end
494+
495+
-- Handle replacement: cancel existing task if needed
496+
if replace then
497+
local existing_message_id = redis.call('HGET', known_key, 'stream_message_id')
498+
if existing_message_id then
499+
redis.call('XDEL', stream_key, existing_message_id)
500+
end
501+
redis.call('DEL', known_key, parked_key)
502+
redis.call('ZREM', queue_key, task_key)
503+
else
504+
-- Check if task already exists
505+
if redis.call('EXISTS', known_key) == 1 then
506+
return 'EXISTS'
507+
end
508+
end
509+
510+
if is_immediate then
511+
-- Add to stream and store message ID for later cancellation
512+
local message_id = redis.call('XADD', stream_key, '*', unpack(message))
513+
redis.call('HSET', known_key, 'when', when_timestamp, 'stream_message_id', message_id)
514+
return message_id
515+
else
516+
-- Add to queue with task data in parked hash
517+
redis.call('HSET', known_key, 'when', when_timestamp)
518+
redis.call('HSET', parked_key, unpack(message))
519+
redis.call('ZADD', queue_key, when_timestamp, task_key)
520+
return 'QUEUED'
521+
end
522+
"""
523+
),
524+
)
525+
schedule_task = self._schedule_task_script
465526

466-
pipeline.set(known_task_key, when.timestamp())
527+
await schedule_task(
528+
keys=[
529+
self.stream_key,
530+
known_task_key,
531+
self.parked_task_key(key),
532+
self.queue_key,
533+
],
534+
args=[
535+
key,
536+
str(when.timestamp()),
537+
"1" if is_immediate else "0",
538+
"1" if replace else "0",
539+
*[
540+
item
541+
for field, value in message.items()
542+
for item in (field, value)
543+
],
544+
],
545+
)
467546

468-
if when <= datetime.now(timezone.utc):
469-
pipeline.xadd(self.stream_key, message) # type: ignore[arg-type]
470-
else:
471-
pipeline.hset(self.parked_task_key(key), mapping=message) # type: ignore[arg-type]
472-
pipeline.zadd(self.queue_key, {key: when.timestamp()})
547+
async def _cancel(self, redis: Redis, key: str) -> None:
548+
"""Cancel a task atomically.
473549
474-
async def _cancel(self, pipeline: Pipeline, key: str) -> None:
475-
pipeline.delete(self.known_task_key(key))
476-
pipeline.delete(self.parked_task_key(key))
477-
pipeline.zrem(self.queue_key, key)
550+
Handles cancellation regardless of task location:
551+
- From the stream (using stored message ID)
552+
- From the queue (scheduled tasks)
553+
- Cleans up all associated metadata keys
554+
"""
555+
if self._cancel_task_script is None:
556+
self._cancel_task_script = cast(
557+
_cancel_task,
558+
redis.register_script(
559+
# KEYS: stream_key, known_key, parked_key, queue_key
560+
# ARGV: task_key
561+
"""
562+
local stream_key = KEYS[1]
563+
local known_key = KEYS[2]
564+
local parked_key = KEYS[3]
565+
local queue_key = KEYS[4]
566+
local task_key = ARGV[1]
567+
568+
-- Delete from stream if message ID exists
569+
local message_id = redis.call('HGET', known_key, 'stream_message_id')
570+
if message_id then
571+
redis.call('XDEL', stream_key, message_id)
572+
end
573+
574+
-- Clean up all task-related keys
575+
redis.call('DEL', known_key, parked_key)
576+
redis.call('ZREM', queue_key, task_key)
577+
578+
return 'OK'
579+
"""
580+
),
581+
)
582+
cancel_task = self._cancel_task_script
583+
584+
# Execute the cancellation script
585+
await cancel_task(
586+
keys=[
587+
self.stream_key,
588+
self.known_task_key(key),
589+
self.parked_task_key(key),
590+
self.queue_key,
591+
],
592+
args=[key],
593+
)
478594

479595
@property
480596
def strike_key(self) -> str:

src/docket/worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,14 +406,17 @@ async def _scheduler_loop(
406406
task[task_data[j]] = task_data[j+1]
407407
end
408408
409-
redis.call('XADD', KEYS[2], '*',
409+
local message_id = redis.call('XADD', KEYS[2], '*',
410410
'key', task['key'],
411411
'when', task['when'],
412412
'function', task['function'],
413413
'args', task['args'],
414414
'kwargs', task['kwargs'],
415415
'attempt', task['attempt']
416416
)
417+
-- Store the message ID in the known task key
418+
local known_key = ARGV[2] .. ":known:" .. key
419+
redis.call('HSET', known_key, 'stream_message_id', message_id)
417420
redis.call('DEL', hash_key)
418421
due_work = due_work + 1
419422
end

tests/test_fundamentals.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,39 @@ async def test_cancelling_future_task(
250250
the_task.assert_not_called()
251251

252252

253-
async def test_cancelling_current_task_not_supported(
253+
async def test_cancelling_immediate_task(
254254
docket: Docket, worker: Worker, the_task: AsyncMock, now: Callable[[], datetime]
255255
):
256-
"""docket does not allow cancelling a task that is schedule now"""
256+
"""docket can cancel a task that is scheduled immediately"""
257257

258258
execution = await docket.add(the_task, now())("a", "b", c="c")
259259

260260
await docket.cancel(execution.key)
261261

262262
await worker.run_until_finished()
263263

264-
the_task.assert_awaited_once_with("a", "b", c="c")
264+
the_task.assert_not_called()
265+
266+
267+
async def test_cancellation_is_idempotent(
268+
docket: Docket, worker: Worker, the_task: AsyncMock, now: Callable[[], datetime]
269+
):
270+
"""Test that canceling the same task twice doesn't error."""
271+
key = f"test-task:{uuid4()}"
272+
273+
# Schedule a task
274+
later = now() + timedelta(seconds=1)
275+
await docket.add(the_task, later, key=key)("test")
276+
277+
# Cancel it twice - both should succeed without error
278+
await docket.cancel(key)
279+
await docket.cancel(key) # Should be idempotent
280+
281+
# Run worker to ensure the task was actually cancelled
282+
await worker.run_until_finished()
283+
284+
# Task should not have been executed since it was cancelled
285+
the_task.assert_not_called()
265286

266287

267288
async def test_errors_are_logged(

0 commit comments

Comments
 (0)