Skip to content

Commit af17042

Browse files
feat: signal handler to wait for task completion before shutting down (#345)
* feat: k8s signal handler to wait for task completion before shutting down * refactor: signal handler to be more generic * refactor: rename to job_completion_wait, split signal handler to a new function tests: add test for new signal handler flow * fix: use asyncio.TimeoutError to support python3.7 * Update arq/worker.py Co-authored-by: Samuel Colvin <[email protected]>
1 parent 2d45f53 commit af17042

File tree

2 files changed

+105
-16
lines changed

2 files changed

+105
-16
lines changed

arq/worker.py

+77-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import inspect
34
import logging
45
import signal
@@ -153,6 +154,10 @@ class Worker:
153154
:param after_job_end: coroutine function to run after job has ended and results have been recorded
154155
:param handle_signals: default true, register signal handlers,
155156
set to false when running inside other async framework
157+
:param job_completion_wait: time to wait before cancelling tasks after a signal.
158+
Useful together with ``terminationGracePeriodSeconds`` in kubernetes,
159+
when you want to make the pod complete jobs before shutting down.
160+
The worker will not pick new tasks while waiting for shut down.
156161
:param max_jobs: maximum number of jobs to run at a time
157162
:param job_timeout: default job timeout (max run time)
158163
:param keep_result: default duration to keep job results for
@@ -192,6 +197,7 @@ def __init__(
192197
on_job_end: Optional['StartupShutdown'] = None,
193198
after_job_end: Optional['StartupShutdown'] = None,
194199
handle_signals: bool = True,
200+
job_completion_wait: int = 0,
195201
max_jobs: int = 10,
196202
job_timeout: 'SecondsTimedelta' = 300,
197203
keep_result: 'SecondsTimedelta' = 3600,
@@ -263,13 +269,19 @@ def __init__(
263269
self._last_health_check: float = 0
264270
self._last_health_check_log: Optional[str] = None
265271
self._handle_signals = handle_signals
272+
self._job_completion_wait = job_completion_wait
266273
if self._handle_signals:
267-
self._add_signal_handler(signal.SIGINT, self.handle_sig)
268-
self._add_signal_handler(signal.SIGTERM, self.handle_sig)
274+
if self._job_completion_wait:
275+
self._add_signal_handler(signal.SIGINT, self.handle_sig_wait_for_completion)
276+
self._add_signal_handler(signal.SIGTERM, self.handle_sig_wait_for_completion)
277+
else:
278+
self._add_signal_handler(signal.SIGINT, self.handle_sig)
279+
self._add_signal_handler(signal.SIGTERM, self.handle_sig)
269280
self.on_stop: Optional[Callable[[Signals], None]] = None
270281
# whether or not to retry jobs on Retry and CancelledError
271282
self.retry_jobs = retry_jobs
272283
self.allow_abort_jobs = allow_abort_jobs
284+
self.allow_pick_jobs: bool = True
273285
self.aborting_tasks: Set[str] = set()
274286
self.max_burst_jobs = max_burst_jobs
275287
self.job_serializer = job_serializer
@@ -361,23 +373,23 @@ async def _poll_iteration(self) -> None:
361373
if burst_jobs_remaining < 1:
362374
return
363375
count = min(burst_jobs_remaining, count)
376+
if self.allow_pick_jobs:
377+
async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs
378+
now = timestamp_ms()
379+
job_ids = await self.pool.zrangebyscore(
380+
self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now
381+
)
364382

365-
async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs
366-
now = timestamp_ms()
367-
job_ids = await self.pool.zrangebyscore(
368-
self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now
369-
)
370-
371-
await self.start_jobs(job_ids)
383+
await self.start_jobs(job_ids)
372384

373-
if self.allow_abort_jobs:
374-
await self._cancel_aborted_jobs()
385+
if self.allow_abort_jobs:
386+
await self._cancel_aborted_jobs()
375387

376-
for job_id, t in list(self.tasks.items()):
377-
if t.done():
378-
del self.tasks[job_id]
379-
# required to make sure errors in run_job get propagated
380-
t.result()
388+
for job_id, t in list(self.tasks.items()):
389+
if t.done():
390+
del self.tasks[job_id]
391+
# required to make sure errors in run_job get propagated
392+
t.result()
381393

382394
await self.heart_beat()
383395

@@ -757,6 +769,55 @@ def _add_signal_handler(self, signum: Signals, handler: Callable[[Signals], None
757769
def _jobs_started(self) -> int:
758770
return self.jobs_complete + self.jobs_retried + self.jobs_failed + len(self.tasks)
759771

772+
async def _sleep_until_tasks_complete(self) -> None:
773+
"""
774+
Sleeps until all tasks are done. Used together with asyncio.wait_for()
775+
"""
776+
while len(self.tasks):
777+
await asyncio.sleep(0.1)
778+
779+
async def _wait_for_tasks_to_complete(self, signum: Signals) -> None:
780+
"""
781+
Wait for tasks to complete, until `wait_for_job_completion_on_signal_second` has been reached.
782+
"""
783+
with contextlib.suppress(asyncio.TimeoutError):
784+
await asyncio.wait_for(
785+
self._sleep_until_tasks_complete(),
786+
self._job_completion_wait,
787+
)
788+
logger.info(
789+
'shutdown on %s, wait complete ◆ %d jobs complete ◆ %d failed ◆ %d retries ◆ %d ongoing to cancel',
790+
signum.name,
791+
self.jobs_complete,
792+
self.jobs_failed,
793+
self.jobs_retried,
794+
sum(not t.done() for t in self.tasks.values()),
795+
)
796+
for t in self.tasks.values():
797+
if not t.done():
798+
t.cancel()
799+
self.main_task and self.main_task.cancel()
800+
self.on_stop and self.on_stop(signum)
801+
802+
def handle_sig_wait_for_completion(self, signum: Signals) -> None:
803+
"""
804+
Alternative signal handler that allow tasks to complete within a given time before shutting down the worker.
805+
Time can be configured using `wait_for_job_completion_on_signal_second`.
806+
The worker will stop picking jobs when signal has been received.
807+
"""
808+
sig = Signals(signum)
809+
logger.info('Setting allow_pick_jobs to `False`')
810+
self.allow_pick_jobs = False
811+
logger.info(
812+
'shutdown on %s ◆ %d jobs complete ◆ %d failed ◆ %d retries ◆ %d to be completed',
813+
sig.name,
814+
self.jobs_complete,
815+
self.jobs_failed,
816+
self.jobs_retried,
817+
len(self.tasks),
818+
)
819+
self.loop.create_task(self._wait_for_tasks_to_complete(signum=sig))
820+
760821
def handle_sig(self, signum: Signals) -> None:
761822
sig = Signals(signum)
762823
logger.info(

tests/test_worker.py

+28
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,34 @@ async def test_handle_no_sig(caplog):
110110
assert worker.tasks[1].cancel.call_count == 1
111111

112112

113+
async def test_worker_signal_completes_job_before_shutting_down(caplog, arq_redis: ArqRedis, worker):
114+
caplog.set_level(logging.INFO)
115+
116+
async def sleep_job(ctx, time):
117+
await asyncio.sleep(time)
118+
119+
await arq_redis.enqueue_job('sleep_job', 0.2, _job_id='short_sleep') # should be cancelled
120+
await arq_redis.enqueue_job('sleep_job', 5, _job_id='long_sleep') # should be cancelled
121+
worker = worker(
122+
functions=[func(sleep_job, name='sleep_job', max_tries=1)],
123+
job_completion_wait=0.4,
124+
job_timeout=10,
125+
)
126+
assert worker.jobs_complete == 0
127+
asyncio.create_task(worker.main())
128+
await asyncio.sleep(0.1)
129+
worker.handle_sig_wait_for_completion(signal.SIGINT)
130+
assert worker.allow_pick_jobs is False
131+
await asyncio.sleep(0.5)
132+
logs = [rec.message for rec in caplog.records]
133+
assert 'shutdown on SIGINT ◆ 0 jobs complete ◆ 0 failed ◆ 0 retries ◆ 2 to be completed' in logs
134+
assert 'shutdown on SIGINT, wait complete ◆ 1 jobs complete ◆ 0 failed ◆ 0 retries ◆ 1 ongoing to cancel' in logs
135+
assert 'long_sleep:sleep_job cancelled, will be run again' in logs[-1]
136+
assert worker.jobs_complete == 1
137+
assert worker.jobs_retried == 1
138+
assert worker.jobs_failed == 0
139+
140+
113141
async def test_job_successful(arq_redis: ArqRedis, worker, caplog):
114142
caplog.set_level(logging.INFO)
115143
await arq_redis.enqueue_job('foobar', _job_id='testing')

0 commit comments

Comments
 (0)