Skip to content

Commit c5e78b8

Browse files
authored
Reverting fixes (#155)
These fixes seem to get at least one of my systems into a state where it's infinitely looping on perpetual tasks. This may have been due to some of the intermediate problems in 0.9.0 and 0.9.1, but I don't want to take that chance. We'll come back and revisit this. This reverts commit 85f293c. Revert "Fix WRONGTYPE error and add memory leak detection tests (#153)" This reverts commit c2b8187. Revert "Fix race condition in task replacement where duplicate executions occur (#151)" This reverts commit 9dc405c.
1 parent 85f293c commit c5e78b8

File tree

4 files changed

+57
-529
lines changed

4 files changed

+57
-529
lines changed

src/docket/docket.py

Lines changed: 35 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
Mapping,
1717
NoReturn,
1818
ParamSpec,
19-
Protocol,
2019
Self,
2120
Sequence,
2221
TypedDict,
@@ -28,6 +27,7 @@
2827
import redis.exceptions
2928
from opentelemetry import propagate, trace
3029
from redis.asyncio import ConnectionPool, Redis
30+
from redis.asyncio.client import Pipeline
3131
from uuid_extensions import uuid7
3232

3333
from .execution import (
@@ -55,18 +55,6 @@
5555
tracer: trace.Tracer = trace.get_tracer(__name__)
5656

5757

58-
class _schedule_task(Protocol):
59-
async def __call__(
60-
self, keys: list[str], args: list[str | float | bytes]
61-
) -> str: ... # pragma: no cover
62-
63-
64-
class _cancel_task(Protocol):
65-
async def __call__(
66-
self, keys: list[str], args: list[str]
67-
) -> str: ... # pragma: no cover
68-
69-
7058
P = ParamSpec("P")
7159
R = TypeVar("R")
7260

@@ -143,8 +131,6 @@ async def my_task(greeting: str, recipient: str) -> None:
143131

144132
_monitor_strikes_task: asyncio.Task[None]
145133
_connection_pool: ConnectionPool
146-
_schedule_task_script: _schedule_task | None
147-
_cancel_task_script: _cancel_task | None
148134

149135
def __init__(
150136
self,
@@ -170,8 +156,6 @@ def __init__(
170156
self.url = url
171157
self.heartbeat_interval = heartbeat_interval
172158
self.missed_heartbeats = missed_heartbeats
173-
self._schedule_task_script = None
174-
self._cancel_task_script = None
175159

176160
@property
177161
def worker_group_name(self) -> str:
@@ -316,7 +300,9 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
316300
execution = Execution(function, args, kwargs, when, key, attempt=1)
317301

318302
async with self.redis() as redis:
319-
await self._schedule(redis, execution, replace=False)
303+
async with redis.pipeline() as pipeline:
304+
await self._schedule(redis, pipeline, execution, replace=False)
305+
await pipeline.execute()
320306

321307
TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()})
322308
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
@@ -375,7 +361,9 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
375361
execution = Execution(function, args, kwargs, when, key, attempt=1)
376362

377363
async with self.redis() as redis:
378-
await self._schedule(redis, execution, replace=True)
364+
async with redis.pipeline() as pipeline:
365+
await self._schedule(redis, pipeline, execution, replace=True)
366+
await pipeline.execute()
379367

380368
TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()})
381369
TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()})
@@ -395,7 +383,9 @@ async def schedule(self, execution: Execution) -> None:
395383
},
396384
):
397385
async with self.redis() as redis:
398-
await self._schedule(redis, execution, replace=False)
386+
async with redis.pipeline() as pipeline:
387+
await self._schedule(redis, pipeline, execution, replace=False)
388+
await pipeline.execute()
399389

400390
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
401391

@@ -410,7 +400,9 @@ async def cancel(self, key: str) -> None:
410400
attributes={**self.labels(), "docket.key": key},
411401
):
412402
async with self.redis() as redis:
413-
await self._cancel(redis, key)
403+
async with redis.pipeline() as pipeline:
404+
await self._cancel(pipeline, key)
405+
await pipeline.execute()
414406

415407
TASKS_CANCELLED.add(1, self.labels())
416408

@@ -428,23 +420,13 @@ def known_task_key(self, key: str) -> str:
428420
def parked_task_key(self, key: str) -> str:
429421
return f"{self.name}:{key}"
430422

