|
1 | 1 | import asyncio |
2 | 2 | import base64 |
| 3 | +import functools |
3 | 4 | import logging |
4 | 5 | import os |
5 | 6 | import signal |
|
8 | 9 | import time |
9 | 10 | from datetime import datetime, timedelta, timezone |
10 | 11 | from types import TracebackType |
11 | | -from typing import Any, Coroutine, Mapping, Protocol, cast |
| 12 | +from typing import Any, Callable, Coroutine, Mapping, Protocol, TypeVar, cast |
12 | 13 |
|
13 | 14 | import cloudpickle # type: ignore[import] |
14 | 15 |
|
@@ -78,6 +79,40 @@ def __init__(self, execution: Execution): |
78 | 79 | logger: logging.Logger = logging.getLogger(__name__) |
79 | 80 | tracer: Tracer = trace.get_tracer(__name__) |
80 | 81 |
|
| 82 | +F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, None]]) |
| 83 | + |
| 84 | + |
| 85 | +def _with_sigterm_handler(func: F) -> F: # pragma: no cover |
| 86 | + """Decorator that installs a SIGTERM handler for graceful shutdown. |
| 87 | +
|
| 88 | + On SIGTERM, cancels the wrapped coroutine, allowing in-flight tasks to |
| 89 | + complete before the worker exits. Tested via subprocess in |
| 90 | + test_sigterm_gracefully_drains_inflight_tasks. |
| 91 | + """ |
| 92 | + if not hasattr(signal, "SIGTERM"): |
| 93 | + return func |
| 94 | + |
| 95 | + @functools.wraps(func) |
| 96 | + async def wrapper(*args: Any, **kwargs: Any) -> None: |
| 97 | + loop = asyncio.get_running_loop() |
| 98 | + task: asyncio.Task[None] | None = None |
| 99 | + |
| 100 | + def handle_sigterm() -> None: |
| 101 | + logger.info("Received SIGTERM, initiating graceful shutdown...") |
| 102 | + if task and not task.done(): |
| 103 | + task.cancel() |
| 104 | + |
| 105 | + loop.add_signal_handler(signal.SIGTERM, handle_sigterm) |
| 106 | + try: |
| 107 | + task = asyncio.create_task(func(*args, **kwargs)) |
| 108 | + await task |
| 109 | + except asyncio.CancelledError: |
| 110 | + pass # Expected from signal handler |
| 111 | + finally: |
| 112 | + loop.remove_signal_handler(signal.SIGTERM) |
| 113 | + |
| 114 | + return cast(F, wrapper) |
| 115 | + |
81 | 116 |
|
82 | 117 | class _stream_due_tasks(Protocol): |
83 | 118 | async def __call__( |
@@ -199,35 +234,21 @@ async def run( |
199 | 234 | ): |
200 | 235 | if until_finished: |
201 | 236 | await worker.run_until_finished() |
202 | | - else: # pragma: no cover |
203 | | - loop = asyncio.get_running_loop() |
204 | | - current_task = asyncio.current_task() |
205 | | - |
206 | | - def handle_sigterm() -> None: |
207 | | - logger.info( |
208 | | - "Received SIGTERM, initiating graceful shutdown..." |
209 | | - ) |
210 | | - if current_task and not current_task.done(): |
211 | | - current_task.cancel() |
212 | | - |
213 | | - if hasattr(signal, "SIGTERM"): |
214 | | - loop.add_signal_handler(signal.SIGTERM, handle_sigterm) |
215 | | - |
216 | | - try: |
217 | | - await worker.run_forever() |
218 | | - except asyncio.CancelledError: |
219 | | - pass # Expected from signal handler |
220 | | - finally: |
221 | | - if hasattr(signal, "SIGTERM"): |
222 | | - loop.remove_signal_handler(signal.SIGTERM) |
| 237 | + else: # pragma: no cover - tested via subprocess |
| 238 | + await worker.run_forever() |
223 | 239 |
|
224 | 240 | async def run_until_finished(self) -> None: |
225 | 241 | """Run the worker until there are no more tasks to process.""" |
226 | 242 | return await self._run(forever=False) |
227 | 243 |
|
228 | | - async def run_forever(self) -> None: |
229 | | - """Run the worker indefinitely.""" |
230 | | - return await self._run(forever=True) # pragma: no cover |
| 244 | + @_with_sigterm_handler |
| 245 | + async def run_forever(self) -> None: # pragma: no cover - tested via subprocess |
| 246 | + """Run the worker indefinitely. |
| 247 | +
|
| 248 | + Installs a SIGTERM handler that initiates graceful shutdown, allowing |
| 249 | + in-flight tasks to complete before the worker exits. |
| 250 | + """ |
| 251 | + await self._run(forever=True) |
231 | 252 |
|
232 | 253 | _execution_counts: dict[str, int] |
233 | 254 |
|
|
0 commit comments