Skip to content

Commit c0a04a8

Browse files
authored
feat: add agent idle timeout (#189)
* feat: add agent idle timeout * feat: handle SIGCHLD in parent server * fix: use events and default idle_timeout 0 * fix: add more tests * fix: prevent blocking event loop * fix: test * fix: test fd * fix: flaky test_server_single_use_submit
1 parent c43d3c2 commit c0a04a8

File tree

5 files changed

+254
-13
lines changed

5 files changed

+254
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ test = [
4848
"cloudpickle>=2.2.0",
4949
"dill>=0.3.5.1",
5050
"flaky",
51+
"psutil>=7.2.1",
5152
]
5253
dev = [
5354
"isolate[test]",

src/isolate/connections/grpc/_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ def abort_agent(self) -> None:
153153
except Exception as exc:
154154
print(f"Failed to shutdown the agent process gracefully: {exc}")
155155
self._process.kill()
156+
self._process = None
157+
158+
def is_alive(self) -> bool:
159+
if self._process is None:
160+
return False
161+
return self._process.poll() is None
156162

157163
def get_python_cmd(
158164
self,

src/isolate/connections/grpc/agent.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from isolate.connections.grpc.configuration import get_default_options
4040
from isolate.connections.grpc.interface import from_grpc
4141

42+
IDLE_TIMEOUT_SECONDS = int(os.getenv("ISOLATE_AGENT_IDLE_TIMEOUT_SECONDS", "0"))
43+
4244

4345
@dataclass
4446
class AbortException(Exception):
@@ -52,17 +54,64 @@ def __init__(self, log_fd: int | None = None):
5254
self._run_cache: dict[str, Any] = {}
5355
self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w")
5456
self._thread_pool = futures.ThreadPoolExecutor(max_workers=1)
57+
self._idle_timeout_seconds = IDLE_TIMEOUT_SECONDS
58+
self._is_running = asyncio.Event()
59+
self._is_idle = asyncio.Event()
60+
self._is_idle.set()
5561

56-
def handle_termination(*args):
57-
self.log("Termination signal received, shutting down...")
62+
def handle_sigint(*args):
63+
self.log("SIGINT signal received, shutting down...")
5864
signal.raise_signal(signal.SIGTERM)
5965

60-
signal.signal(signal.SIGINT, handle_termination)
66+
signal.signal(signal.SIGINT, handle_sigint)
67+
68+
async def wait_for_idle_timeout(self) -> None:
69+
while True:
70+
# print(f"Hello, world! {self._idle_timeout_seconds}")
71+
# wait for the agent to be idle
72+
await self._is_idle.wait()
73+
74+
# idle timeout disabled
75+
if self._idle_timeout_seconds <= 0:
76+
# prevent blocking the event loop
77+
await asyncio.sleep(0.1)
78+
continue
79+
80+
try:
81+
# wait for the agent to be running
82+
await asyncio.wait_for(
83+
self._is_running.wait(), timeout=self._idle_timeout_seconds
84+
)
85+
except asyncio.TimeoutError:
86+
self.log(
87+
f"Idle timeout {self._idle_timeout_seconds} seconds exceeded, shutting down..." # noqa: E501
88+
)
89+
# This kills the agent itself, however it will remain as a zombie state
90+
# unless the parent process (server) properly handles the SIGCHLD.
91+
signal.raise_signal(signal.SIGTERM)
92+
break
93+
except asyncio.CancelledError:
94+
# Cancelled when the server is shutting down
95+
break
6196

6297
async def Run(
6398
self,
6499
request: definitions.FunctionCall,
65100
context: aio.ServicerContext,
101+
) -> AsyncIterator[PartialRunResult]:
102+
self._is_idle.clear()
103+
self._is_running.set()
104+
try:
105+
async for result in self._Run(request, context):
106+
yield result
107+
finally:
108+
self._is_running.clear()
109+
self._is_idle.set()
110+
111+
async def _Run(
112+
self,
113+
request: definitions.FunctionCall,
114+
context: aio.ServicerContext,
66115
) -> AsyncIterator[PartialRunResult]:
67116
self.log(f"A connection has been established: {context.peer()}!")
68117
server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown"
@@ -267,7 +316,18 @@ async def run_agent(address: str, log_fd: int | None = None) -> int:
267316
definitions.register_agent(servicer, server)
268317

269318
await server.start()
270-
await server.wait_for_termination()
319+
320+
_, pending = await asyncio.wait(
321+
[
322+
asyncio.create_task(server.wait_for_termination()),
323+
asyncio.create_task(servicer.wait_for_idle_timeout()),
324+
],
325+
return_when=asyncio.FIRST_COMPLETED,
326+
)
327+
for task in pending:
328+
print(f"Cancelling task: {task}")
329+
task.cancel()
330+
271331
return 0
272332

273333

src/isolate/server/server.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,14 @@ def __exit__(self, *exc_info: Any) -> None:
182182
for agent in agents:
183183
agent.terminate()
184184

185+
def abort_unreachable_agents(self) -> None:
186+
for agents in self._agents.values():
187+
for agent in agents:
188+
connection = agent._connection
189+
if connection is not None and not connection.is_alive():
190+
connection.abort_agent()
191+
# maybe restart the agent?
192+
185193

186194
@dataclass
187195
class RunTask:
@@ -650,7 +658,7 @@ def termination() -> None:
650658

651659
def _stop(*args):
652660
# Small sleep to make sure the cancellation is processed
653-
time.sleep(0.1)
661+
time.sleep(0.3)
654662
print("Stopping server since the task is finished")
655663
self.servicer.shutdown()
656664
self.server.stop(grace=0.1)
@@ -720,8 +728,13 @@ def handle_termination(*args):
720728
servicer.shutdown()
721729
server.stop(grace=0.1)
722730

731+
def handle_child_termination(*args):
732+
print("Child termination signal received, aborting unreachable agents...")
733+
bridge_manager.abort_unreachable_agents()
734+
723735
signal.signal(signal.SIGINT, handle_termination)
724736
signal.signal(signal.SIGTERM, handle_termination)
737+
signal.signal(signal.SIGCHLD, handle_child_termination)
725738

726739
server.add_insecure_port(f"[::]:{options.port}")
727740
print(f"Started listening at {options.host}:{options.port}")

tests/test_shutdown.py

Lines changed: 169 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
import sys
88
import threading
99
import time
10+
from typing import Iterator, Tuple
1011
from unittest.mock import Mock
1112

1213
import grpc
14+
import psutil
1315
import pytest
16+
from isolate.connections.grpc.definitions import FunctionCall
17+
from isolate.connections.grpc.definitions.agent_pb2_grpc import AgentStub
1418
from isolate.server.definitions.server_pb2 import BoundFunction, EnvironmentDefinition
1519
from isolate.server.definitions.server_pb2_grpc import IsolateStub
1620
from isolate.server.interface import to_serialized_object
@@ -44,13 +48,23 @@ def servicer():
4448

4549

4650
@pytest.fixture
47-
def isolate_server_subprocess(monkeypatch):
51+
def single_use():
52+
return True
53+
54+
55+
@pytest.fixture
56+
def idle_timeout_seconds():
57+
return 0
58+
59+
60+
@pytest.fixture
61+
def isolate_server_subprocess(
62+
single_use: bool, idle_timeout_seconds: int
63+
) -> Iterator[Tuple[subprocess.Popen, int]]:
4864
"""Set up a gRPC server with the IsolateServicer for testing."""
4965
# Find a free port
5066
import socket
5167

52-
monkeypatch.setenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "2")
53-
5468
# Bind only to the loopback interface to avoid exposing the socket on all interfaces
5569
with socket.socket() as s:
5670
s.bind(("127.0.0.1", 0))
@@ -61,10 +75,14 @@ def isolate_server_subprocess(monkeypatch):
6175
sys.executable,
6276
"-m",
6377
"isolate.server.server",
64-
"--single-use",
78+
*(["--single-use"] if single_use else []),
6579
"--port",
6680
str(port),
67-
]
81+
],
82+
env={
83+
"ISOLATE_SHUTDOWN_GRACE_PERIOD": "2",
84+
"ISOLATE_AGENT_IDLE_TIMEOUT_SECONDS": str(idle_timeout_seconds),
85+
},
6886
)
6987

