Skip to content
This repository was archived by the owner on Jun 19, 2025. It is now read-only.

Commit 99890bb

Browse files
authored
fix(cancellation): TT-168 add graceful cancellation (#227)
* TT-168 add graceful cancellation * adding pstool * addressing comments * adding docstring
1 parent d8cdd2f commit 99890bb

File tree

3 files changed

+62
-17
lines changed

3 files changed

+62
-17
lines changed

aqueductcore/backend/services/task_executor.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Celery task execution."""
22

33
import os
4+
import signal
45
import subprocess
56
from asyncio import sleep
67
from datetime import datetime, timezone
78
from pathlib import Path
89
from typing import List, Optional, Tuple
910
from uuid import UUID
1011

12+
import psutil
1113
from celery import Celery
1214
from celery.backends.base import TaskRevokedError
1315
from celery.result import AsyncResult
@@ -19,7 +21,8 @@
1921
from aqueductcore.backend.context import UserInfo
2022
from aqueductcore.backend.errors import AQDDBTaskNonExisting, AQDPermission
2123
from aqueductcore.backend.models import orm
22-
from aqueductcore.backend.models.task import TaskProcessExecutionResult, TaskRead
24+
from aqueductcore.backend.models.task import (TaskProcessExecutionResult,
25+
TaskRead)
2326
from aqueductcore.backend.services.utils import task_orm_to_model
2427
from aqueductcore.backend.settings import settings
2528

@@ -37,11 +40,21 @@
3740

3841

3942
celery_app.conf.update(result_extended=True)
43+
extension_process = None # pylint: disable=invalid-name
44+
45+
def worker_signal_handler(signo, _):
46+
""" Handle SIGINT signal and propagate it to child and grandchild processes. """
47+
global extension_process # pylint: disable=global-statement,global-variable-not-assigned
48+
if extension_process is not None:
49+
psutil_child_process = psutil.Process(extension_process.pid)
50+
for grand_child in psutil_child_process.children(recursive=True):
51+
grand_child.send_signal(signo)
52+
extension_process.send_signal(signo)
4053

4154

4255
@celery_app.task(bind=True)
43-
def run_executable(
44-
self, # pylint: disable=unused-argument
56+
def run_executable( # pylint: disable=unused-argument
57+
self,
4558
extension_directory_name: str,
4659
shell_script: str,
4760
**kwargs,
@@ -59,6 +72,8 @@ def run_executable(
5972
Returns:
6073
Tuple[int, str, str]: result code, std out, std error.
6174
"""
75+
global extension_process # pylint: disable=global-statement
76+
signal.signal(signal.SIGINT, worker_signal_handler)
6277
extensions_dir = os.environ.get("EXTENSIONS_DIR_PATH", "")
6378
if not extensions_dir:
6479
raise FileNotFoundError("EXTENSIONS_DIR_PATH environment variable should be set.")
@@ -78,6 +93,7 @@ def run_executable(
7893
env=myenv,
7994
cwd=workdir,
8095
) as proc:
96+
extension_process = proc
8197
out, err = proc.communicate(timeout=None)
8298
code = proc.returncode
8399
return (
@@ -103,15 +119,18 @@ async def _update_task_info(task_id: str, wait=False) -> TaskProcessExecutionRes
103119
task_result = task.result
104120

105121
if task_result is not None:
106-
known_errors = (FileNotFoundError, TaskRevokedError)
107-
if isinstance(task_result, known_errors):
108-
err = str(task_result)
109-
task_info.std_err = err
122+
known_exceptions = (FileNotFoundError, KeyboardInterrupt, TaskRevokedError, Exception)
123+
if isinstance(task_result, known_exceptions):
124+
task_info.std_err = str(task_result)
110125
elif task.ready():
111-
code, out, err = task_result
112-
task_info.result_code = code
113-
task_info.std_out = out
114-
task_info.std_err = err
126+
# in case the result format is incorrect
127+
if len(task_result) == 3:
128+
code, out, err = task_result
129+
task_info.result_code = code
130+
task_info.std_out = out
131+
task_info.std_err = err
132+
else:
133+
task_info.std_err = str(task_result)
115134
task_info.ended_at = task.date_done
116135
task_info.kwargs = task.kwargs
117136

@@ -171,11 +190,7 @@ async def revoke_task(
171190
if not user_info.can_cancel_task_owned_by(task_user):
172191
raise AQDPermission("User has no permission to cancel tasks of this user.")
173192

174-
# note: SIGINT does not lead to task abort. If you send
175-
# KeyboardInterupt (SIGINT), it will not stop, and the
176-
# exception does not propagate.
177-
AsyncResult(db_task.task_id).revoke(terminate=terminate, signal="SIGTERM")
178-
193+
AsyncResult(db_task.task_id).revoke(terminate=terminate, signal="SIGINT")
179194
task_info = await _update_task_info(task_id=db_task.task_id, wait=False)
180195

181196
username = db_task.created_by_user.username

poetry.lock

Lines changed: 30 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ typer = "^0.12"
3333
psycopg2-binary = "^2.9"
3434
celery = "^5.4"
3535
flower = "^2.0"
36+
psutil = "^6.0.0"
3637

3738
[tool.poetry.scripts]
3839
aqueduct = "aqueductcore.cli:app"

0 commit comments

Comments
 (0)