Skip to content

Commit e90156c

Browse files
committed
fix(client): monitor proxier death for specific servers
Signed-off-by: JerryLee <223425819+Jerry2003826@users.noreply.github.com>
1 parent 6d83f4a commit e90156c

4 files changed

Lines changed: 234 additions & 17 deletions

File tree

python/ray/_private/services.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ def extract_ip_port(bootstrap_address: str):
750750
ip_port = parse_address(bootstrap_address)
751751
if ip_port is None:
752752
raise ValueError(
753-
f"Malformed address {bootstrap_address}. " f"Expected '<host>:<port>'."
753+
f"Malformed address {bootstrap_address}. Expected '<host>:<port>'."
754754
)
755755
ip, port = ip_port
756756
try:
@@ -759,8 +759,7 @@ def extract_ip_port(bootstrap_address: str):
759759
raise ValueError(f"Malformed address port {port}. Must be an integer.")
760760
if port < 1024 or port > 65535:
761761
raise ValueError(
762-
f"Invalid address port {port}. Must be between 1024 "
763-
"and 65535 (inclusive)."
762+
f"Invalid address port {port}. Must be between 1024 and 65535 (inclusive)."
764763
)
765764
return ip, port
766765

@@ -1053,21 +1052,35 @@ def preexec_fn():
10531052
total_chrs = sum([len(x) for x in command])
10541053
if total_chrs > 31766:
10551054
raise ValueError(
1056-
f"command is limited to a total of 31767 characters, "
1057-
f"got {total_chrs}"
1055+
f"command is limited to a total of 31767 characters, got {total_chrs}"
10581056
)
10591057

1060-
process = ConsolePopen(
1061-
command,
1062-
env=modified_env,
1063-
cwd=cwd,
1064-
stdout=stdout_file,
1065-
stderr=stderr_file,
1066-
stdin=subprocess.PIPE if pipe_stdin else None,
1067-
preexec_fn=(None if sys.platform == "win32" or use_posix_spawn else preexec_fn),
1068-
close_fds=close_fds,
1069-
creationflags=CREATE_SUSPENDED if win32_fate_sharing else 0,
1058+
previous_sigmask = None
1059+
should_block_sigint_for_spawn = (
1060+
use_posix_spawn
1061+
and hasattr(signal, "pthread_sigmask")
1062+
and hasattr(signal, "SIG_BLOCK")
1063+
and hasattr(signal, "SIG_SETMASK")
10701064
)
1065+
if should_block_sigint_for_spawn:
1066+
previous_sigmask = signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})
1067+
try:
1068+
process = ConsolePopen(
1069+
command,
1070+
env=modified_env,
1071+
cwd=cwd,
1072+
stdout=stdout_file,
1073+
stderr=stderr_file,
1074+
stdin=subprocess.PIPE if pipe_stdin else None,
1075+
preexec_fn=(
1076+
None if sys.platform == "win32" or use_posix_spawn else preexec_fn
1077+
),
1078+
close_fds=close_fds,
1079+
creationflags=CREATE_SUSPENDED if win32_fate_sharing else 0,
1080+
)
1081+
finally:
1082+
if previous_sigmask is not None:
1083+
signal.pthread_sigmask(signal.SIG_SETMASK, previous_sigmask)
10711084

