Skip to content

Commit 126edc0

Browse files
hanchchchsanjayprabhugithub-advanced-security[bot]
authored
fix: propagate SIGTERM to agent (#184)
* feat: Allow agent to handle signals for graceful termination * Add tests, and fix a few agent termination edge cases * Fix flaky test * Allow nested event loops * Revert "Allow nested event loops" This reverts commit 42d04ec. * fix: log traceback if raised * fix: handle coroutine function * fix: test * fix: test * fix: run in a thread pool if _run_as_main_thread not set * fix: add signal handler * fix: redundant signal handler * fix: terminate process before closing connection on abort * fix: log sigterm * fix: duplicated PROCESS_SHUTDOWN_TIMEOUT_SECONDS * fix: set _run_as_main_thread on test * Potential fix for code scanning alert no. 4: Binding a socket to all network interfaces Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fix: logs test ignore warning * docs: add comments * refactor: rename run_as to run_on * fix: add tests --------- Co-authored-by: Sanjay Raveendran <sanjayprabhu@gmail.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent 02d1b8d commit 126edc0

File tree

6 files changed

+645
-36
lines changed

6 files changed

+645
-36
lines changed

src/isolate/connections/grpc/_base.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import os
12
import socket
3+
import subprocess
24
from contextlib import contextmanager
35
from dataclasses import dataclass
46
from pathlib import Path
@@ -23,6 +25,11 @@ class AgentError(Exception):
2325
"""An internal problem caused by (most probably) the agent."""
2426

2527

28+
PROCESS_SHUTDOWN_TIMEOUT_SECONDS = float(
29+
os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "60")
30+
)
31+
32+
2633
@dataclass
2734
class GRPCExecutionBase(EnvironmentConnection):
2835
"""A customizable gRPC-based execution backend."""
@@ -32,6 +39,9 @@ def start_agent(self) -> ContextManager[Tuple[str, grpc.ChannelCredentials]]:
3239
the required credentials to connect to it."""
3340
raise NotImplementedError
3441

42+
def abort_agent(self) -> None:
43+
raise NotImplementedError
44+
3545
@contextmanager
3646
def _establish_bridge(
3747
self,
@@ -113,6 +123,8 @@ def run(
113123

114124

115125
class LocalPythonGRPC(PythonExecutionBase[str], GRPCExecutionBase):
126+
_process: Union[None, subprocess.Popen] = None
127+
116128
@contextmanager
117129
def start_agent(self) -> Iterator[Tuple[str, grpc.ChannelCredentials]]:
118130
def find_free_port() -> Tuple[str, int]:
@@ -123,14 +135,24 @@ def find_free_port() -> Tuple[str, int]:
123135

124136
host, port = find_free_port()
125137
address = f"{host}:{port}"
126-
process = None
138+
self._process = None
127139
try:
128140
with self.start_process(address) as process:
141+
self._process = process
129142
yield address, grpc.local_channel_credentials()
130143
finally:
131-
if process is not None:
132-
# TODO: should we check the status code here?
133-
process.terminate()
144+
self.abort_agent()
145+
146+
def abort_agent(self) -> None:
147+
if self._process is not None:
148+
try:
149+
print("Terminating the agent process...")
150+
self._process.terminate()
151+
self._process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS)
152+
print("Agent process shutdown gracefully")
153+
except Exception as exc:
154+
print(f"Failed to shutdown the agent process gracefully: {exc}")
155+
self._process.kill()
134156

135157
def get_python_cmd(
136158
self,

src/isolate/connections/grpc/agent.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,23 @@
1010

1111
from __future__ import annotations
1212

13+
import asyncio
1314
import os
15+
import signal
1416
import sys
1517
import traceback
1618
from argparse import ArgumentParser
1719
from concurrent import futures
1820
from dataclasses import dataclass
1921
from typing import (
2022
Any,
23+
AsyncIterator,
2124
Iterable,
22-
Iterator,
2325
)
2426

25-
import grpc
26-
from grpc import ServicerContext, StatusCode
27+
from grpc import StatusCode, aio, local_server_credentials
28+
29+
from isolate.connections.grpc.definitions import PartialRunResult
2730

2831
try:
2932
from isolate import __version__ as agent_version
@@ -48,12 +51,19 @@ def __init__(self, log_fd: int | None = None):
4851

4952
self._run_cache: dict[str, Any] = {}
5053
self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w")
54+
self._thread_pool = futures.ThreadPoolExecutor(max_workers=1)
55+
56+
def handle_termination(*args):
57+
self.log("Termination signal received, shutting down...")
58+
signal.raise_signal(signal.SIGTERM)
5159

52-
def Run(
60+
signal.signal(signal.SIGINT, handle_termination)
61+
62+
async def Run(
5363
self,
5464
request: definitions.FunctionCall,
55-
context: ServicerContext,
56-
) -> Iterator[definitions.PartialRunResult]:
65+
context: aio.ServicerContext,
66+
) -> AsyncIterator[PartialRunResult]:
5767
self.log(f"A connection has been established: {context.peer()}!")
5868
server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown"
5969
self.log(f"Isolate info: server {server_version}, agent {agent_version}")
@@ -70,7 +80,7 @@ def Run(
7080
result,
7181
was_it_raised,
7282
stringized_tb,
73-
) = self.execute_function(
83+
) = await self.execute_function(
7484
request.setup_func,
7585
"setup",
7686
)
@@ -87,15 +97,16 @@ def Run(
8797
)
8898
raise AbortException("The setup function has thrown an error.")
8999
except AbortException as exc:
90-
return self.abort_with_msg(context, exc.message)
100+
self.abort_with_msg(context, exc.message)
101+
return
91102
else:
92103
assert not was_it_raised
93104
self._run_cache[cache_key] = result
94105

95106
extra_args.append(self._run_cache[cache_key])
96107

97108
try:
98-
result, was_it_raised, stringized_tb = self.execute_function(
109+
result, was_it_raised, stringized_tb = await self.execute_function(
99110
request.function,
100111
"function",
101112
extra_args=extra_args,
@@ -107,9 +118,10 @@ def Run(
107118
stringized_tb,
108119
)
109120
except AbortException as exc:
110-
return self.abort_with_msg(context, exc.message)
121+
self.abort_with_msg(context, exc.message)
122+
return
111123

112-
def execute_function(
124+
async def execute_function(
113125
self,
114126
function: definitions.SerializedObject,
115127
function_kind: str,
@@ -143,14 +155,34 @@ def execute_function(
143155
was_it_raised = False
144156
stringized_tb = None
145157
try:
146-
result = function(*extra_args)
158+
# Newer fal SDK will mark async entrypoints with `_run_on_main_thread` so
159+
# we execute on the main loop and can await the coroutine they return.
160+
# Older fal SDK still call `asyncio.run(...)`.
161+
# To avoid error "asyncio.run() cannot be called from a running event loop"
162+
# and be backward compatible,
163+
# we offload those unflagged functions to a thread pool.
164+
165+
if getattr(function, "_run_on_main_thread", False):
166+
result = function(*extra_args)
167+
else:
168+
result = self._thread_pool.submit(function, *extra_args).result()
169+
170+
if asyncio.iscoroutine(result):
171+
result = await result
172+
147173
except BaseException as exc:
148174
result = exc
149175
was_it_raised = True
150176
num_frames = len(traceback.extract_stack()[:-5])
151177
stringized_tb = "".join(traceback.format_exc(limit=-num_frames))
152178

153-
self.log(f"Completed the execution of the {function_kind} function.")
179+
if not was_it_raised:
180+
self.log(f"Completed the execution of the {function_kind} function.")
181+
else:
182+
self.log(
183+
f"Completed the execution of the {function_kind} function"
184+
f" with an error: {result}\nTraceback:\n{stringized_tb}"
185+
)
154186
return result, was_it_raised, stringized_tb
155187

156188
def send_object(
@@ -195,7 +227,7 @@ def log(self, message: str) -> None:
195227

196228
def abort_with_msg(
197229
self,
198-
context: ServicerContext,
230+
context: aio.ServicerContext,
199231
message: str,
200232
*,
201233
code: StatusCode = StatusCode.INVALID_ARGUMENT,
@@ -205,23 +237,26 @@ def abort_with_msg(
205237
return None
206238

207239

208-
def create_server(address: str) -> grpc.Server:
240+
def create_server(address: str) -> aio.Server:
209241
"""Create a new (temporary) gRPC server listening on the given
210242
address."""
211-
server = grpc.server(
212-
futures.ThreadPoolExecutor(max_workers=1),
213-
maximum_concurrent_rpcs=1,
243+
# Use asyncio server so requests can run in the main thread and intercept signals
244+
# There seems to be a weird bug with grpcio that makes subsequent requests fail with
245+
# concurrent rpc limit exceeded if we set maximum_current_rpcs to 1. Setting it to 2
246+
# fixes it, even though in practice, we only run one request at a time.
247+
server = aio.server(
248+
maximum_concurrent_rpcs=2,
214249
options=get_default_options(),
215250
)
216251

217252
# Local server credentials allow us to ensure that the
218253
# connection is established by a local process.
219-
server_credentials = grpc.local_server_credentials()
254+
server_credentials = local_server_credentials()
220255
server.add_secure_port(address, server_credentials)
221256
return server
222257

223258

224-
def run_agent(address: str, log_fd: int | None = None) -> int:
259+
async def run_agent(address: str, log_fd: int | None = None) -> int:
225260
"""Run the agent servicer on the given address."""
226261
server = create_server(address)
227262
servicer = AgentServicer(log_fd=log_fd)
@@ -231,19 +266,19 @@ def run_agent(address: str, log_fd: int | None = None) -> int:
231266
# not have any global side effects.
232267
definitions.register_agent(servicer, server)
233268

234-
server.start()
235-
server.wait_for_termination()
269+
await server.start()
270+
await server.wait_for_termination()
236271
return 0
237272

238273

239-
def main() -> int:
274+
async def main() -> int:
240275
parser = ArgumentParser()
241276
parser.add_argument("address", type=str)
242277
parser.add_argument("--log-fd", type=int)
243278

244279
options = parser.parse_args()
245-
return run_agent(options.address, log_fd=options.log_fd)
280+
return await run_agent(options.address, log_fd=options.log_fd)
246281

247282

248283
if __name__ == "__main__":
249-
main()
284+
asyncio.run(main())

0 commit comments

Comments
 (0)