Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions loq.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ max_lines = 750
# Source files that still need exceptions above 750
[[rules]]
path = "src/docket/worker.py"
max_lines = 1125
max_lines = 1130

[[rules]]
path = "src/docket/cli.py"
max_lines = 945

[[rules]]
path = "src/docket/execution.py"
max_lines = 903
max_lines = 950

[[rules]]
path = "src/docket/docket.py"
Expand Down
8 changes: 8 additions & 0 deletions src/docket/dependencies/_perpetual.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ async def on_complete(self, execution: Execution, outcome: TaskOutcome) -> bool:
await docket._cancel(redis, execution.key)
return False

if await execution.is_superseded():
logger.info(
"↬ [%s] %s (superseded)",
format_duration(outcome.duration.total_seconds()),
execution.call_repr(),
)
return True

docket = self.docket.get()
worker = self.worker.get()

Expand Down
37 changes: 37 additions & 0 deletions src/docket/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
trace_context: opentelemetry.context.Context | None = None,
redelivered: bool = False,
function_name: str | None = None,
generation: int = 0,
) -> None:
# Task definition (immutable)
self._docket = docket
Expand All @@ -125,6 +126,7 @@ def __init__(
self.attempt = attempt
self._trace_context = trace_context
self._redelivered = redelivered
self._generation = generation

# Lifecycle state (mutable)
self.state: ExecutionState = ExecutionState.SCHEDULED
Expand Down Expand Up @@ -182,6 +184,11 @@ def redelivered(self) -> bool:
"""Whether this message was redelivered."""
return self._redelivered

@property
def generation(self) -> int:
"""Scheduling generation counter for supersession detection."""
return self._generation

@contextmanager
def _maybe_suppress_instrumentation(self) -> Generator[None, None, None]:
"""Suppress OTel auto-instrumentation for internal Redis operations."""
Expand All @@ -199,6 +206,7 @@ def as_message(self) -> Message:
b"args": cloudpickle.dumps(self.args),
b"kwargs": cloudpickle.dumps(self.kwargs),
b"attempt": str(self.attempt).encode(),
b"generation": str(self.generation).encode(),
}

