1616 Mapping ,
1717 NoReturn ,
1818 ParamSpec ,
19- Protocol ,
2019 Self ,
2120 Sequence ,
2221 TypedDict ,
2827import redis .exceptions
2928from opentelemetry import propagate , trace
3029from redis .asyncio import ConnectionPool , Redis
30+ from redis .asyncio .client import Pipeline
3131from uuid_extensions import uuid7
3232
3333from .execution import (
5555tracer : trace .Tracer = trace .get_tracer (__name__ )
5656
5757
58- class _schedule_task (Protocol ):
59- async def __call__ (
60- self , keys : list [str ], args : list [str | float | bytes ]
61- ) -> str : ... # pragma: no cover
62-
63-
64- class _cancel_task (Protocol ):
65- async def __call__ (
66- self , keys : list [str ], args : list [str ]
67- ) -> str : ... # pragma: no cover
68-
69-
7058P = ParamSpec ("P" )
7159R = TypeVar ("R" )
7260
@@ -143,8 +131,6 @@ async def my_task(greeting: str, recipient: str) -> None:
143131
144132 _monitor_strikes_task : asyncio .Task [None ]
145133 _connection_pool : ConnectionPool
146- _schedule_task_script : _schedule_task | None
147- _cancel_task_script : _cancel_task | None
148134
149135 def __init__ (
150136 self ,
@@ -170,8 +156,6 @@ def __init__(
170156 self .url = url
171157 self .heartbeat_interval = heartbeat_interval
172158 self .missed_heartbeats = missed_heartbeats
173- self ._schedule_task_script = None
174- self ._cancel_task_script = None
175159
176160 @property
177161 def worker_group_name (self ) -> str :
@@ -316,7 +300,9 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
316300 execution = Execution (function , args , kwargs , when , key , attempt = 1 )
317301
318302 async with self .redis () as redis :
319- await self ._schedule (redis , execution , replace = False )
303+ async with redis .pipeline () as pipeline :
304+ await self ._schedule (redis , pipeline , execution , replace = False )
305+ await pipeline .execute ()
320306
321307 TASKS_ADDED .add (1 , {** self .labels (), ** execution .general_labels ()})
322308 TASKS_SCHEDULED .add (1 , {** self .labels (), ** execution .general_labels ()})
@@ -375,7 +361,9 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
375361 execution = Execution (function , args , kwargs , when , key , attempt = 1 )
376362
377363 async with self .redis () as redis :
378- await self ._schedule (redis , execution , replace = True )
364+ async with redis .pipeline () as pipeline :
365+ await self ._schedule (redis , pipeline , execution , replace = True )
366+ await pipeline .execute ()
379367
380368 TASKS_REPLACED .add (1 , {** self .labels (), ** execution .general_labels ()})
381369 TASKS_CANCELLED .add (1 , {** self .labels (), ** execution .general_labels ()})
@@ -395,7 +383,9 @@ async def schedule(self, execution: Execution) -> None:
395383 },
396384 ):
397385 async with self .redis () as redis :
398- await self ._schedule (redis , execution , replace = False )
386+ async with redis .pipeline () as pipeline :
387+ await self ._schedule (redis , pipeline , execution , replace = False )
388+ await pipeline .execute ()
399389
400390 TASKS_SCHEDULED .add (1 , {** self .labels (), ** execution .general_labels ()})
401391
@@ -410,7 +400,9 @@ async def cancel(self, key: str) -> None:
410400 attributes = {** self .labels (), "docket.key" : key },
411401 ):
412402 async with self .redis () as redis :
413- await self ._cancel (redis , key )
403+ async with redis .pipeline () as pipeline :
404+ await self ._cancel (pipeline , key )
405+ await pipeline .execute ()
414406
415407 TASKS_CANCELLED .add (1 , self .labels ())
416408
@@ -428,23 +420,13 @@ def known_task_key(self, key: str) -> str:
428420 def parked_task_key (self , key : str ) -> str :
429421 return f"{ self .name } :{ key } "
430422
431- def stream_id_key (self , key : str ) -> str :
432- return f"{ self .name } :stream-id:{ key } "
433-
434423 async def _schedule (
435424 self ,
436425 redis : Redis ,
426+ pipeline : Pipeline ,
437427 execution : Execution ,
438428 replace : bool = False ,
439429 ) -> None :
440- """Schedule a task atomically.
441-
442- Handles:
443- - Checking for task existence
444- - Cancelling existing tasks when replacing
445- - Adding tasks to stream (immediate) or queue (future)
446- - Tracking stream message IDs for later cancellation
447- """
448430 if self .strike_list .is_stricken (execution ):
449431 logger .warning (
450432 "%r is stricken, skipping schedule of %r" ,
@@ -467,138 +449,32 @@ async def _schedule(
467449 key = execution .key
468450 when = execution .when
469451 known_task_key = self .known_task_key (key )
470- is_immediate = when <= datetime .now (timezone .utc )
471452
472- # Lock per task key to prevent race conditions between concurrent operations
473453 async with redis .lock (f"{ known_task_key } :lock" , timeout = 10 ):
474- if self ._schedule_task_script is None :
475- self ._schedule_task_script = cast (
476- _schedule_task ,
477- redis .register_script (
478- # KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key
479- # ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields
480- """
481- local stream_key = KEYS[1]
482- local known_key = KEYS[2]
483- local parked_key = KEYS[3]
484- local queue_key = KEYS[4]
485- local stream_id_key = KEYS[5]
486-
487- local task_key = ARGV[1]
488- local when_timestamp = ARGV[2]
489- local is_immediate = ARGV[3] == '1'
490- local replace = ARGV[4] == '1'
491-
492- -- Extract message fields from ARGV[5] onwards
493- local message = {}
494- for i = 5, #ARGV, 2 do
495- message[#message + 1] = ARGV[i] -- field name
496- message[#message + 1] = ARGV[i + 1] -- field value
497- end
498-
499- -- Handle replacement: cancel existing task if needed
500- if replace then
501- local existing_message_id = redis.call('GET', stream_id_key)
502- if existing_message_id then
503- redis.call('XDEL', stream_key, existing_message_id)
504- end
505- redis.call('DEL', known_key, parked_key, stream_id_key)
506- redis.call('ZREM', queue_key, task_key)
507- else
508- -- Check if task already exists
509- if redis.call('EXISTS', known_key) == 1 then
510- return 'EXISTS'
511- end
512- end
513-
514- if is_immediate then
515- -- Add to stream and store message ID for later cancellation
516- local message_id = redis.call('XADD', stream_key, '*', unpack(message))
517- redis.call('SET', known_key, when_timestamp)
518- redis.call('SET', stream_id_key, message_id)
519- return message_id
520- else
521- -- Add to queue with task data in parked hash
522- redis.call('SET', known_key, when_timestamp)
523- redis.call('HSET', parked_key, unpack(message))
524- redis.call('ZADD', queue_key, when_timestamp, task_key)
525- return 'QUEUED'
526- end
527- """
528- ),
529- )
530- schedule_task = self ._schedule_task_script
454+ if replace :
455+ await self ._cancel (pipeline , key )
456+ else :
457+ # if the task is already in the queue or stream, retain it
458+ if await redis .exists (known_task_key ):
459+ logger .debug (
460+ "Task %r is already in the queue or stream, not scheduling" ,
461+ key ,
462+ extra = self .labels (),
463+ )
464+ return
531465
532- await schedule_task (
533- keys = [
534- self .stream_key ,
535- known_task_key ,
536- self .parked_task_key (key ),
537- self .queue_key ,
538- self .stream_id_key (key ),
539- ],
540- args = [
541- key ,
542- str (when .timestamp ()),
543- "1" if is_immediate else "0" ,
544- "1" if replace else "0" ,
545- * [
546- item
547- for field , value in message .items ()
548- for item in (field , value )
549- ],
550- ],
551- )
466+ pipeline .set (known_task_key , when .timestamp ())
552467
553- async def _cancel (self , redis : Redis , key : str ) -> None :
554- """Cancel a task atomically.
468+ if when <= datetime .now (timezone .utc ):
469+ pipeline .xadd (self .stream_key , message ) # type: ignore[arg-type]
470+ else :
471+ pipeline .hset (self .parked_task_key (key ), mapping = message ) # type: ignore[arg-type]
472+ pipeline .zadd (self .queue_key , {key : when .timestamp ()})
555473
556- Handles cancellation regardless of task location:
557- - From the stream (using stored message ID)
558- - From the queue (scheduled tasks)
559- - Cleans up all associated metadata keys
560- """
561- if self ._cancel_task_script is None :
562- self ._cancel_task_script = cast (
563- _cancel_task ,
564- redis .register_script (
565- # KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key
566- # ARGV: task_key
567- """
568- local stream_key = KEYS[1]
569- local known_key = KEYS[2]
570- local parked_key = KEYS[3]
571- local queue_key = KEYS[4]
572- local stream_id_key = KEYS[5]
573- local task_key = ARGV[1]
574-
575- -- Delete from stream if message ID exists
576- local message_id = redis.call('GET', stream_id_key)
577- if message_id then
578- redis.call('XDEL', stream_key, message_id)
579- end
580-
581- -- Clean up all task-related keys
582- redis.call('DEL', known_key, parked_key, stream_id_key)
583- redis.call('ZREM', queue_key, task_key)
584-
585- return 'OK'
586- """
587- ),
588- )
589- cancel_task = self ._cancel_task_script
590-
591- # Execute the cancellation script
592- await cancel_task (
593- keys = [
594- self .stream_key ,
595- self .known_task_key (key ),
596- self .parked_task_key (key ),
597- self .queue_key ,
598- self .stream_id_key (key ),
599- ],
600- args = [key ],
601- )
474+ async def _cancel (self , pipeline : Pipeline , key : str ) -> None :
475+ pipeline .delete (self .known_task_key (key ))
476+ pipeline .delete (self .parked_task_key (key ))
477+ pipeline .zrem (self .queue_key , key )
602478
603479 @property
604480 def strike_key (self ) -> str :
@@ -905,7 +781,6 @@ async def clear(self) -> int:
905781 key = key_bytes .decode ()
906782 pipeline .delete (self .parked_task_key (key ))
907783 pipeline .delete (self .known_task_key (key ))
908- pipeline .delete (self .stream_id_key (key ))
909784
910785 await pipeline .execute ()
911786
0 commit comments