|
1 | 1 | import asyncio
|
| 2 | +import contextlib |
2 | 3 | import inspect
|
3 | 4 | import logging
|
4 | 5 | import signal
|
@@ -153,6 +154,10 @@ class Worker:
|
153 | 154 | :param after_job_end: coroutine function to run after job has ended and results have been recorded
|
154 | 155 | :param handle_signals: default true, register signal handlers,
|
155 | 156 | set to false when running inside other async framework
|
| 157 | + :param job_completion_wait: time to wait before cancelling tasks after a signal. |
| 158 | + Useful together with ``terminationGracePeriodSeconds`` in kubernetes, |
| 159 | + when you want to make the pod complete jobs before shutting down. |
| 160 | + The worker will not pick new tasks while waiting for shut down. |
156 | 161 | :param max_jobs: maximum number of jobs to run at a time
|
157 | 162 | :param job_timeout: default job timeout (max run time)
|
158 | 163 | :param keep_result: default duration to keep job results for
|
@@ -192,6 +197,7 @@ def __init__(
|
192 | 197 | on_job_end: Optional['StartupShutdown'] = None,
|
193 | 198 | after_job_end: Optional['StartupShutdown'] = None,
|
194 | 199 | handle_signals: bool = True,
|
| 200 | + job_completion_wait: int = 0, |
195 | 201 | max_jobs: int = 10,
|
196 | 202 | job_timeout: 'SecondsTimedelta' = 300,
|
197 | 203 | keep_result: 'SecondsTimedelta' = 3600,
|
@@ -263,13 +269,19 @@ def __init__(
|
263 | 269 | self._last_health_check: float = 0
|
264 | 270 | self._last_health_check_log: Optional[str] = None
|
265 | 271 | self._handle_signals = handle_signals
|
| 272 | + self._job_completion_wait = job_completion_wait |
266 | 273 | if self._handle_signals:
|
267 |
| - self._add_signal_handler(signal.SIGINT, self.handle_sig) |
268 |
| - self._add_signal_handler(signal.SIGTERM, self.handle_sig) |
| 274 | + if self._job_completion_wait: |
| 275 | + self._add_signal_handler(signal.SIGINT, self.handle_sig_wait_for_completion) |
| 276 | + self._add_signal_handler(signal.SIGTERM, self.handle_sig_wait_for_completion) |
| 277 | + else: |
| 278 | + self._add_signal_handler(signal.SIGINT, self.handle_sig) |
| 279 | + self._add_signal_handler(signal.SIGTERM, self.handle_sig) |
269 | 280 | self.on_stop: Optional[Callable[[Signals], None]] = None
|
270 | 281 | # whether or not to retry jobs on Retry and CancelledError
|
271 | 282 | self.retry_jobs = retry_jobs
|
272 | 283 | self.allow_abort_jobs = allow_abort_jobs
|
| 284 | + self.allow_pick_jobs: bool = True |
273 | 285 | self.aborting_tasks: Set[str] = set()
|
274 | 286 | self.max_burst_jobs = max_burst_jobs
|
275 | 287 | self.job_serializer = job_serializer
|
@@ -361,23 +373,23 @@ async def _poll_iteration(self) -> None:
|
361 | 373 | if burst_jobs_remaining < 1:
|
362 | 374 | return
|
363 | 375 | count = min(burst_jobs_remaining, count)
|
| 376 | + if self.allow_pick_jobs: |
| 377 | + async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs |
| 378 | + now = timestamp_ms() |
| 379 | + job_ids = await self.pool.zrangebyscore( |
| 380 | + self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now |
| 381 | + ) |
364 | 382 |
|
365 |
| - async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs |
366 |
| - now = timestamp_ms() |
367 |
| - job_ids = await self.pool.zrangebyscore( |
368 |
| - self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now |
369 |
| - ) |
370 |
| - |
371 |
| - await self.start_jobs(job_ids) |
| 383 | + await self.start_jobs(job_ids) |
372 | 384 |
|
373 |
| - if self.allow_abort_jobs: |
374 |
| - await self._cancel_aborted_jobs() |
| 385 | + if self.allow_abort_jobs: |
| 386 | + await self._cancel_aborted_jobs() |
375 | 387 |
|
376 |
| - for job_id, t in list(self.tasks.items()): |
377 |
| - if t.done(): |
378 |
| - del self.tasks[job_id] |
379 |
| - # required to make sure errors in run_job get propagated |
380 |
| - t.result() |
| 388 | + for job_id, t in list(self.tasks.items()): |
| 389 | + if t.done(): |
| 390 | + del self.tasks[job_id] |
| 391 | + # required to make sure errors in run_job get propagated |
| 392 | + t.result() |
381 | 393 |
|
382 | 394 | await self.heart_beat()
|
383 | 395 |
|
@@ -757,6 +769,55 @@ def _add_signal_handler(self, signum: Signals, handler: Callable[[Signals], None
|
757 | 769 | def _jobs_started(self) -> int:
|
758 | 770 | return self.jobs_complete + self.jobs_retried + self.jobs_failed + len(self.tasks)
|
759 | 771 |
|
| 772 | + async def _sleep_until_tasks_complete(self) -> None: |
| 773 | + """ |
| 774 | + Sleeps until all tasks are done. Used together with asyncio.wait_for() |
| 775 | + """ |
| 776 | + while len(self.tasks): |
| 777 | + await asyncio.sleep(0.1) |
| 778 | + |
| 779 | + async def _wait_for_tasks_to_complete(self, signum: Signals) -> None: |
| 780 | + """ |
| 781 | + Wait for tasks to complete, until `wait_for_job_completion_on_signal_second` has been reached. |
| 782 | + """ |
| 783 | + with contextlib.suppress(asyncio.TimeoutError): |
| 784 | + await asyncio.wait_for( |
| 785 | + self._sleep_until_tasks_complete(), |
| 786 | + self._job_completion_wait, |
| 787 | + ) |
| 788 | + logger.info( |
| 789 | + 'shutdown on %s, wait complete ◆ %d jobs complete ◆ %d failed ◆ %d retries ◆ %d ongoing to cancel', |
| 790 | + signum.name, |
| 791 | + self.jobs_complete, |
| 792 | + self.jobs_failed, |
| 793 | + self.jobs_retried, |
| 794 | + sum(not t.done() for t in self.tasks.values()), |
| 795 | + ) |
| 796 | + for t in self.tasks.values(): |
| 797 | + if not t.done(): |
| 798 | + t.cancel() |
| 799 | + self.main_task and self.main_task.cancel() |
| 800 | + self.on_stop and self.on_stop(signum) |
| 801 | + |
| 802 | + def handle_sig_wait_for_completion(self, signum: Signals) -> None: |
| 803 | + """ |
| 804 | + Alternative signal handler that allow tasks to complete within a given time before shutting down the worker. |
| 805 | + Time can be configured using `wait_for_job_completion_on_signal_second`. |
| 806 | + The worker will stop picking jobs when signal has been received. |
| 807 | + """ |
| 808 | + sig = Signals(signum) |
| 809 | + logger.info('Setting allow_pick_jobs to `False`') |
| 810 | + self.allow_pick_jobs = False |
| 811 | + logger.info( |
| 812 | + 'shutdown on %s ◆ %d jobs complete ◆ %d failed ◆ %d retries ◆ %d to be completed', |
| 813 | + sig.name, |
| 814 | + self.jobs_complete, |
| 815 | + self.jobs_failed, |
| 816 | + self.jobs_retried, |
| 817 | + len(self.tasks), |
| 818 | + ) |
| 819 | + self.loop.create_task(self._wait_for_tasks_to_complete(signum=sig)) |
| 820 | + |
760 | 821 | def handle_sig(self, signum: Signals) -> None:
|
761 | 822 | sig = Signals(signum)
|
762 | 823 | logger.info(
|
|
0 commit comments