7088
time.sleep(5) # Wait for server to start
@@ -76,7 +94,50 @@ def isolate_server_subprocess(monkeypatch):
7694
process.wait(timeout=10)
7795

7896

79-
def consume_responses(responses):
97+
@pytest.fixture
98+
def isolate_agent_subprocess(
99+
idle_timeout_seconds: int,
100+
) -> Iterator[Tuple[subprocess.Popen, int]]:
101+
"""Set up a gRPC server with the IsolateServicer for testing."""
102+
# Find a free port
103+
import socket
104+
105+
# Bind only to the loopback interface to avoid exposing the socket on all interfaces
106+
with socket.socket() as s:
107+
s.bind(("127.0.0.1", 0))
108+
port = s.getsockname()[1]
109+
110+
# Use /dev/null for log output since pytest may capture stdout
111+
# (making fileno() fail with "Bad file descriptor")
112+
log_file = open(os.devnull, "w")
113+
114+
process = subprocess.Popen(
115+
[
116+
sys.executable,
117+
"-m",
118+
"isolate.connections.grpc.agent",
119+
f"localhost:{port}",
120+
"--log-fd",
121+
str(log_file.fileno()),
122+
],
123+
env={
124+
"ISOLATE_AGENT_IDLE_TIMEOUT_SECONDS": str(idle_timeout_seconds),
125+
},
126+
pass_fds=(log_file.fileno(),),
127+
)
128+
129+
time.sleep(1) # Wait for server to start
130+
try:
131+
yield process, port
132+
finally:
133+
# Cleanup
134+
if process.poll() is None:
135+
process.terminate()
136+
process.wait(timeout=10)
137+
log_file.close()
138+
139+
140+
def consume_responses(responses: Iterator, wait: bool = False) -> None:
80141
def _consume():
81142
try:
82143
for response in responses:
@@ -87,6 +148,8 @@ def _consume():
87148

