Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/docket/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,11 @@ async def schedule(
-- Check if task already exists (check new location first, then legacy)
local known_exists = redis.call('HEXISTS', runs_key, 'known') == 1
if not known_exists then
-- Check if task is currently running (known field deleted at claim time)
local state = redis.call('HGET', runs_key, 'state')
if state == 'running' then
return 'EXISTS'
end
-- TODO: Remove in next breaking release (v0.14.0) - check legacy location
known_exists = redis.call('EXISTS', known_key) == 1
end
Expand Down
12 changes: 8 additions & 4 deletions tests/test_fundamentals.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,8 @@ async def test_self_perpetuating_immediate_tasks(
async def the_task(start: int, iteration: int, key: str = TaskKey()):
calls[key].append(start + iteration)
if iteration < 3:
await docket.add(the_task, key=key)(start, iteration + 1)
# Use replace() for self-perpetuating to allow rescheduling while running
await docket.replace(the_task, now(), key=key)(start, iteration + 1)

await docket.add(the_task, key="first")(10, 1)
await docket.add(the_task, key="second")(20, 1)
Expand All @@ -929,7 +930,8 @@ async def the_task(start: int, iteration: int, key: str = TaskKey()):
calls[key].append(start + iteration)
if iteration < 3:
soon = now() + timedelta(milliseconds=100)
await docket.add(the_task, key=key, when=soon)(start, iteration + 1)
# Use replace() for self-perpetuating to allow rescheduling while running
await docket.replace(the_task, key=key, when=soon)(start, iteration + 1)

await docket.add(the_task, key="first")(10, 1)
await docket.add(the_task, key="second")(20, 1)
Expand All @@ -954,12 +956,14 @@ async def test_infinitely_self_perpetuating_tasks(
async def the_task(start: int, iteration: int, key: str = TaskKey()):
calls[key].append(start + iteration)
soon = now() + timedelta(milliseconds=100)
await docket.add(the_task, key=key, when=soon)(start, iteration + 1)
# Use replace() for self-perpetuating to allow rescheduling while running
await docket.replace(the_task, key=key, when=soon)(start, iteration + 1)

async def unaffected_task(start: int, iteration: int, key: str = TaskKey()):
calls[key].append(start + iteration)
if iteration < 3:
await docket.add(unaffected_task, key=key)(start, iteration + 1)
# Use replace() for self-perpetuating to allow rescheduling while running
await docket.replace(unaffected_task, now(), key=key)(start, iteration + 1)

await docket.add(the_task, key="first")(10, 1)
await docket.add(the_task, key="second")(20, 1)
Expand Down
66 changes: 66 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,72 @@ async def test_rapid_replace_operations(
assert the_task.await_count == 1


@pytest.mark.parametrize(
"execution_ttl", [None, timedelta(0)], ids=["default_ttl", "zero_ttl"]
)
async def test_duplicate_execution_race_condition_non_perpetual_task(
redis_url: str, execution_ttl: timedelta | None
):
"""Reproduce race condition where non-perpetual tasks execute multiple times.

Bug: known_task_key is deleted BEFORE task function runs (worker.py:588),
allowing duplicate docket.add() calls with the same key to succeed
while the original task is still executing.

Timeline:
1. Task A scheduled with key="task:123" -> known_key set
2. Worker picks up Task A, _perpetuate_if_requested() returns False
3. Worker calls _delete_known_task() -> known_key DELETED
4. Worker starts executing the actual task function (slow task)
5. Meanwhile, docket.add(key="task:123") checks EXISTS known_key -> 0
6. Duplicate task scheduled and picked up by concurrent worker
7. Both tasks execute in parallel

Tests both default TTL and execution_ttl=0 to ensure fix doesn't depend
on volatile results keys.
"""
execution_count = 0
task_started = asyncio.Event()

async def slow_task(task_id: str):
nonlocal execution_count
execution_count += 1
task_started.set()
await asyncio.sleep(0.3)

docket_kwargs: dict[str, object] = {
"name": f"test-race-{uuid4()}",
"url": redis_url,
}
if execution_ttl is not None:
docket_kwargs["execution_ttl"] = execution_ttl

async with Docket(**docket_kwargs) as docket: # type: ignore[arg-type]
docket.register(slow_task)
task_key = f"race-test:{uuid4()}"

async with Worker(docket, concurrency=2) as worker:
worker_task = asyncio.create_task(worker.run_until_finished())

# Schedule first task
await docket.add(slow_task, key=task_key)("first")

# Wait for task to start (known_key already deleted at this point)
await asyncio.wait_for(task_started.wait(), timeout=2.0)
await asyncio.sleep(0.05) # Small buffer to ensure deletion happened

# Attempt duplicate - should be rejected but isn't due to bug
await docket.add(slow_task, key=task_key)("second")

await asyncio.wait_for(worker_task, timeout=5.0)

# BUG: execution_count == 2 (both tasks ran)
# EXPECTED: execution_count == 1 (duplicate rejected)
assert execution_count == 1, (
f"Task executed {execution_count} times, expected 1"
)


async def test_wrongtype_error_with_legacy_known_task_key(
docket: Docket,
worker: Worker,
Expand Down