Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 151 additions & 35 deletions src/docket/docket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Mapping,
NoReturn,
ParamSpec,
Protocol,
Self,
Sequence,
TypedDict,
Expand All @@ -27,7 +28,6 @@
import redis.exceptions
from opentelemetry import propagate, trace
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.client import Pipeline
from uuid_extensions import uuid7

from .execution import (
Expand Down Expand Up @@ -55,6 +55,18 @@
tracer: trace.Tracer = trace.get_tracer(__name__)


class _schedule_task(Protocol):
async def __call__(
self, keys: list[str], args: list[str | float | bytes]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ... is a cool idea

) -> str: ... # pragma: no cover


class _cancel_task(Protocol):
async def __call__(
self, keys: list[str], args: list[str]
) -> str: ... # pragma: no cover


P = ParamSpec("P")
R = TypeVar("R")

Expand Down Expand Up @@ -131,6 +143,8 @@ async def my_task(greeting: str, recipient: str) -> None:

_monitor_strikes_task: asyncio.Task[None]
_connection_pool: ConnectionPool
_schedule_task_script: _schedule_task | None
_cancel_task_script: _cancel_task | None

def __init__(
self,
Expand All @@ -156,6 +170,8 @@ def __init__(
self.url = url
self.heartbeat_interval = heartbeat_interval
self.missed_heartbeats = missed_heartbeats
self._schedule_task_script = None
self._cancel_task_script = None

@property
def worker_group_name(self) -> str:
Expand Down Expand Up @@ -300,9 +316,7 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
execution = Execution(function, args, kwargs, when, key, attempt=1)

async with self.redis() as redis:
async with redis.pipeline() as pipeline:
await self._schedule(redis, pipeline, execution, replace=False)
await pipeline.execute()
await self._schedule(redis, execution, replace=False)

TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()})
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
Expand Down Expand Up @@ -361,9 +375,7 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
execution = Execution(function, args, kwargs, when, key, attempt=1)

async with self.redis() as redis:
async with redis.pipeline() as pipeline:
await self._schedule(redis, pipeline, execution, replace=True)
await pipeline.execute()
await self._schedule(redis, execution, replace=True)

TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()})
TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()})
Expand All @@ -383,9 +395,7 @@ async def schedule(self, execution: Execution) -> None:
},
):
async with self.redis() as redis:
async with redis.pipeline() as pipeline:
await self._schedule(redis, pipeline, execution, replace=False)
await pipeline.execute()
await self._schedule(redis, execution, replace=False)

TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})

Expand All @@ -400,9 +410,7 @@ async def cancel(self, key: str) -> None:
attributes={**self.labels(), "docket.key": key},
):
async with self.redis() as redis:
async with redis.pipeline() as pipeline:
await self._cancel(pipeline, key)
await pipeline.execute()
await self._cancel(redis, key)

TASKS_CANCELLED.add(1, self.labels())