88149
response_thread = threading.Thread(target=_consume, daemon=True)
89150
response_thread.start()
151+
if wait:
152+
response_thread.join()
90153

91154

92155
def test_shutdown_with_terminate(servicer):
@@ -180,5 +243,103 @@ def handle_term(signum, frame):
180243
), "Function should have received SIGTERM and created the file"
181244

182245

183-
if __name__ == "__main__":
184-
pytest.main([__file__, "-v"])
246+
@pytest.mark.parametrize(
247+
"idle_timeout_seconds",
248+
[0, 2],
249+
)
250+
def test_idle_timeout_no_request(isolate_agent_subprocess, idle_timeout_seconds):
251+
process, port = isolate_agent_subprocess
252+
253+
p = psutil.Process(process.pid)
254+
for _ in range(10):
255+
if p.is_running():
256+
break
257+
time.sleep(1)
258+
else:
259+
assert False, "Agent should be running"
260+
261+
# Wait for the idle timeout to trigger
262+
try:
263+
p.wait(timeout=5)
264+
terminated = True
265+
except psutil.TimeoutExpired:
266+
terminated = False
267+
268+
if idle_timeout_seconds == 0:
269+
assert not terminated, "Agent should not have terminated"
270+
else:
271+
assert terminated, "Agent should have terminated after idle timeout"
272+
273+
274+
@pytest.mark.parametrize(
275+
"idle_timeout_seconds",
276+
[0, 2],
277+
)
278+
def test_idle_timeout(isolate_agent_subprocess, idle_timeout_seconds):
279+
process, port = isolate_agent_subprocess
280+
281+
p = psutil.Process(process.pid)
282+
for _ in range(10):
283+
if p.is_running():
284+
break
285+
time.sleep(1)
286+
else:
287+
assert False, "Agent should be running"
288+
289+
channel = grpc.insecure_channel(f"localhost:{port}")
290+
stub = AgentStub(channel)
291+
292+
def fn():
293+
import time
294+
295+
time.sleep(3) # longer than the idle timeout
296+
print("Function finished")
297+
298+
responses = stub.Run(FunctionCall(function=to_serialized_object(fn, method="dill")))
299+
consume_responses(responses, wait=True)
300+
301+
# Wait for the idle timeout to trigger
302+
try:
303+
p.wait(timeout=5)
304+
terminated = True
305+
except psutil.TimeoutExpired:
306+
terminated = False
307+
308+
if idle_timeout_seconds == 0:
309+
assert not terminated, "Agent should not have terminated"
310+
else:
311+
assert terminated, "Agent should have terminated after idle timeout"
312+
313+
314+
@pytest.mark.parametrize(
315+
["single_use", "idle_timeout_seconds"],
316+
[(False, 2)], # to prevent the server from shutting down automatically
317+
)
318+
def test_idle_timeout_server_handle(isolate_server_subprocess):
319+
process, port = isolate_server_subprocess
320+
channel = grpc.insecure_channel(f"localhost:{port}")
321+
stub = IsolateStub(channel)
322+
323+
def fn():
324+
import time
325+
326+
time.sleep(5) # longer than the idle timeout
327+
print("Function finished")
328+
329+
responses = stub.Run(create_run_request(fn))
330+
consume_responses(responses, wait=True)
331+
332+
# Send the first request to start the agent
333+
p = psutil.Process(process.pid)
334+
assert len(p.children()) == 1, "Server should have one agent process"
335+
336+
# Wait for the idle timeout to trigger
337+
time.sleep(3)
338+
assert (
339+
len(p.children()) == 0
340+
), "Agent process should have terminated after idle timeout"
341+
342+
# Server should be able to handle a new request after the idle timeout
343+
responses = stub.Run(create_run_request(fn))
344+
consume_responses(responses, wait=True)
345+
assert len(p.children()) == 1, "Server should have one agent process"

0 commit comments

Comments
 (0)