11"""Celery task execution."""
22
33import os
4+ import signal
45import subprocess
56from asyncio import sleep
67from datetime import datetime , timezone
78from pathlib import Path
89from typing import List , Optional , Tuple
910from uuid import UUID
1011
12+ import psutil
1113from celery import Celery
1214from celery .backends .base import TaskRevokedError
1315from celery .result import AsyncResult
1921from aqueductcore .backend .context import UserInfo
2022from aqueductcore .backend .errors import AQDDBTaskNonExisting , AQDPermission
2123from aqueductcore .backend .models import orm
22- from aqueductcore .backend .models .task import TaskProcessExecutionResult , TaskRead
24+ from aqueductcore .backend .models .task import (TaskProcessExecutionResult ,
25+ TaskRead )
2326from aqueductcore .backend .services .utils import task_orm_to_model
2427from aqueductcore .backend .settings import settings
2528
3740
3841
3942celery_app .conf .update (result_extended = True )
43+ extension_process = None # pylint: disable=invalid-name
44+
45+ def worker_signal_handler (signo , _ ):
46+ """ Handle SIGINT signal and propagate it to child and grandchild processes. """
47+ global extension_process # pylint: disable=global-statement,global-variable-not-assigned
48+ if extension_process is not None :
49+ psutil_child_process = psutil .Process (extension_process .pid )
50+ for grand_child in psutil_child_process .children (recursive = True ):
51+ grand_child .send_signal (signo )
52+ extension_process .send_signal (signo )
4053
4154
4255@celery_app .task (bind = True )
43- def run_executable (
44- self , # pylint: disable=unused-argument
56+ def run_executable ( # pylint: disable=unused-argument
57+ self ,
4558 extension_directory_name : str ,
4659 shell_script : str ,
4760 ** kwargs ,
@@ -59,6 +72,8 @@ def run_executable(
5972 Returns:
6073 Tuple[int, str, str]: result code, std out, std error.
6174 """
75+ global extension_process # pylint: disable=global-statement
76+ signal .signal (signal .SIGINT , worker_signal_handler )
6277 extensions_dir = os .environ .get ("EXTENSIONS_DIR_PATH" , "" )
6378 if not extensions_dir :
6479 raise FileNotFoundError ("EXTENSIONS_DIR_PATH environment variable should be set." )
@@ -78,6 +93,7 @@ def run_executable(
7893 env = myenv ,
7994 cwd = workdir ,
8095 ) as proc :
96+ extension_process = proc
8197 out , err = proc .communicate (timeout = None )
8298 code = proc .returncode
8399 return (
@@ -103,15 +119,18 @@ async def _update_task_info(task_id: str, wait=False) -> TaskProcessExecutionRes
103119 task_result = task .result
104120
105121 if task_result is not None :
106- known_errors = (FileNotFoundError , TaskRevokedError )
107- if isinstance (task_result , known_errors ):
108- err = str (task_result )
109- task_info .std_err = err
122+ known_exceptions = (FileNotFoundError , KeyboardInterrupt , TaskRevokedError , Exception )
123+ if isinstance (task_result , known_exceptions ):
124+ task_info .std_err = str (task_result )
110125 elif task .ready ():
111- code , out , err = task_result
112- task_info .result_code = code
113- task_info .std_out = out
114- task_info .std_err = err
126+ # in case the result format is incorrect
127+ if len (task_result ) == 3 :
128+ code , out , err = task_result
129+ task_info .result_code = code
130+ task_info .std_out = out
131+ task_info .std_err = err
132+ else :
133+ task_info .std_err = str (task_result )
115134 task_info .ended_at = task .date_done
116135 task_info .kwargs = task .kwargs
117136
@@ -171,11 +190,7 @@ async def revoke_task(
171190 if not user_info .can_cancel_task_owned_by (task_user ):
172191 raise AQDPermission ("User has no permission to cancel tasks of this user." )
173192
174- # note: SIGINT does not lead to task abort. If you send
175- # KeyboardInterupt (SIGINT), it will not stop, and the
176- # exception does not propagate.
177- AsyncResult (db_task .task_id ).revoke (terminate = terminate , signal = "SIGTERM" )
178-
193+ AsyncResult (db_task .task_id ).revoke (terminate = terminate , signal = "SIGINT" )
179194 task_info = await _update_task_info (task_id = db_task .task_id , wait = False )
180195
181196 username = db_task .created_by_user .username
0 commit comments