Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions loq.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ max_lines = 750
# Source files that still need exceptions above 750
[[rules]]
path = "src/docket/worker.py"
max_lines = 1125
max_lines = 1130

[[rules]]
path = "src/docket/cli.py"
max_lines = 945

[[rules]]
path = "src/docket/execution.py"
max_lines = 903
max_lines = 965

[[rules]]
path = "src/docket/docket.py"
Expand Down
8 changes: 8 additions & 0 deletions src/docket/dependencies/_perpetual.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ async def on_complete(self, execution: Execution, outcome: TaskOutcome) -> bool:
await docket._cancel(redis, execution.key)
return False

if await execution.is_superseded():
logger.info(
"↬ [%s] %s (superseded)",
format_duration(outcome.duration.total_seconds()),
execution.call_repr(),
)
return True

docket = self.docket.get()
worker = self.worker.get()

Expand Down
72 changes: 67 additions & 5 deletions src/docket/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
trace_context: opentelemetry.context.Context | None = None,
redelivered: bool = False,
function_name: str | None = None,
generation: int = 0,
) -> None:
# Task definition (immutable)
self._docket = docket
Expand All @@ -125,6 +126,7 @@ def __init__(
self.attempt = attempt
self._trace_context = trace_context
self._redelivered = redelivered
self._generation = generation

# Lifecycle state (mutable)
self.state: ExecutionState = ExecutionState.SCHEDULED
Expand Down Expand Up @@ -182,6 +184,11 @@ def redelivered(self) -> bool:
"""Whether this message was redelivered."""
return self._redelivered

@property
def generation(self) -> int:
"""Scheduling generation counter for supersession detection."""
return self._generation

@contextmanager
def _maybe_suppress_instrumentation(self) -> Generator[None, None, None]:
"""Suppress OTel auto-instrumentation for internal Redis operations."""
Expand All @@ -199,6 +206,7 @@ def as_message(self) -> Message:
b"args": cloudpickle.dumps(self.args),
b"kwargs": cloudpickle.dumps(self.kwargs),
b"attempt": str(self.attempt).encode(),
b"generation": str(self.generation).encode(),
}

