diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 4935a910031f..4521da06f2a8 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -116,6 +116,11 @@ ) +def _posix_spawn_can_close_fds() -> bool: + """Return whether CPython can use posix_spawn with close_fds=True.""" + return hasattr(os, "POSIX_SPAWN_CLOSEFROM") + + def _site_flags() -> List[str]: """Detect whether flags related to site packages are enabled for the current interpreter. To run Ray in hermetic build environments, it helps to pass these flags @@ -745,7 +750,7 @@ def extract_ip_port(bootstrap_address: str): ip_port = parse_address(bootstrap_address) if ip_port is None: raise ValueError( - f"Malformed address {bootstrap_address}. " f"Expected ':'." + f"Malformed address {bootstrap_address}. Expected ':'." ) ip, port = ip_port try: @@ -754,8 +759,7 @@ def extract_ip_port(bootstrap_address: str): raise ValueError(f"Malformed address port {port}. Must be an integer.") if port < 1024 or port > 65535: raise ValueError( - f"Invalid address port {port}. Must be between 1024 " - "and 65535 (inclusive)." + f"Invalid address port {port}. Must be between 1024 and 65535 (inclusive)." ) return ip, port @@ -867,6 +871,7 @@ def start_ray_process( stdout_file: Optional[IO[AnyStr]] = None, stderr_file: Optional[IO[AnyStr]] = None, pipe_stdin: bool = False, + use_posix_spawn: bool = False, ): """Start one of the Ray processes. @@ -898,6 +903,12 @@ def start_ray_process( no redirection should happen, then this should be None. pipe_stdin: If true, subprocess.PIPE will be passed to the process as stdin. + use_posix_spawn: If true on POSIX, avoid preexec_fn so CPython can use + its posix_spawn fast path. On runtimes that support closing file + descriptors from posix_spawn, keep close_fds=True. Older runtimes + need close_fds=False to stay off the fork path. This also skips + Ray's SIGINT-masking preexec hook, so it is only safe for + subprocesses that do not need fate sharing or that signal mask. Returns: Information about the process that was started including a handle to @@ -963,6 +974,7 @@ def start_ray_process( env_updates = {} if not isinstance(env_updates, dict): raise ValueError("The 'env_updates' argument must be a dictionary.") + use_posix_spawn = use_posix_spawn and sys.platform != "win32" modified_env = os.environ.copy() modified_env.update(env_updates) @@ -1015,6 +1027,9 @@ def start_ray_process( "kernel-level fate-sharing must only be specified if " "detect_fate_sharing_support() has returned True" ) + if use_posix_spawn and fate_share: + raise ValueError("'use_posix_spawn' cannot be combined with 'fate_share'.") + close_fds = not use_posix_spawn or _posix_spawn_can_close_fds() def preexec_fn(): import signal @@ -1037,20 +1052,35 @@ def preexec_fn(): total_chrs = sum([len(x) for x in command]) if total_chrs > 31766: raise ValueError( - f"command is limited to a total of 31767 characters, " - f"got {total_chrs}" + f"command is limited to a total of 31767 characters, got {total_chrs}" ) - process = ConsolePopen( - command, - env=modified_env, - cwd=cwd, - stdout=stdout_file, - stderr=stderr_file, - stdin=subprocess.PIPE if pipe_stdin else None, - preexec_fn=preexec_fn if sys.platform != "win32" else None, - creationflags=CREATE_SUSPENDED if win32_fate_sharing else 0, + previous_sigmask = None + should_block_sigint_for_spawn = ( + use_posix_spawn + and hasattr(signal, "pthread_sigmask") + and hasattr(signal, "SIG_BLOCK") + and hasattr(signal, "SIG_SETMASK") ) + if should_block_sigint_for_spawn: + previous_sigmask = signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT}) + try: + process = ConsolePopen( + command, + env=modified_env, + cwd=cwd, + stdout=stdout_file, + stderr=stderr_file, + stdin=subprocess.PIPE if pipe_stdin else None, + preexec_fn=( + None if sys.platform == "win32" or use_posix_spawn else preexec_fn + ), + close_fds=close_fds, + creationflags=CREATE_SUSPENDED if win32_fate_sharing else 0, + ) + finally: + if previous_sigmask is not None: + signal.pthread_sigmask(signal.SIG_SETMASK, previous_sigmask) if win32_fate_sharing: try: @@ -2473,13 +2503,27 @@ def start_ray_client_server( if node_id: command.append(f"--node-id={node_id}") + use_posix_spawn = server_type == "specific-server" and sys.platform != "win32" + # Specific Ray Client servers are spawned by the proxier, which is itself a + # multi-threaded gRPC server. Avoid a fork+preexec path there: gRPC may have + # active poller threads and can skip fork handlers, leaving the child to + # crash before it opens its channel. Specific servers self-terminate after + # being idle, monitor stdin EOF from setup_worker for abnormal parent death, + # and inherit a temporarily-blocked SIGINT mask from the spawning thread, so + # they can trade kernel fate sharing for a fork-safe spawn path. + process_fate_share = False if use_posix_spawn else fate_share + if use_posix_spawn: + command.append("--monitor-parent-pipe") + process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER, stdout_file=stdout_file, stderr_file=stderr_file, - fate_share=fate_share, + fate_share=process_fate_share, env_updates=env_updates, + pipe_stdin=use_posix_spawn, + use_posix_spawn=use_posix_spawn, ) return process_info diff --git a/python/ray/_private/workers/setup_worker.py b/python/ray/_private/workers/setup_worker.py index 23ba980a5bb2..d738fa6650b6 100644 --- a/python/ray/_private/workers/setup_worker.py +++ b/python/ray/_private/workers/setup_worker.py @@ -1,5 +1,8 @@ import argparse import logging +import os +import subprocess +import sys from ray._private.ray_constants import LOGGER_FORMAT, LOGGER_LEVEL from ray._private.ray_logging import setup_logger @@ -8,6 +11,30 @@ logger = logging.getLogger(__name__) +_PARENT_PIPE_MONITOR_SCRIPT = """ +import os +import select +import signal +import sys + +target_pid = int(sys.argv[1]) + +while True: + readable, _, _ = select.select([sys.stdin], [], [], 1.0) + if readable: + data = os.read(sys.stdin.fileno(), 1) + if not data: + try: + os.kill(target_pid, signal.SIGTERM) + except ProcessLookupError: + pass + sys.exit(0) + try: + os.kill(target_pid, 0) + except ProcessLookupError: + sys.exit(0) +""" + parser = argparse.ArgumentParser( description=("Set up the environment for a Ray worker and launch the worker.") ) @@ -20,10 +47,30 @@ parser.add_argument("--language", type=str, help="the language type of the worker") +parser.add_argument( + "--monitor-parent-pipe", + required=False, + action="store_true", + help="Internal: exit when inherited stdin pipe reaches EOF.", +) + + +def _start_parent_pipe_monitor(enabled: bool): + if not enabled: + return None + stdin = getattr(sys.stdin, "buffer", sys.stdin) + return subprocess.Popen( + [sys.executable, "-c", _PARENT_PIPE_MONITOR_SCRIPT, str(os.getpid())], + stdin=stdin, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + if __name__ == "__main__": setup_logger(LOGGER_LEVEL, LOGGER_FORMAT) args, remaining_args = parser.parse_known_args() + _start_parent_pipe_monitor(args.monitor_parent_pipe) # NOTE(edoakes): args.serialized_runtime_env_context is only None when # we're starting the main Ray client proxy server. That case should # probably not even go through this codepath. diff --git a/python/ray/tests/test_debug_tools.py b/python/ray/tests/test_debug_tools.py index e1dc65e7d52b..8a3757ec39db 100644 --- a/python/ray/tests/test_debug_tools.py +++ b/python/ray/tests/test_debug_tools.py @@ -1,4 +1,5 @@ import os +import signal import subprocess import sys from pathlib import Path @@ -8,6 +9,7 @@ import ray import ray._private.ray_constants as ray_constants import ray._private.services as services +import ray.util.client.server.server as ray_client_server from ray._common.test_utils import wait_for_condition @@ -199,9 +201,216 @@ def fake_start_ray_process(command, process_type, **kwargs): assert captured["kwargs"]["env_updates"] == { ray_constants.RAY_REDIS_PASSWORD_ENV: "secret123" } + assert captured["kwargs"]["fate_share"] is False + assert captured["kwargs"]["use_posix_spawn"] is False assert ray_constants.RAY_REDIS_PASSWORD_ENV not in os.environ +def test_start_ray_client_specific_server_uses_fork_safe_spawn(monkeypatch): + captured = {} + expected_process_info = object() + + def fake_start_ray_process(command, process_type, **kwargs): + captured["command"] = command + captured["process_type"] = process_type + captured["kwargs"] = kwargs + return expected_process_info + + with monkeypatch.context() as m: + m.setattr(services.sys, "platform", "linux") + m.setattr(services, "start_ray_process", fake_start_ray_process) + + process_info = services.start_ray_client_server( + address="127.0.0.1:6379", + ray_client_server_ip="127.0.0.1", + ray_client_server_port=10001, + fate_share=True, + server_type="specific-server", + serialized_runtime_env_context="{}", + ) + + assert process_info is expected_process_info + assert captured["process_type"] == ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER + assert "--mode=specific-server" in captured["command"] + assert "--monitor-parent-pipe" in captured["command"] + assert captured["kwargs"]["fate_share"] is False + assert captured["kwargs"]["pipe_stdin"] is True + assert captured["kwargs"]["use_posix_spawn"] is True + + +def test_setup_worker_parent_pipe_monitor_starts_subprocess(monkeypatch): + from ray._private.workers import setup_worker + + captured = {} + expected_process = object() + fake_stdin = object() + + def fake_popen(command, **kwargs): + captured["command"] = command + captured["kwargs"] = kwargs + return expected_process + + with monkeypatch.context() as m: + m.setattr(setup_worker.sys, "stdin", fake_stdin) + m.setattr(setup_worker.os, "getpid", lambda: 12345) + m.setattr(setup_worker.subprocess, "Popen", fake_popen) + process = setup_worker._start_parent_pipe_monitor(True) + + assert process is expected_process + assert captured["command"] == [ + setup_worker.sys.executable, + "-c", + setup_worker._PARENT_PIPE_MONITOR_SCRIPT, + "12345", + ] + assert captured["kwargs"]["stdin"] is fake_stdin + assert captured["kwargs"]["stdout"] == setup_worker.subprocess.DEVNULL + assert captured["kwargs"]["stderr"] == setup_worker.subprocess.DEVNULL + assert setup_worker._start_parent_pipe_monitor(False) is None + + +def test_ray_client_specific_server_blocks_sigint(monkeypatch): + calls = [] + sig_block = object() + + def fake_pthread_sigmask(how, signals): + calls.append((how, signals)) + + with monkeypatch.context() as m: + m.setattr( + ray_client_server.signal, + "pthread_sigmask", + fake_pthread_sigmask, + raising=False, + ) + m.setattr(ray_client_server.signal, "SIG_BLOCK", sig_block, raising=False) + + ray_client_server._block_sigint_for_specific_server() + + assert calls == [(sig_block, {signal.SIGINT})] + + +def test_ray_client_specific_server_sigint_block_noops_without_posix_signal_support( + monkeypatch, +): + calls = [] + + def fake_pthread_sigmask(how, signals): + calls.append((how, signals)) + + with monkeypatch.context() as m: + m.setattr( + ray_client_server.signal, + "pthread_sigmask", + fake_pthread_sigmask, + raising=False, + ) + m.delattr(ray_client_server.signal, "SIG_BLOCK", raising=False) + + ray_client_server._block_sigint_for_specific_server() + + assert calls == [] + + +def test_start_ray_process_posix_spawn_close_fds_when_supported(monkeypatch): + captured = {} + expected_process = object() + + def fake_console_popen(command, **kwargs): + captured["command"] = command + captured["kwargs"] = kwargs + return expected_process + + with monkeypatch.context() as m: + m.setattr(services.sys, "platform", "linux") + m.setattr(services.os, "POSIX_SPAWN_CLOSEFROM", object(), raising=False) + m.setattr(services, "ConsolePopen", fake_console_popen) + + process_info = services.start_ray_process( + [sys.executable], + ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER, + fate_share=False, + use_posix_spawn=True, + ) + + assert process_info.process is expected_process + assert captured["kwargs"]["preexec_fn"] is None + assert captured["kwargs"]["close_fds"] is True + + +def test_start_ray_process_posix_spawn_blocks_sigint_for_child(monkeypatch): + captured = {} + expected_process = object() + calls = [] + previous_mask = {signal.SIGTERM} + sig_block = object() + sig_setmask = object() + + def fake_console_popen(command, **kwargs): + captured["command"] = command + captured["kwargs"] = kwargs + return expected_process + + def fake_pthread_sigmask(how, signals): + calls.append((how, signals)) + return previous_mask + + with monkeypatch.context() as m: + m.setattr(services.sys, "platform", "linux") + m.setattr(services.os, "POSIX_SPAWN_CLOSEFROM", object(), raising=False) + m.setattr( + services.signal, + "pthread_sigmask", + fake_pthread_sigmask, + raising=False, + ) + m.setattr(services.signal, "SIG_BLOCK", sig_block, raising=False) + m.setattr(services.signal, "SIG_SETMASK", sig_setmask, raising=False) + m.setattr(services, "ConsolePopen", fake_console_popen) + + process_info = services.start_ray_process( + [sys.executable], + ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER, + fate_share=False, + use_posix_spawn=True, + ) + + assert process_info.process is expected_process + assert captured["kwargs"]["preexec_fn"] is None + assert calls == [ + (sig_block, {signal.SIGINT}), + (sig_setmask, previous_mask), + ] + + +def test_start_ray_process_posix_spawn_leaves_fds_open_for_older_runtime( + monkeypatch, +): + captured = {} + expected_process = object() + + def fake_console_popen(command, **kwargs): + captured["command"] = command + captured["kwargs"] = kwargs + return expected_process + + with monkeypatch.context() as m: + m.setattr(services.sys, "platform", "linux") + m.delattr(services.os, "POSIX_SPAWN_CLOSEFROM", raising=False) + m.setattr(services, "ConsolePopen", fake_console_popen) + + process_info = services.start_ray_process( + [sys.executable], + ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER, + fate_share=False, + use_posix_spawn=True, + ) + + assert process_info.process is expected_process + assert captured["kwargs"]["preexec_fn"] is None + assert captured["kwargs"]["close_fds"] is False + + if __name__ == "__main__": # Make subprocess happy in bazel. os.environ["LC_ALL"] = "en_US.UTF-8" diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 796421f5a7aa..7ae0fcc1c9f4 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -8,6 +8,7 @@ import os import pickle import queue +import signal import threading import time from collections import defaultdict @@ -851,6 +852,11 @@ def ray_connect_handler(job_config: JobConfig = None, **ray_init_kwargs): return ray_connect_handler +def _block_sigint_for_specific_server() -> None: + if hasattr(signal, "pthread_sigmask") and hasattr(signal, "SIG_BLOCK"): + signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT}) + + def try_create_gcs_client(address: Optional[str]) -> Optional[GcsClient]: """ Try to create a gcs client based on the command line args or by @@ -903,6 +909,8 @@ def main(): args, _ = parser.parse_known_args() redis_password = os.environ.get(ray_constants.RAY_REDIS_PASSWORD_ENV) setup_logger(ray_constants.LOGGER_LEVEL, ray_constants.LOGGER_FORMAT) + if args.mode == "specific-server": + _block_sigint_for_specific_server() ray_connect_handler = create_ray_handler( args.address, redis_password, args.redis_username