Expand All @@ -423,10 +431,17 @@ def parked_task_key(self, key: str) -> str:
async def _schedule(
self,
redis: Redis,
pipeline: Pipeline,
execution: Execution,
replace: bool = False,
) -> None:
"""Schedule a task atomically.

Handles:
- Checking for task existence
- Cancelling existing tasks when replacing
- Adding tasks to stream (immediate) or queue (future)
- Tracking stream message IDs for later cancellation
"""
if self.strike_list.is_stricken(execution):
logger.warning(
"%r is stricken, skipping schedule of %r",
Expand All @@ -449,32 +464,133 @@ async def _schedule(
key = execution.key
when = execution.when
known_task_key = self.known_task_key(key)
is_immediate = when <= datetime.now(timezone.utc)

# Lock per task key to prevent race conditions between concurrent operations
async with redis.lock(f"{known_task_key}:lock", timeout=10):
if replace:
await self._cancel(pipeline, key)
else:
# if the task is already in the queue or stream, retain it
if await redis.exists(known_task_key):
logger.debug(
"Task %r is already in the queue or stream, not scheduling",
key,
extra=self.labels(),
)
return
if self._schedule_task_script is None:
self._schedule_task_script = cast(
_schedule_task,
redis.register_script(
# KEYS: stream_key, known_key, parked_key, queue_key
# ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields
"""
local stream_key = KEYS[1]
local known_key = KEYS[2]
local parked_key = KEYS[3]
local queue_key = KEYS[4]

local task_key = ARGV[1]
local when_timestamp = ARGV[2]
local is_immediate = ARGV[3] == '1'
local replace = ARGV[4] == '1'

-- Extract message fields from ARGV[5] onwards
local message = {}
for i = 5, #ARGV, 2 do
message[#message + 1] = ARGV[i] -- field name
message[#message + 1] = ARGV[i + 1] -- field value
end

-- Handle replacement: cancel existing task if needed
if replace then
local existing_message_id = redis.call('HGET', known_key, 'stream_message_id')
if existing_message_id then
redis.call('XDEL', stream_key, existing_message_id)
end
redis.call('DEL', known_key, parked_key)
redis.call('ZREM', queue_key, task_key)
else
-- Check if task already exists
if redis.call('EXISTS', known_key) == 1 then
return 'EXISTS'
end
end

if is_immediate then
-- Add to stream and store message ID for later cancellation
local message_id = redis.call('XADD', stream_key, '*', unpack(message))
redis.call('HSET', known_key, 'when', when_timestamp, 'stream_message_id', message_id)
return message_id
else
-- Add to queue with task data in parked hash
redis.call('HSET', known_key, 'when', when_timestamp)
redis.call('HSET', parked_key, unpack(message))
redis.call('ZADD', queue_key, when_timestamp, task_key)
return 'QUEUED'
end
"""
),
)
schedule_task = self._schedule_task_script

pipeline.set(known_task_key, when.timestamp())
await schedule_task(
keys=[
self.stream_key,
known_task_key,
self.parked_task_key(key),
self.queue_key,
],
args=[
key,
str(when.timestamp()),
"1" if is_immediate else "0",
"1" if replace else "0",
*[
item
for field, value in message.items()
for item in (field, value)
],
],
)

if when <= datetime.now(timezone.utc):
pipeline.xadd(self.stream_key, message) # type: ignore[arg-type]
else:
pipeline.hset(self.parked_task_key(key), mapping=message) # type: ignore[arg-type]
pipeline.zadd(self.queue_key, {key: when.timestamp()})
async def _cancel(self, redis: Redis, key: str) -> None:
"""Cancel a task atomically.

async def _cancel(self, pipeline: Pipeline, key: str) -> None:
pipeline.delete(self.known_task_key(key))
pipeline.delete(self.parked_task_key(key))
pipeline.zrem(self.queue_key, key)
Handles cancellation regardless of task location:
- From the stream (using stored message ID)
- From the queue (scheduled tasks)
- Cleans up all associated metadata keys
"""
if self._cancel_task_script is None:
self._cancel_task_script = cast(
_cancel_task,
redis.register_script(
# KEYS: stream_key, known_key, parked_key, queue_key
# ARGV: task_key
"""
local stream_key = KEYS[1]
local known_key = KEYS[2]
local parked_key = KEYS[3]
local queue_key = KEYS[4]
local task_key = ARGV[1]

-- Delete from stream if message ID exists
local message_id = redis.call('HGET', known_key, 'stream_message_id')
if message_id then
redis.call('XDEL', stream_key, message_id)
end

-- Clean up all task-related keys
redis.call('DEL', known_key, parked_key)
redis.call('ZREM', queue_key, task_key)

return 'OK'
"""
),
)
cancel_task = self._cancel_task_script

# Execute the cancellation script
await cancel_task(
keys=[
self.stream_key,
self.known_task_key(key),
self.parked_task_key(key),
self.queue_key,
],
args=[key],
)

@property
def strike_key(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,17 @@ async def _scheduler_loop(
task[task_data[j]] = task_data[j+1]
end

redis.call('XADD', KEYS[2], '*',
local message_id = redis.call('XADD', KEYS[2], '*',
'key', task['key'],
'when', task['when'],
'function', task['function'],
'args', task['args'],
'kwargs', task['kwargs'],
'attempt', task['attempt']
)
-- Store the message ID in the known task key
local known_key = ARGV[2] .. ":known:" .. key
redis.call('HSET', known_key, 'stream_message_id', message_id)
redis.call('DEL', hash_key)
due_work = due_work + 1
end
Expand Down
27 changes: 24 additions & 3 deletions tests/test_fundamentals.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,39 @@ async def test_cancelling_future_task(
the_task.assert_not_called()


async def test_cancelling_current_task_not_supported(
async def test_cancelling_immediate_task(
docket: Docket, worker: Worker, the_task: AsyncMock, now: Callable[[], datetime]
):
"""docket does not allow cancelling a task that is schedule now"""
"""docket can cancel a task that is scheduled immediately"""

execution = await docket.add(the_task, now())("a", "b", c="c")

await docket.cancel(execution.key)

await worker.run_until_finished()

the_task.assert_awaited_once_with("a", "b", c="c")
the_task.assert_not_called()


async def test_cancellation_is_idempotent(
docket: Docket, worker: Worker, the_task: AsyncMock, now: Callable[[], datetime]
):
"""Test that canceling the same task twice doesn't error."""
key = f"test-task:{uuid4()}"

# Schedule a task
later = now() + timedelta(seconds=1)
await docket.add(the_task, later, key=key)("test")

# Cancel it twice - both should succeed without error
await docket.cancel(key)
await docket.cancel(key) # Should be idempotent

# Run worker to ensure the task was actually cancelled
await worker.run_until_finished()

# Task should not have been executed since it was cancelled
the_task.assert_not_called()


async def test_errors_are_logged(
Expand Down
Loading