@classmethod
Expand Down Expand Up @@ -228,6 +236,7 @@ async def from_message(
trace_context=propagate.extract(message, getter=message_getter),
redelivered=redelivered,
function_name=function_name,
generation=int(message.get(b"generation", b"0")),
)
await instance.sync()
return instance
Expand Down Expand Up @@ -339,6 +348,7 @@ async def schedule(
local function_name = nil
local args_data = nil
local kwargs_data = nil
local generation_index = nil

for i = 7, #ARGV, 2 do
local field_name = ARGV[i]
Expand All @@ -353,6 +363,8 @@ async def schedule(
args_data = field_value
elseif field_name == 'kwargs' then
kwargs_data = field_value
elseif field_name == 'generation' then
generation_index = #message
end
end

Expand All @@ -364,6 +376,12 @@ async def schedule(
redis.call('XACK', stream_key, worker_group_name, reschedule_message_id)
redis.call('XDEL', stream_key, reschedule_message_id)

-- Increment generation counter
local new_gen = redis.call('HINCRBY', runs_key, 'generation', 1)
if generation_index then
message[generation_index] = tostring(new_gen)
end

-- Park task data for future execution
redis.call('HSET', parked_key, unpack(message))

Expand Down Expand Up @@ -421,6 +439,12 @@ async def schedule(
end
end

-- Increment generation counter
local new_gen = redis.call('HINCRBY', runs_key, 'generation', 1)
if generation_index then
message[generation_index] = tostring(new_gen)
end

if is_immediate then
-- Add to stream for immediate execution
local message_id = redis.call('XADD', stream_key, '*', unpack(message))
Expand Down Expand Up @@ -500,18 +524,22 @@ async def schedule(
{"state": ExecutionState.SCHEDULED.value, "when": when.isoformat()}
)

async def claim(self, worker: str) -> None:
"""Atomically claim task and transition to RUNNING state.
async def claim(self, worker: str) -> bool:
"""Atomically check supersession and claim task in a single round-trip.

This consolidates worker operations when claiming a task into a single
atomic Lua script that:
- Checks if the task has been superseded by a newer generation
- Sets state to RUNNING with worker name and timestamp
- Initializes progress tracking (current=0, total=100)
- Deletes known/stream_id fields to allow task rescheduling
- Cleans up legacy keys for backwards compatibility

Args:
worker: Name of the worker claiming the task

Returns:
True if the task was claimed, False if it was superseded.
"""
started_at = datetime.now(timezone.utc)
started_at_iso = started_at.isoformat()
Expand All @@ -520,7 +548,7 @@ async def claim(self, worker: str) -> None:
async with self.docket.redis() as redis:
claim_script = redis.register_script(
# KEYS: runs_key, progress_key, known_key, stream_id_key
# ARGV: worker, started_at_iso
# ARGV: worker, started_at_iso, generation
"""
local runs_key = KEYS[1]
local progress_key = KEYS[2]
Expand All @@ -530,6 +558,15 @@ async def claim(self, worker: str) -> None:

local worker = ARGV[1]
local started_at = ARGV[2]
local generation = tonumber(ARGV[3])

-- Check supersession: generation > 0 means tracking is active
if generation > 0 then
local current = redis.call('HGET', runs_key, 'generation')
if current and tonumber(current) > generation then
return 'SUPERSEDED'
end
end

-- Update execution state to running
redis.call('HSET', runs_key,
Expand All @@ -554,16 +591,19 @@ async def claim(self, worker: str) -> None:
"""
)

await claim_script(
result = await claim_script(
keys=[
self._redis_key, # runs_key
self.progress._redis_key, # progress_key
self.docket.known_task_key(self.key), # legacy known_key
self.docket.stream_id_key(self.key), # legacy stream_id_key
],
args=[worker, started_at_iso],
args=[worker, started_at_iso, str(self._generation)],
)

if result == b"SUPERSEDED":
return False

# Update local state
self.state = ExecutionState.RUNNING
self.worker = worker
Expand All @@ -580,6 +620,8 @@ async def claim(self, worker: str) -> None:
}
)

return True

async def _mark_as_terminal(
self,
state: ExecutionState,
Expand Down Expand Up @@ -802,6 +844,26 @@ async def sync(self) -> None:
# Sync progress data
await self.progress.sync()

async def is_superseded(self) -> bool:
"""Check whether a newer schedule has superseded this execution.

Compares this execution's generation against the current generation
stored in the runs hash. If the stored generation is strictly greater,
this execution has been superseded by a newer schedule() call.

Generation 0 means the message predates generation tracking (e.g. it
was moved from queue to stream by an older worker's scheduler that
doesn't pass through the generation field). These are never considered
superseded since we can't tell.
"""
if self._generation == 0:
return False
with self._maybe_suppress_instrumentation():
async with self.docket.redis() as redis:
current = await redis.hget(self._redis_key, "generation")
current_gen = int(current) if current is not None else 0
return current_gen > self._generation

async def _publish_state(self, data: dict) -> None:
"""Publish state change to Redis pub/sub channel.

Expand Down
10 changes: 6 additions & 4 deletions src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ async def _scheduler_loop(self, redis: Redis) -> None:
'function', task['function'],
'args', task['args'],
'kwargs', task['kwargs'],
'attempt', task['attempt']
'attempt', task['attempt'],
'generation', task['generation'] or '0'
)
redis.call('DEL', hash_key)

Expand Down Expand Up @@ -817,9 +818,10 @@ async def _execute(self, execution: Execution) -> None:
"%s [%s] %s", arrow, format_duration(punctuality), call, extra=log_context
)

# Atomically claim task and transition to running state
# This also initializes progress and cleans up known/stream_id to allow rescheduling
await execution.claim(self.name)
# Atomically check supersession and claim task in a single round-trip
if not await execution.claim(self.name):
logger.info("↬ %s (superseded)", call, extra=log_context)
return

dependencies: dict[str, Dependency] = {}

Expand Down
Loading