Skip to content

Commit e6bcb73

Browse files
committed
Small refactors
* Remove little-used HEALTH_CHECK constant * Remove `Worker._health_tab` and `Worker._delay_for` * Add running tasks to worker string representation * Collapse worker health check into `renew_idle_timeouts` * Start producer task after consumers * Fix bug in redis health ttl
1 parent d5265e3 commit e6bcb73

File tree

4 files changed

+39
-38
lines changed

4 files changed

+39
-38
lines changed

streaq/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
VERSION = "6.1.0"
3+
VERSION = "6.2.0"
44
__version__ = VERSION
55

66
logger = logging.getLogger(__name__)

streaq/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
HEALTH_CHECK = "redis_health_check"
21
REDIS_ABORT = ":aborted"
32
REDIS_CHANNEL = ":channels:"
43
REDIS_CRON = ":cron:"

streaq/worker.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252

5353
from streaq import logger
5454
from streaq.constants import (
55-
HEALTH_CHECK,
5655
REDIS_ABORT,
5756
REDIS_CHANNEL,
5857
REDIS_CRON,
@@ -200,7 +199,6 @@ class Worker(AsyncContextManagerMixin, Generic[C]):
200199
"_coworkers",
201200
"_health_key",
202201
"_health_tab",
203-
"_health_tab_str",
204202
"_initialized",
205203
"_lib",
206204
"_limiter",
@@ -308,8 +306,7 @@ def __init__(
308306
self._cancel_scopes: dict[str, CancelScope] = {}
309307
self._running_tasks: dict[str, set[str]] = defaultdict(set)
310308
self._limiter = CapacityLimiter(self.sync_concurrency)
311-
self._health_tab = CronTab(health_crontab)
312-
self._health_tab_str = health_crontab
309+
self._health_tab = health_crontab
313310
self._initialized = False
314311
# precalculate Redis prefixes
315312
self.prefix = REDIS_PREFIX + self.queue_name
@@ -351,6 +348,8 @@ def running(self) -> int:
351348

352349
def __str__(self) -> str:
353350
counters = {k: v for k, v in self.counters.items() if v}
351+
if self._cancel_scopes:
352+
counters["running"] = self.running()
354353
counters_str = repr(counters).replace("'", "")
355354
return f"worker {self.id} {counters_str}"
356355

@@ -379,7 +378,11 @@ async def redis_health_check(self) -> None:
379378
f"redis {{memory: {mem_usage}, clients: {clients}, keys: {keys}, "
380379
f"queued: {queued}, scheduled: {scheduled}}}"
381380
)
382-
ttl = self._delay_for(self._health_tab)
381+
ttl = (
382+
self._next_datetime(self._health_tab)
383+
- datetime.now(self.tz)
384+
+ timedelta(seconds=1)
385+
)
383386
await self.redis.set(self._health_key + ":redis", health, ex=ttl)
384387

385388
@asynccontextmanager
@@ -613,7 +616,7 @@ async def run_async(
613616
# run user-defined initialization code
614617
async with self, self.lifespan as context:
615618
# register redis health check
616-
self.cron(self._health_tab_str, silent=True, ttl=0, name=HEALTH_CHECK)(
619+
self.cron(self._health_tab, silent=True, ttl=0, name="redis_health_check")(
617620
self.redis_health_check
618621
)
619622
token = worker_context.set(context)
@@ -641,27 +644,16 @@ async def run_async(
641644
send, receive = create_memory_object_stream[StreamMessage](
642645
max_buffer_size=self.prefetch
643646
)
644-
limiter = CapacityLimiter(self.concurrency)
645-
646-
async def _run_consumers(scope: CancelScope) -> None:
647-
try:
648-
async with create_task_group() as tg:
649-
for _ in range(self.concurrency):
650-
tg.start_soon(self.consumer, receive.clone(), limiter)
651-
finally:
652-
# don't cancel renewal task until consumers finish
653-
scope.cancel()
654647

655648
# start tasks
656649
try:
657650
async with create_task_group() as tg:
658651
# register signal handler
659652
tg.start_soon(self.signal_handler, tg.cancel_scope)
660-
tg.start_soon(self.health_check)
661-
tg.start_soon(self.producer, send, limiter, tg.cancel_scope)
662653
scope = CancelScope(shield=True)
663654
tg.start_soon(self.renew_idle_timeouts, scope)
664-
tg.start_soon(_run_consumers, scope)
655+
limiter = await tg.start(self.run_consumers, receive, scope)
656+
tg.start_soon(self.producer, send, limiter, tg.cancel_scope)
665657
task_status.started()
666658
finally:
667659
run_time = to_ms(current_time() - start_time)
@@ -679,17 +671,38 @@ async def consumer(
679671
async with limiter:
680672
await self.run_task(msg)
681673

674+
async def run_consumers(
675+
self,
676+
receive: MemoryObjectReceiveStream[StreamMessage],
677+
scope: CancelScope,
678+
*,
679+
task_status: AnyStatus[CapacityLimiter] = TASK_STATUS_IGNORED,
680+
) -> None:
681+
"""
682+
Run all consumers in a dedicated task group, finally clean up.
683+
"""
684+
limiter = CapacityLimiter(self.concurrency)
685+
try:
686+
async with create_task_group() as tg:
687+
for _ in range(self.concurrency):
688+
tg.start_soon(self.consumer, receive.clone(), limiter)
689+
task_status.started(limiter)
690+
finally:
691+
# don't cancel renewal task until consumers finish
692+
scope.cancel()
693+
682694
async def renew_idle_timeouts(self, scope: CancelScope) -> None:
683695
"""
684696
Periodically renew idle timeout for running tasks. This allows the queue to
685-
be resilient to sudden shutdowns.
697+
be resilient to sudden shutdowns. Additionally marks worker as healthy.
686698
"""
687699
timeout = self.idle_timeout / 1000 * 0.9 # 10% buffer
700+
key = f"{self._health_key}:{self.id}"
688701
# prevent cancellation until consumers finish
689702
with scope:
690703
while True:
691-
await sleep(timeout)
692704
async with self.redis.pipeline(transaction=False) as pipe:
705+
pipe.set(key, str(self), px=self.idle_timeout)
693706
for priority, tasks in self._running_tasks.items():
694707
if tasks:
695708
pipe.xclaim(
@@ -700,6 +713,7 @@ async def renew_idle_timeouts(self, scope: CancelScope) -> None:
700713
tasks,
701714
justid=True,
702715
)
716+
await sleep(timeout)
703717

704718
async def producer(
705719
self,
@@ -1337,9 +1351,6 @@ async def queue_size(self, include_scheduled: bool = True) -> int:
13371351
)
13381352
return sum(await gather(*commands))
13391353

1340-
def _delay_for(self, tab: CronTab) -> int:
1341-
return to_ms(tab.next(now=datetime.now(self.tz)) + 1) # type: ignore
1342-
13431354
def _next_datetime(self, tab: str) -> datetime:
13441355
return CronTab(tab).next(now=datetime.now(self.tz), return_datetime=True) # type: ignore
13451356

@@ -1351,15 +1362,6 @@ def next_run(self, tab: str) -> int:
13511362
"""
13521363
return datetime_ms(self._next_datetime(tab))
13531364

1354-
async def health_check(self) -> None:
1355-
"""
1356-
Periodically stores info about the worker in Redis.
1357-
"""
1358-
while True:
1359-
ttl = self._delay_for(self._health_tab)
1360-
await self.redis.set(f"{self._health_key}:{self.id}", str(self), px=ttl)
1361-
await sleep(ttl / 1000)
1362-
13631365
async def signal_handler(self, scope: CancelScope) -> None:
13641366
"""
13651367
Gracefully shutdown the worker when a signal is received.

tests/test_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,15 @@ async def foobar() -> None:
418418
await sleep(3)
419419

420420
worker.grace_period = 5
421+
task = foobar.enqueue().start(delay=1)
421422
async with create_task_group() as tg:
422423
await tg.start(worker.run_async)
423424
await foobar.enqueue()
424425
await sleep(1)
426+
await task
425427
os.kill(os.getpid(), signal.SIGINT)
426428
async with worker:
427-
task = await foobar.enqueue()
428-
await sleep(1)
429-
assert await task.status() == TaskStatus.QUEUED
429+
assert await task.status() == TaskStatus.SCHEDULED
430430

431431

432432
async def test_get_tasks_by_status_scheduled(worker: Worker):

0 commit comments

Comments
 (0)