1616 Mapping ,
1717 NoReturn ,
1818 ParamSpec ,
19+ Protocol ,
1920 Self ,
2021 Sequence ,
2122 TypedDict ,
2728import redis .exceptions
2829from opentelemetry import propagate , trace
2930from redis .asyncio import ConnectionPool , Redis
30- from redis .asyncio .client import Pipeline
3131from uuid_extensions import uuid7
3232
3333from .execution import (
5555tracer : trace .Tracer = trace .get_tracer (__name__ )
5656
5757
58+ class _schedule_task (Protocol ):
59+ async def __call__ (
60+ self , keys : list [str ], args : list [str | float | bytes ]
61+ ) -> str : ... # pragma: no cover
62+
63+
64+ class _cancel_task (Protocol ):
65+ async def __call__ (
66+ self , keys : list [str ], args : list [str ]
67+ ) -> str : ... # pragma: no cover
68+
69+
5870P = ParamSpec ("P" )
5971R = TypeVar ("R" )
6072
@@ -131,6 +143,8 @@ async def my_task(greeting: str, recipient: str) -> None:
131143
132144 _monitor_strikes_task : asyncio .Task [None ]
133145 _connection_pool : ConnectionPool
146+ _schedule_task_script : _schedule_task | None
147+ _cancel_task_script : _cancel_task | None
134148
135149 def __init__ (
136150 self ,
@@ -156,6 +170,8 @@ def __init__(
156170 self .url = url
157171 self .heartbeat_interval = heartbeat_interval
158172 self .missed_heartbeats = missed_heartbeats
173+ self ._schedule_task_script = None
174+ self ._cancel_task_script = None
159175
160176 @property
161177 def worker_group_name (self ) -> str :
@@ -300,9 +316,7 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
300316 execution = Execution (function , args , kwargs , when , key , attempt = 1 )
301317
302318 async with self .redis () as redis :
303- async with redis .pipeline () as pipeline :
304- await self ._schedule (redis , pipeline , execution , replace = False )
305- await pipeline .execute ()
319+ await self ._schedule (redis , execution , replace = False )
306320
307321 TASKS_ADDED .add (1 , {** self .labels (), ** execution .general_labels ()})
308322 TASKS_SCHEDULED .add (1 , {** self .labels (), ** execution .general_labels ()})
@@ -361,9 +375,7 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
361375 execution = Execution (function , args , kwargs , when , key , attempt = 1 )
362376
363377 async with self .redis () as redis :
364- async with redis .pipeline () as pipeline :
365- await self ._schedule (redis , pipeline , execution , replace = True )
366- await pipeline .execute ()
378+ await self ._schedule (redis , execution , replace = True )
367379
368380 TASKS_REPLACED .add (1 , {** self .labels (), ** execution .general_labels ()})
369381 TASKS_CANCELLED .add (1 , {** self .labels (), ** execution .general_labels ()})
@@ -383,9 +395,7 @@ async def schedule(self, execution: Execution) -> None:
383395 },
384396 ):
385397 async with self .redis () as redis :
386- async with redis .pipeline () as pipeline :
387- await self ._schedule (redis , pipeline , execution , replace = False )
388- await pipeline .execute ()
398+ await self ._schedule (redis , execution , replace = False )
389399
390400 TASKS_SCHEDULED .add (1 , {** self .labels (), ** execution .general_labels ()})
391401
@@ -400,9 +410,7 @@ async def cancel(self, key: str) -> None:
400410 attributes = {** self .labels (), "docket.key" : key },
401411 ):
402412 async with self .redis () as redis :
403- async with redis .pipeline () as pipeline :
404- await self ._cancel (pipeline , key )
405- await pipeline .execute ()
413+ await self ._cancel (redis , key )
406414
407415 TASKS_CANCELLED .add (1 , self .labels ())
408416
@@ -423,10 +431,17 @@ def parked_task_key(self, key: str) -> str:
423431 async def _schedule (
424432 self ,
425433 redis : Redis ,
426- pipeline : Pipeline ,
427434 execution : Execution ,
428435 replace : bool = False ,
429436 ) -> None :
437+ """Schedule a task atomically.
438+
439+ Handles:
440+ - Checking for task existence
441+ - Cancelling existing tasks when replacing
442+ - Adding tasks to stream (immediate) or queue (future)
443+ - Tracking stream message IDs for later cancellation
444+ """
430445 if self .strike_list .is_stricken (execution ):
431446 logger .warning (
432447 "%r is stricken, skipping schedule of %r" ,
@@ -449,32 +464,133 @@ async def _schedule(
449464 key = execution .key
450465 when = execution .when
451466 known_task_key = self .known_task_key (key )
467+ is_immediate = when <= datetime .now (timezone .utc )
452468
469+ # Lock per task key to prevent race conditions between concurrent operations
453470 async with redis .lock (f"{ known_task_key } :lock" , timeout = 10 ):
454- if replace :
455- await self ._cancel (pipeline , key )
456- else :
457- # if the task is already in the queue or stream, retain it
458- if await redis .exists (known_task_key ):
459- logger .debug (
460- "Task %r is already in the queue or stream, not scheduling" ,
461- key ,
462- extra = self .labels (),
463- )
464- return
471+ if self ._schedule_task_script is None :
472+ self ._schedule_task_script = cast (
473+ _schedule_task ,
474+ redis .register_script (
475+ # KEYS: stream_key, known_key, parked_key, queue_key
476+ # ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields
477+ """
478+ local stream_key = KEYS[1]
479+ local known_key = KEYS[2]
480+ local parked_key = KEYS[3]
481+ local queue_key = KEYS[4]
482+
483+ local task_key = ARGV[1]
484+ local when_timestamp = ARGV[2]
485+ local is_immediate = ARGV[3] == '1'
486+ local replace = ARGV[4] == '1'
487+
488+ -- Extract message fields from ARGV[5] onwards
489+ local message = {}
490+ for i = 5, #ARGV, 2 do
491+ message[#message + 1] = ARGV[i] -- field name
492+ message[#message + 1] = ARGV[i + 1] -- field value
493+ end
494+
495+ -- Handle replacement: cancel existing task if needed
496+ if replace then
497+ local existing_message_id = redis.call('HGET', known_key, 'stream_message_id')
498+ if existing_message_id then
499+ redis.call('XDEL', stream_key, existing_message_id)
500+ end
501+ redis.call('DEL', known_key, parked_key)
502+ redis.call('ZREM', queue_key, task_key)
503+ else
504+ -- Check if task already exists
505+ if redis.call('EXISTS', known_key) == 1 then
506+ return 'EXISTS'
507+ end
508+ end
509+
510+ if is_immediate then
511+ -- Add to stream and store message ID for later cancellation
512+ local message_id = redis.call('XADD', stream_key, '*', unpack(message))
513+ redis.call('HSET', known_key, 'when', when_timestamp, 'stream_message_id', message_id)
514+ return message_id
515+ else
516+ -- Add to queue with task data in parked hash
517+ redis.call('HSET', known_key, 'when', when_timestamp)
518+ redis.call('HSET', parked_key, unpack(message))
519+ redis.call('ZADD', queue_key, when_timestamp, task_key)
520+ return 'QUEUED'
521+ end
522+ """
523+ ),
524+ )
525+ schedule_task = self ._schedule_task_script
465526
466- pipeline .set (known_task_key , when .timestamp ())
527+ await schedule_task (
528+ keys = [
529+ self .stream_key ,
530+ known_task_key ,
531+ self .parked_task_key (key ),
532+ self .queue_key ,
533+ ],
534+ args = [
535+ key ,
536+ str (when .timestamp ()),
537+ "1" if is_immediate else "0" ,
538+ "1" if replace else "0" ,
539+ * [
540+ item
541+ for field , value in message .items ()
542+ for item in (field , value )
543+ ],
544+ ],
545+ )
467546
468- if when <= datetime .now (timezone .utc ):
469- pipeline .xadd (self .stream_key , message ) # type: ignore[arg-type]
470- else :
471- pipeline .hset (self .parked_task_key (key ), mapping = message ) # type: ignore[arg-type]
472- pipeline .zadd (self .queue_key , {key : when .timestamp ()})
547+ async def _cancel (self , redis : Redis , key : str ) -> None :
548+ """Cancel a task atomically.
473549
474- async def _cancel (self , pipeline : Pipeline , key : str ) -> None :
475- pipeline .delete (self .known_task_key (key ))
476- pipeline .delete (self .parked_task_key (key ))
477- pipeline .zrem (self .queue_key , key )
550+ Handles cancellation regardless of task location:
551+ - From the stream (using stored message ID)
552+ - From the queue (scheduled tasks)
553+ - Cleans up all associated metadata keys
554+ """
555+ if self ._cancel_task_script is None :
556+ self ._cancel_task_script = cast (
557+ _cancel_task ,
558+ redis .register_script (
559+ # KEYS: stream_key, known_key, parked_key, queue_key
560+ # ARGV: task_key
561+ """
562+ local stream_key = KEYS[1]
563+ local known_key = KEYS[2]
564+ local parked_key = KEYS[3]
565+ local queue_key = KEYS[4]
566+ local task_key = ARGV[1]
567+
568+ -- Delete from stream if message ID exists
569+ local message_id = redis.call('HGET', known_key, 'stream_message_id')
570+ if message_id then
571+ redis.call('XDEL', stream_key, message_id)
572+ end
573+
574+ -- Clean up all task-related keys
575+ redis.call('DEL', known_key, parked_key)
576+ redis.call('ZREM', queue_key, task_key)
577+
578+ return 'OK'
579+ """
580+ ),
581+ )
582+ cancel_task = self ._cancel_task_script
583+
584+ # Execute the cancellation script
585+ await cancel_task (
586+ keys = [
587+ self .stream_key ,
588+ self .known_task_key (key ),
589+ self .parked_task_key (key ),
590+ self .queue_key ,
591+ ],
592+ args = [key ],
593+ )
478594
479595 @property
480596 def strike_key (self ) -> str :
0 commit comments