431-
def stream_id_key(self, key: str) -> str:
432-
return f"{self.name}:stream-id:{key}"
433-
434423
async def _schedule(
435424
self,
436425
redis: Redis,
426+
pipeline: Pipeline,
437427
execution: Execution,
438428
replace: bool = False,
439429
) -> None:
440-
"""Schedule a task atomically.
441-
442-
Handles:
443-
- Checking for task existence
444-
- Cancelling existing tasks when replacing
445-
- Adding tasks to stream (immediate) or queue (future)
446-
- Tracking stream message IDs for later cancellation
447-
"""
448430
if self.strike_list.is_stricken(execution):
449431
logger.warning(
450432
"%r is stricken, skipping schedule of %r",
@@ -467,138 +449,32 @@ async def _schedule(
467449
key = execution.key
468450
when = execution.when
469451
known_task_key = self.known_task_key(key)
470-
is_immediate = when <= datetime.now(timezone.utc)
471452

472-
# Lock per task key to prevent race conditions between concurrent operations
473453
async with redis.lock(f"{known_task_key}:lock", timeout=10):
474-
if self._schedule_task_script is None:
475-
self._schedule_task_script = cast(
476-
_schedule_task,
477-
redis.register_script(
478-
# KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key
479-
# ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields
480-
"""
481-
local stream_key = KEYS[1]
482-
local known_key = KEYS[2]
483-
local parked_key = KEYS[3]
484-
local queue_key = KEYS[4]
485-
local stream_id_key = KEYS[5]
486-
487-
local task_key = ARGV[1]
488-
local when_timestamp = ARGV[2]
489-
local is_immediate = ARGV[3] == '1'
490-
local replace = ARGV[4] == '1'
491-
492-
-- Extract message fields from ARGV[5] onwards
493-
local message = {}
494-
for i = 5, #ARGV, 2 do
495-
message[#message + 1] = ARGV[i] -- field name
496-
message[#message + 1] = ARGV[i + 1] -- field value
497-
end
498-
499-
-- Handle replacement: cancel existing task if needed
500-
if replace then
501-
local existing_message_id = redis.call('GET', stream_id_key)
502-
if existing_message_id then
503-
redis.call('XDEL', stream_key, existing_message_id)
504-
end
505-
redis.call('DEL', known_key, parked_key, stream_id_key)
506-
redis.call('ZREM', queue_key, task_key)
507-
else
508-
-- Check if task already exists
509-
if redis.call('EXISTS', known_key) == 1 then
510-
return 'EXISTS'
511-
end
512-
end
513-
514-
if is_immediate then
515-
-- Add to stream and store message ID for later cancellation
516-
local message_id = redis.call('XADD', stream_key, '*', unpack(message))
517-
redis.call('SET', known_key, when_timestamp)
518-
redis.call('SET', stream_id_key, message_id)
519-
return message_id
520-
else
521-
-- Add to queue with task data in parked hash
522-
redis.call('SET', known_key, when_timestamp)
523-
redis.call('HSET', parked_key, unpack(message))
524-
redis.call('ZADD', queue_key, when_timestamp, task_key)
525-
return 'QUEUED'
526-
end
527-
"""
528-
),
529-
)
530-
schedule_task = self._schedule_task_script
454+
if replace:
455+
await self._cancel(pipeline, key)
456+
else:
457+
# if the task is already in the queue or stream, retain it
458+
if await redis.exists(known_task_key):
459+
logger.debug(
460+
"Task %r is already in the queue or stream, not scheduling",
461+
key,
462+
extra=self.labels(),
463+
)
464+
return
531465

532-
await schedule_task(
533-
keys=[
534-
self.stream_key,
535-
known_task_key,
536-
self.parked_task_key(key),
537-
self.queue_key,
538-
self.stream_id_key(key),
539-
],
540-
args=[
541-
key,
542-
str(when.timestamp()),
543-
"1" if is_immediate else "0",
544-
"1" if replace else "0",
545-
*[
546-
item
547-
for field, value in message.items()
548-
for item in (field, value)
549-
],
550-
],
551-
)
466+
pipeline.set(known_task_key, when.timestamp())
552467

553-
async def _cancel(self, redis: Redis, key: str) -> None:
554-
"""Cancel a task atomically.
468+
if when <= datetime.now(timezone.utc):
469+
pipeline.xadd(self.stream_key, message) # type: ignore[arg-type]
470+
else:
471+
pipeline.hset(self.parked_task_key(key), mapping=message) # type: ignore[arg-type]
472+
pipeline.zadd(self.queue_key, {key: when.timestamp()})
555473

556-
Handles cancellation regardless of task location:
557-
- From the stream (using stored message ID)
558-
- From the queue (scheduled tasks)
559-
- Cleans up all associated metadata keys
560-
"""
561-
if self._cancel_task_script is None:
562-
self._cancel_task_script = cast(
563-
_cancel_task,
564-
redis.register_script(
565-
# KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key
566-
# ARGV: task_key
567-
"""
568-
local stream_key = KEYS[1]
569-
local known_key = KEYS[2]
570-
local parked_key = KEYS[3]
571-
local queue_key = KEYS[4]
572-
local stream_id_key = KEYS[5]
573-
local task_key = ARGV[1]
574-
575-
-- Delete from stream if message ID exists
576-
local message_id = redis.call('GET', stream_id_key)
577-
if message_id then
578-
redis.call('XDEL', stream_key, message_id)
579-
end
580-
581-
-- Clean up all task-related keys
582-
redis.call('DEL', known_key, parked_key, stream_id_key)
583-
redis.call('ZREM', queue_key, task_key)
584-
585-
return 'OK'
586-
"""
587-
),
588-
)
589-
cancel_task = self._cancel_task_script
590-
591-
# Execute the cancellation script
592-
await cancel_task(
593-
keys=[
594-
self.stream_key,
595-
self.known_task_key(key),
596-
self.parked_task_key(key),
597-
self.queue_key,
598-
self.stream_id_key(key),
599-
],
600-
args=[key],
601-
)
474+
async def _cancel(self, pipeline: Pipeline, key: str) -> None:
475+
pipeline.delete(self.known_task_key(key))
476+
pipeline.delete(self.parked_task_key(key))
477+
pipeline.zrem(self.queue_key, key)
602478

603479
@property
604480
def strike_key(self) -> str:
@@ -905,7 +781,6 @@ async def clear(self) -> int:
905781
key = key_bytes.decode()
906782
pipeline.delete(self.parked_task_key(key))
907783
pipeline.delete(self.known_task_key(key))
908-
pipeline.delete(self.stream_id_key(key))
909784

910785
await pipeline.execute()
911786

src/docket/worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,7 @@ async def _delete_known_task(
495495

496496
logger.debug("Deleting known task", extra=self._log_context())
497497
known_task_key = self.docket.known_task_key(key)
498-
stream_id_key = self.docket.stream_id_key(key)
499-
await redis.delete(known_task_key, stream_id_key)
498+
await redis.delete(known_task_key)
500499

501500
async def _execute(self, execution: Execution) -> None:
502501
log_context = {**self._log_context(), **execution.specific_labels()}

tests/test_fundamentals.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,13 @@ async def test_adding_is_idempotent(
104104
assert soon <= now() < later
105105

106106

107+
@pytest.mark.skip(
108+
"Temporarily skipping due to test flake for task rescheduling. "
109+
"See https://github.com/chrisguidry/docket/issues/149"
110+
)
107111
async def test_rescheduling_later(
108112
docket: Docket, worker: Worker, the_task: AsyncMock, now: Callable[[], datetime]
109-
):
113+
): # pragma: no cover
110114
"""docket should allow for rescheduling a task for later"""
111115

112116
key = f"my-cool-task:{uuid4()}"
@@ -250,39 +254,18 @@ async def test_cancelling_future_task(
250254
the_task.assert_not_called()
251255

252256

253-
async def test_cancelling_immediate_task(
257+
async def test_cancelling_current_task_not_supported(
254258
docket: Docket, worker: Worker, the_task: AsyncMock, now: Callable[[], datetime]
255259
):
256-
"""docket can cancel a task that is scheduled immediately"""
260+
"""docket does not allow cancelling a task that is schedule now"""
257261

258262
execution = await docket.add(the_task, now())("a", "b", c="c")
259263

260264
await docket.cancel(execution.key)
261265

262266
await worker.run_until_finished()
263267

264-
the_task.assert_not_called()
265-
266-
267-
async def test_cancellation_is_idempotent(
268-
docket: Docket, worker: Worker, the_task: AsyncMock, now: Callable[[], datetime]
269-
):
270-
"""Test that canceling the same task twice doesn't error."""
271-
key = f"test-task:{uuid4()}"
272-
273-
# Schedule a task
274-
later = now() + timedelta(seconds=1)
275-
await docket.add(the_task, later, key=key)("test")
276-
277-
# Cancel it twice - both should succeed without error
278-
await docket.cancel(key)
279-
await docket.cancel(key) # Should be idempotent
280-
281-
# Run worker to ensure the task was actually cancelled
282-
await worker.run_until_finished()
283-
284-
# Task should not have been executed since it was cancelled
285-
the_task.assert_not_called()
268+
the_task.assert_awaited_once_with("a", "b", c="c")
286269

287270

288271
async def test_errors_are_logged(

0 commit comments

Comments
 (0)