10721085
if win32_fate_sharing:
10731086
try:
@@ -2495,9 +2508,12 @@ def start_ray_client_server(
24952508
# multi-threaded gRPC server. Avoid a fork+preexec path there: gRPC may have
24962509
# active poller threads and can skip fork handlers, leaving the child to
24972510
# crash before it opens its channel. Specific servers self-terminate after
2498-
# being idle and are also cleaned up by the proxier, so they can trade
2499-
# kernel fate sharing and the preexec SIGINT mask for a fork-safe spawn path.
2511+
# being idle, monitor stdin EOF from setup_worker for abnormal parent death,
2512+
# and inherit a temporarily-blocked SIGINT mask from the spawning thread, so
2513+
# they can trade kernel fate sharing for a fork-safe spawn path.
25002514
process_fate_share = False if use_posix_spawn else fate_share
2515+
if use_posix_spawn:
2516+
command.append("--monitor-parent-pipe")
25012517

25022518
process_info = start_ray_process(
25032519
command,
@@ -2506,6 +2522,7 @@ def start_ray_client_server(
25062522
stderr_file=stderr_file,
25072523
fate_share=process_fate_share,
25082524
env_updates=env_updates,
2525+
pipe_stdin=use_posix_spawn,
25092526
use_posix_spawn=use_posix_spawn,
25102527
)
25112528
return process_info

python/ray/_private/workers/setup_worker.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import argparse
22
import logging
3+
import os
4+
import sys
5+
import threading
36

47
from ray._private.ray_constants import LOGGER_FORMAT, LOGGER_LEVEL
58
from ray._private.ray_logging import setup_logger
@@ -20,10 +23,43 @@
2023

2124
parser.add_argument("--language", type=str, help="the language type of the worker")
2225

26+
parser.add_argument(
27+
"--monitor-parent-pipe",
28+
required=False,
29+
action="store_true",
30+
help="Internal: exit when inherited stdin pipe reaches EOF.",
31+
)
32+
33+
34+
def _start_parent_pipe_monitor(enabled: bool, exit_func=os._exit):
35+
if not enabled:
36+
return None
37+
stdin = getattr(sys.stdin, "buffer", sys.stdin)
38+
39+
def monitor_parent_pipe():
40+
try:
41+
while stdin.read(1):
42+
pass
43+
except Exception:
44+
logger.exception("Failed while monitoring parent pipe.")
45+
logger.info("Parent pipe closed; exiting worker setup process.")
46+
exit_func(0)
47+
48+
thread = threading.Thread(
49+
target=monitor_parent_pipe,
50+
name="ray-worker-parent-pipe-monitor",
51+
daemon=True,
52+
)
53+
thread.start()
54+
return thread
55+
2356

2457
if __name__ == "__main__":
2558
setup_logger(LOGGER_LEVEL, LOGGER_FORMAT)
2659
args, remaining_args = parser.parse_known_args()
60+
_start_parent_pipe_monitor(args.monitor_parent_pipe)
61+
if args.monitor_parent_pipe:
62+
remaining_args.append("--monitor-parent-pipe")
2763
# NOTE(edoakes): args.serialized_runtime_env_context is only None when
2864
# we're starting the main Ray client proxy server. That case should
2965
# probably not even go through this codepath.

python/ray/tests/test_debug_tools.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import signal
23
import subprocess
34
import sys
45
from pathlib import Path
@@ -8,6 +9,7 @@
89
import ray
910
import ray._private.ray_constants as ray_constants
1011
import ray._private.services as services
12+
import ray.util.client.server.server as ray_client_server
1113
from ray._common.test_utils import wait_for_condition
1214

1315

@@ -230,10 +232,91 @@ def fake_start_ray_process(command, process_type, **kwargs):
230232
assert process_info is expected_process_info
231233
assert captured["process_type"] == ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER
232234
assert "--mode=specific-server" in captured["command"]
235+
assert "--monitor-parent-pipe" in captured["command"]
233236
assert captured["kwargs"]["fate_share"] is False
237+
assert captured["kwargs"]["pipe_stdin"] is True
234238
assert captured["kwargs"]["use_posix_spawn"] is True
235239

236240

241+
def test_setup_worker_parent_pipe_monitor_exits_on_eof(monkeypatch):
242+
from ray._private.workers import setup_worker
243+
244+
class FakeStdin:
245+
def read(self, size):
246+
return ""
247+
248+
exit_codes = []
249+
with monkeypatch.context() as m:
250+
m.setattr(setup_worker.sys, "stdin", FakeStdin())
251+
thread = setup_worker._start_parent_pipe_monitor(True, exit_codes.append)
252+
assert thread is not None
253+
thread.join(timeout=1)
254+
255+
assert exit_codes == [0]
256+
assert setup_worker._start_parent_pipe_monitor(False, exit_codes.append) is None
257+
258+
259+
def test_ray_client_specific_server_parent_pipe_monitor_exits_on_eof(monkeypatch):
260+
class FakeStdin:
261+
def read(self, size):
262+
return ""
263+
264+
exit_codes = []
265+
with monkeypatch.context() as m:
266+
m.setattr(ray_client_server.sys, "stdin", FakeStdin())
267+
thread = ray_client_server._start_parent_pipe_monitor(True, exit_codes.append)
268+
assert thread is not None
269+
thread.join(timeout=1)
270+
271+
assert exit_codes == [0]
272+
assert (
273+
ray_client_server._start_parent_pipe_monitor(False, exit_codes.append) is None
274+
)
275+
276+
277+
def test_ray_client_specific_server_blocks_sigint(monkeypatch):
278+
calls = []
279+
sig_block = object()
280+
281+
def fake_pthread_sigmask(how, signals):
282+
calls.append((how, signals))
283+
284+
with monkeypatch.context() as m:
285+
m.setattr(
286+
ray_client_server.signal,
287+
"pthread_sigmask",
288+
fake_pthread_sigmask,
289+
raising=False,
290+
)
291+
m.setattr(ray_client_server.signal, "SIG_BLOCK", sig_block, raising=False)
292+
293+
ray_client_server._block_sigint_for_specific_server()
294+
295+
assert calls == [(sig_block, {signal.SIGINT})]
296+
297+
298+
def test_ray_client_specific_server_sigint_block_noops_without_posix_signal_support(
299+
monkeypatch,
300+
):
301+
calls = []
302+
303+
def fake_pthread_sigmask(how, signals):
304+
calls.append((how, signals))
305+
306+
with monkeypatch.context() as m:
307+
m.setattr(
308+
ray_client_server.signal,
309+
"pthread_sigmask",
310+
fake_pthread_sigmask,
311+
raising=False,
312+
)
313+
m.delattr(ray_client_server.signal, "SIG_BLOCK", raising=False)
314+
315+
ray_client_server._block_sigint_for_specific_server()
316+
317+
assert calls == []
318+
319+
237320
def test_start_ray_process_posix_spawn_close_fds_when_supported(monkeypatch):
238321
captured = {}
239322
expected_process = object()
@@ -260,6 +343,46 @@ def fake_console_popen(command, **kwargs):
260343
assert captured["kwargs"]["close_fds"] is True
261344

262345

346+
def test_start_ray_process_posix_spawn_blocks_sigint_for_child(monkeypatch):
347+
captured = {}
348+
expected_process = object()
349+
calls = []
350+
previous_mask = {signal.SIGTERM}
351+
sig_block = object()
352+
sig_setmask = object()
353+
354+
def fake_console_popen(command, **kwargs):
355+
captured["command"] = command
356+
captured["kwargs"] = kwargs
357+
return expected_process
358+
359+
def fake_pthread_sigmask(how, signals):
360+
calls.append((how, signals))
361+
return previous_mask
362+
363+
with monkeypatch.context() as m:
364+
m.setattr(services.sys, "platform", "linux")
365+
m.setattr(services.os, "POSIX_SPAWN_CLOSEFROM", object(), raising=False)
366+
m.setattr(services.signal, "pthread_sigmask", fake_pthread_sigmask)
367+
m.setattr(services.signal, "SIG_BLOCK", sig_block, raising=False)
368+
m.setattr(services.signal, "SIG_SETMASK", sig_setmask, raising=False)
369+
m.setattr(services, "ConsolePopen", fake_console_popen)
370+
371+
process_info = services.start_ray_process(
372+
[sys.executable],
373+
ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER,
374+
fate_share=False,
375+
use_posix_spawn=True,
376+
)
377+
378+
assert process_info.process is expected_process
379+
assert captured["kwargs"]["preexec_fn"] is None
380+
assert calls == [
381+
(sig_block, {signal.SIGINT}),
382+
(sig_setmask, previous_mask),
383+
]
384+
385+
263386
def test_start_ray_process_posix_spawn_leaves_fds_open_for_older_runtime(
264387
monkeypatch,
265388
):

python/ray/util/client/server/server.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import os
99
import pickle
1010
import queue
11+
import signal
12+
import sys
1113
import threading
1214
import time
1315
from collections import defaultdict
@@ -851,6 +853,34 @@ def ray_connect_handler(job_config: JobConfig = None, **ray_init_kwargs):
851853
return ray_connect_handler
852854

853855

856+
def _block_sigint_for_specific_server() -> None:
857+
if hasattr(signal, "pthread_sigmask") and hasattr(signal, "SIG_BLOCK"):
858+
signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})
859+
860+
861+
def _start_parent_pipe_monitor(enabled: bool, exit_func=os._exit):
862+
if not enabled:
863+
return None
864+
stdin = getattr(sys.stdin, "buffer", sys.stdin)
865+
866+
def monitor_parent_pipe():
867+
try:
868+
while stdin.read(1):
869+
pass
870+
except Exception:
871+
logger.exception("Failed while monitoring specific-server parent pipe.")
872+
logger.info("Specific Ray Client server parent pipe closed; shutting down.")
873+
exit_func(0)
874+
875+
thread = threading.Thread(
876+
target=monitor_parent_pipe,
877+
name="ray-client-specific-server-parent-monitor",
878+
daemon=True,
879+
)
880+
thread.start()
881+
return thread
882+
883+
854884
def try_create_gcs_client(address: Optional[str]) -> Optional[GcsClient]:
855885
"""
856886
Try to create a gcs client based on the command line args or by
@@ -900,9 +930,20 @@ def main():
900930
default=None,
901931
help="The hex ID of this node.",
902932
)
933+
parser.add_argument(
934+
"--monitor-parent-pipe",
935+
required=False,
936+
action="store_true",
937+
help="Internal: exit when inherited stdin pipe reaches EOF.",
938+
)
903939
args, _ = parser.parse_known_args()
904940
redis_password = os.environ.get(ray_constants.RAY_REDIS_PASSWORD_ENV)
905941
setup_logger(ray_constants.LOGGER_LEVEL, ray_constants.LOGGER_FORMAT)
942+
_start_parent_pipe_monitor(
943+
args.mode == "specific-server" and args.monitor_parent_pipe
944+
)
945+
if args.mode == "specific-server":
946+
_block_sigint_for_specific_server()
906947

907948
ray_connect_handler = create_ray_handler(
908949
args.address, redis_password, args.redis_username

0 commit comments

Comments
 (0)