@classmethod
Expand Down Expand Up @@ -228,6 +236,7 @@ async def from_message(
trace_context=propagate.extract(message, getter=message_getter),
redelivered=redelivered,
function_name=function_name,
generation=int(message.get(b"generation", b"0")),
)
await instance.sync()
return instance
Expand Down Expand Up @@ -339,6 +348,7 @@ async def schedule(
local function_name = nil
local args_data = nil
local kwargs_data = nil
local generation_index = nil

for i = 7, #ARGV, 2 do
local field_name = ARGV[i]
Expand All @@ -353,6 +363,8 @@ async def schedule(
args_data = field_value
elseif field_name == 'kwargs' then
kwargs_data = field_value
elseif field_name == 'generation' then
generation_index = #message
end
end

Expand All @@ -364,6 +376,12 @@ async def schedule(
redis.call('XACK', stream_key, worker_group_name, reschedule_message_id)
redis.call('XDEL', stream_key, reschedule_message_id)

-- Increment generation counter
local new_gen = redis.call('HINCRBY', runs_key, 'generation', 1)
if generation_index then
message[generation_index] = tostring(new_gen)
end

-- Park task data for future execution
redis.call('HSET', parked_key, unpack(message))

Expand Down Expand Up @@ -421,6 +439,12 @@ async def schedule(
end
end

-- Increment generation counter
local new_gen = redis.call('HINCRBY', runs_key, 'generation', 1)
if generation_index then
message[generation_index] = tostring(new_gen)
end

if is_immediate then
-- Add to stream for immediate execution
local message_id = redis.call('XADD', stream_key, '*', unpack(message))
Expand Down Expand Up @@ -802,6 +826,19 @@ async def sync(self) -> None:
# Sync progress data
await self.progress.sync()

async def is_superseded(self) -> bool:
"""Check whether a newer schedule has superseded this execution.

Compares this execution's generation against the current generation
stored in the runs hash. If the stored generation is strictly greater,
this execution has been superseded by a newer schedule() call.
"""
with self._maybe_suppress_instrumentation():
async with self.docket.redis() as redis:
current = await redis.hget(self._redis_key, "generation")
current_gen = int(current) if current is not None else 0
return current_gen > self._generation

async def _publish_state(self, data: dict) -> None:
"""Publish state change to Redis pub/sub channel.

Expand Down
7 changes: 6 additions & 1 deletion src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ async def _scheduler_loop(self, redis: Redis) -> None:
'function', task['function'],
'args', task['args'],
'kwargs', task['kwargs'],
'attempt', task['attempt']
'attempt', task['attempt'],
'generation', task['generation'] or '0'
)
redis.call('DEL', hash_key)

Expand Down Expand Up @@ -798,6 +799,10 @@ async def _execute(self, execution: Execution) -> None:
TASKS_STRICKEN.add(1, counter_labels | {"docket.where": "worker"})
return

if await execution.is_superseded():
logger.info("↬ %s (superseded)", call, extra=log_context)
return

if execution.key in self._execution_counts:
self._execution_counts[execution.key] += 1

Expand Down
157 changes: 157 additions & 0 deletions tests/test_perpetual_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Tests for the Perpetual rescheduling race condition documented in RACE.md.

When a Perpetual task is running and an external caller uses docket.replace() to
force immediate re-execution of the same key, two executions run concurrently.
Both call Perpetual.on_complete() → docket.replace() on completion, and the last
one to finish wins — potentially with stale timing data.

The fix uses a generation counter: each schedule() atomically increments the
generation in the runs hash. Perpetual.on_complete() checks whether the execution
has been superseded before rescheduling.
"""

import asyncio
from datetime import datetime, timedelta, timezone

from docket import CurrentExecution, Docket, Perpetual, Worker
from docket.execution import Execution

TASK_KEY = "perpetual-race-test"

STALE_INTERVAL = timedelta(seconds=2)
CORRECT_INTERVAL = timedelta(milliseconds=500)


async def test_stale_perpetual_on_complete_overwrites_correct_successor(
docket: Docket, worker: Worker
):
"""When a running Perpetual task is externally replaced and finishes after
the replacement, its on_complete overwrites the correctly-timed successor."""

config: dict[str, timedelta] = {"interval": STALE_INTERVAL}

task_a_started = asyncio.Event()
let_a_finish = asyncio.Event()
task_b_started = asyncio.Event()
let_b_finish = asyncio.Event()

executions: list[Execution] = []

async def racing_task(
perpetual: Perpetual = Perpetual(),
execution: Execution = CurrentExecution(),
):
my_interval = config["interval"]
call_number = len(executions) + 1
executions.append(execution)

if call_number == 1:
# Task A: signal start, block until released
task_a_started.set()
await asyncio.wait_for(let_a_finish.wait(), timeout=10)
elif call_number == 2:
# Task B: signal start, block until released
task_b_started.set()
await asyncio.wait_for(let_b_finish.wait(), timeout=10)
# Task C (call 3): the successor — just runs

perpetual.after(my_interval)

# Schedule the initial task (A)
await docket.add(racing_task, key=TASK_KEY)()

# Run the worker in the background (allow 3 executions of this key)
worker_task = asyncio.create_task(worker.run_at_most({TASK_KEY: 3}))

# Wait for task A to start executing
await asyncio.wait_for(task_a_started.wait(), timeout=10)

# Simulate user changing config to a shorter interval
config["interval"] = CORRECT_INTERVAL

# External replace: force immediate re-execution (creates task B)
replace_time = datetime.now(timezone.utc)
await docket.replace(racing_task, replace_time, TASK_KEY)()

# Wait for task B to start
await asyncio.wait_for(task_b_started.wait(), timeout=10)

# Let B finish first — B's on_complete schedules successor at B_time + 500ms
let_b_finish.set()
# Give B's on_complete time to complete
await asyncio.sleep(0.05)

# Now let A finish — A's on_complete overwrites B's successor with
# A_start + 2s (stale), pushing the successor much further out
let_a_finish.set()

# Wait for all 3 executions to complete
await asyncio.wait_for(worker_task, timeout=15)

assert len(executions) == 3

# The third execution's `when` tells us which on_complete won the race.
#
# If B's on_complete won (correct): when ≈ B_completion + 500ms (< 1s from replace)
# If A's on_complete won (stale): when ≈ A_completion + 2s (> 1s from replace)
successor = executions[2]
gap = successor.when - replace_time

assert gap < timedelta(seconds=1), (
f"Stale execution won the race: successor scheduled "
f"{gap.total_seconds():.2f}s after the correct replacement, "
f"expected < 1s (correct interval is {CORRECT_INTERVAL})"
)


call_count: int = 0


async def counting_task():
global call_count
call_count += 1


async def test_is_superseded_after_replace(docket: Docket):
"""An execution becomes superseded when the same key is rescheduled."""
await docket.add(counting_task, key="gen-test")()

# Build an Execution from the stream message to capture its generation
async with docket.redis() as redis:
messages = await redis.xrange(docket.stream_key, count=1)
_, message = messages[0]
original = await Execution.from_message(docket, message)

assert original.generation == 1
assert not await original.is_superseded()

# Replacing bumps the generation in the runs hash
await docket.replace(counting_task, datetime.now(timezone.utc), "gen-test")()

assert await original.is_superseded()


async def test_superseded_message_skipped_before_execution(
docket: Docket, worker: Worker
):
"""A stale message in the stream is skipped without running the function.

This covers the case where a message was already pending (e.g. after a
worker crash and redelivery) when the task was replaced. The runs hash
has a newer generation so the worker bails before claim().
"""
global call_count
call_count = 0

await docket.add(counting_task, key="head-check")()

# Bump the generation in the runs hash without touching the stream message.
# This simulates the state after a replace where the old message is still
# pending in the consumer group (e.g. redelivery after crash).
runs_key = docket.key("runs:head-check")
async with docket.redis() as redis:
await redis.hincrby(runs_key, "generation", 1) # type: ignore[misc]

await worker.run_until_finished()

assert call_count == 0, "superseded task should not have executed"
Loading