5252
5353from streaq import logger
5454from streaq .constants import (
55- HEALTH_CHECK ,
5655 REDIS_ABORT ,
5756 REDIS_CHANNEL ,
5857 REDIS_CRON ,
@@ -200,7 +199,6 @@ class Worker(AsyncContextManagerMixin, Generic[C]):
200199 "_coworkers" ,
201200 "_health_key" ,
202201 "_health_tab" ,
203- "_health_tab_str" ,
204202 "_initialized" ,
205203 "_lib" ,
206204 "_limiter" ,
@@ -308,8 +306,7 @@ def __init__(
308306 self ._cancel_scopes : dict [str , CancelScope ] = {}
309307 self ._running_tasks : dict [str , set [str ]] = defaultdict (set )
310308 self ._limiter = CapacityLimiter (self .sync_concurrency )
311- self ._health_tab = CronTab (health_crontab )
312- self ._health_tab_str = health_crontab
309+ self ._health_tab = health_crontab
313310 self ._initialized = False
314311 # precalculate Redis prefixes
315312 self .prefix = REDIS_PREFIX + self .queue_name
@@ -351,6 +348,8 @@ def running(self) -> int:
351348
352349 def __str__ (self ) -> str :
353350 counters = {k : v for k , v in self .counters .items () if v }
351+ if self ._cancel_scopes :
352+ counters ["running" ] = self .running ()
354353 counters_str = repr (counters ).replace ("'" , "" )
355354 return f"worker { self .id } { counters_str } "
356355
@@ -379,7 +378,11 @@ async def redis_health_check(self) -> None:
379378 f"redis {{memory: { mem_usage } , clients: { clients } , keys: { keys } , "
380379 f"queued: { queued } , scheduled: { scheduled } }}"
381380 )
382- ttl = self ._delay_for (self ._health_tab )
381+ ttl = (
382+ self ._next_datetime (self ._health_tab )
383+ - datetime .now (self .tz )
384+ + timedelta (seconds = 1 )
385+ )
383386 await self .redis .set (self ._health_key + ":redis" , health , ex = ttl )
384387
385388 @asynccontextmanager
@@ -613,7 +616,7 @@ async def run_async(
613616 # run user-defined initialization code
614617 async with self , self .lifespan as context :
615618 # register redis health check
616- self .cron (self ._health_tab_str , silent = True , ttl = 0 , name = HEALTH_CHECK )(
619+ self .cron (self ._health_tab , silent = True , ttl = 0 , name = "redis_health_check" )(
617620 self .redis_health_check
618621 )
619622 token = worker_context .set (context )
@@ -641,27 +644,16 @@ async def run_async(
641644 send , receive = create_memory_object_stream [StreamMessage ](
642645 max_buffer_size = self .prefetch
643646 )
644- limiter = CapacityLimiter (self .concurrency )
645-
646- async def _run_consumers (scope : CancelScope ) -> None :
647- try :
648- async with create_task_group () as tg :
649- for _ in range (self .concurrency ):
650- tg .start_soon (self .consumer , receive .clone (), limiter )
651- finally :
652- # don't cancel renewal task until consumers finish
653- scope .cancel ()
654647
655648 # start tasks
656649 try :
657650 async with create_task_group () as tg :
658651 # register signal handler
659652 tg .start_soon (self .signal_handler , tg .cancel_scope )
660- tg .start_soon (self .health_check )
661- tg .start_soon (self .producer , send , limiter , tg .cancel_scope )
662653 scope = CancelScope (shield = True )
663654 tg .start_soon (self .renew_idle_timeouts , scope )
664- tg .start_soon (_run_consumers , scope )
655+ limiter = await tg .start (self .run_consumers , receive , scope )
656+ tg .start_soon (self .producer , send , limiter , tg .cancel_scope )
665657 task_status .started ()
666658 finally :
667659 run_time = to_ms (current_time () - start_time )
@@ -679,17 +671,38 @@ async def consumer(
679671 async with limiter :
680672 await self .run_task (msg )
681673
674+ async def run_consumers (
675+ self ,
676+ receive : MemoryObjectReceiveStream [StreamMessage ],
677+ scope : CancelScope ,
678+ * ,
679+ task_status : AnyStatus [CapacityLimiter ] = TASK_STATUS_IGNORED ,
680+ ) -> None :
681+ """
682+ Run all consumers in a dedicated task group, finally clean up.
683+ """
684+ limiter = CapacityLimiter (self .concurrency )
685+ try :
686+ async with create_task_group () as tg :
687+ for _ in range (self .concurrency ):
688+ tg .start_soon (self .consumer , receive .clone (), limiter )
689+ task_status .started (limiter )
690+ finally :
691+ # don't cancel renewal task until consumers finish
692+ scope .cancel ()
693+
682694 async def renew_idle_timeouts (self , scope : CancelScope ) -> None :
683695 """
684696 Periodically renew idle timeout for running tasks. This allows the queue to
685- be resilient to sudden shutdowns.
697+ be resilient to sudden shutdowns. Additionally marks worker as healthy.
686698 """
687699 timeout = self .idle_timeout / 1000 * 0.9 # 10% buffer
700+ key = f"{ self ._health_key } :{ self .id } "
688701 # prevent cancellation until consumers finish
689702 with scope :
690703 while True :
691- await sleep (timeout )
692704 async with self .redis .pipeline (transaction = False ) as pipe :
705+ pipe .set (key , str (self ), px = self .idle_timeout )
693706 for priority , tasks in self ._running_tasks .items ():
694707 if tasks :
695708 pipe .xclaim (
@@ -700,6 +713,7 @@ async def renew_idle_timeouts(self, scope: CancelScope) -> None:
700713 tasks ,
701714 justid = True ,
702715 )
716+ await sleep (timeout )
703717
704718 async def producer (
705719 self ,
@@ -1337,9 +1351,6 @@ async def queue_size(self, include_scheduled: bool = True) -> int:
13371351 )
13381352 return sum (await gather (* commands ))
13391353
1340- def _delay_for (self , tab : CronTab ) -> int :
1341- return to_ms (tab .next (now = datetime .now (self .tz )) + 1 ) # type: ignore
1342-
13431354 def _next_datetime (self , tab : str ) -> datetime :
13441355 return CronTab (tab ).next (now = datetime .now (self .tz ), return_datetime = True ) # type: ignore
13451356
@@ -1351,15 +1362,6 @@ def next_run(self, tab: str) -> int:
13511362 """
13521363 return datetime_ms (self ._next_datetime (tab ))
13531364
1354- async def health_check (self ) -> None :
1355- """
1356- Periodically stores info about the worker in Redis.
1357- """
1358- while True :
1359- ttl = self ._delay_for (self ._health_tab )
1360- await self .redis .set (f"{ self ._health_key } :{ self .id } " , str (self ), px = ttl )
1361- await sleep (ttl / 1000 )
1362-
13631365 async def signal_handler (self , scope : CancelScope ) -> None :
13641366 """
13651367 Gracefully shutdown the worker when a signal is received.
0 commit comments