diff --git a/ci/docker/serve.build.Dockerfile b/ci/docker/serve.build.Dockerfile index bd5d07c90f6d..c37f49b6048e 100644 --- a/ci/docker/serve.build.Dockerfile +++ b/ci/docker/serve.build.Dockerfile @@ -11,6 +11,46 @@ SHELL ["/bin/bash", "-ice"] COPY . . +# Install HAProxy from source +RUN < None: + """Broadcast target groups over long poll if they have changed. + + Keeps an in-memory record of the last target groups that were broadcast + to determine if they have changed. + """ + target_groups: List[TargetGroup] = self.get_target_groups( + from_proxy_manager=True, + ) + + # Check if target groups have changed by comparing the objects directly + if self._last_broadcasted_target_groups == target_groups: + return + + self.long_poll_host.notify_changed( + {LongPollNamespace.TARGET_GROUPS: target_groups} + ) + self._last_broadcasted_target_groups = target_groups + def _create_control_loop_metrics(self): self.node_update_duration_gauge_s = metrics.Gauge( "serve_controller_node_update_duration_s", @@ -1364,9 +1400,16 @@ def get_target_groups( that have running replicas, we return target groups for direct ingress. If there are multiple applications with no running replicas, we return one target group per application with unique route prefix. + 5. HAProxy is enabled and the caller is not an internal proxy manager. In + this case, we return target groups containing the proxies (e.g. haproxy). + 6. HAProxy is enabled and the caller is an internal proxy manager (e.g. + haproxy manager). In this case, we return target groups containing the + ingress replicas and possibly the Serve proxies. """ proxy_target_groups = self._get_proxy_target_groups() - if not self._direct_ingress_enabled: + if not self._direct_ingress_enabled or ( + self._ha_proxy_enabled and not from_proxy_manager + ): return proxy_target_groups # Get all applications and their metadata @@ -1387,6 +1430,10 @@ def get_target_groups( ] if not apps: + # When HAProxy is enabled and there are no apps, return empty target groups + # so that all requests fall through to the default_backend (404) + if self._ha_proxy_enabled and from_proxy_manager: + return [] return proxy_target_groups # Create target groups for each application @@ -1496,7 +1543,7 @@ def _get_target_groups_for_app_with_no_running_replicas( TargetGroup( protocol=RequestProtocol.HTTP, route_prefix=route_prefix, - targets=http_targets, + targets=[] if self._ha_proxy_enabled else http_targets, app_name=app_name, ) ) @@ -1505,7 +1552,7 @@ def _get_target_groups_for_app_with_no_running_replicas( TargetGroup( protocol=RequestProtocol.GRPC, route_prefix=route_prefix, - targets=grpc_targets, + targets=[] if self._ha_proxy_enabled else grpc_targets, app_name=app_name, ) ) diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 526b0dc0a3e5..9670b6b1415b 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -18,6 +18,7 @@ ) from ray.serve._private.constants import ( CONTROLLER_MAX_CONCURRENCY, + RAY_SERVE_ENABLE_HA_PROXY, RAY_SERVE_ENABLE_TASK_EVENTS, RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING, RAY_SERVE_PROXY_USE_GRPC, @@ -255,3 +256,16 @@ def get_controller_impl(): )(ServeController) return controller_impl + + +def get_proxy_actor_class(): + # These imports are lazy to avoid circular imports + + if RAY_SERVE_ENABLE_HA_PROXY: + from ray.serve._private.haproxy import HAProxyManager + + return HAProxyManager + else: + from ray.serve._private.proxy import ProxyActor + + return ProxyActor diff --git a/python/ray/serve/_private/haproxy.py b/python/ray/serve/_private/haproxy.py new file mode 100644 index 000000000000..0737db03d40c --- /dev/null +++ b/python/ray/serve/_private/haproxy.py @@ -0,0 +1,1217 @@ +import asyncio +import csv +import fcntl +import io +import json +import logging +import os +import re +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set + +from jinja2 import Environment + +import ray +from ray._common.utils import get_or_create_event_loop +from ray.serve._private.common import ( + NodeId, + ReplicaID, + RequestMetadata, +) +from ray.serve._private.constants import ( + DRAINING_MESSAGE, + HEALTHY_MESSAGE, + NO_REPLICAS_MESSAGE, + NO_ROUTES_MESSAGE, + PROXY_MIN_DRAINING_PERIOD_S, + RAY_SERVE_ENABLE_HAPROXY_OPTIMIZED_CONFIG, + RAY_SERVE_HAPROXY_CONFIG_FILE_LOC, + RAY_SERVE_HAPROXY_HARD_STOP_AFTER_S, + RAY_SERVE_HAPROXY_HEALTH_CHECK_DOWNINTER, + RAY_SERVE_HAPROXY_HEALTH_CHECK_FALL, + RAY_SERVE_HAPROXY_HEALTH_CHECK_FASTINTER, + RAY_SERVE_HAPROXY_HEALTH_CHECK_INTER, + RAY_SERVE_HAPROXY_HEALTH_CHECK_RISE, + RAY_SERVE_HAPROXY_MAXCONN, + RAY_SERVE_HAPROXY_METRICS_PORT, + RAY_SERVE_HAPROXY_NBTHREAD, + RAY_SERVE_HAPROXY_SERVER_STATE_BASE, + RAY_SERVE_HAPROXY_SERVER_STATE_FILE, + RAY_SERVE_HAPROXY_SOCKET_PATH, + RAY_SERVE_HAPROXY_SYSLOG_PORT, + RAY_SERVE_HAPROXY_TIMEOUT_CLIENT_S, + RAY_SERVE_HAPROXY_TIMEOUT_CONNECT_S, + RAY_SERVE_HAPROXY_TIMEOUT_SERVER_S, + SERVE_CONTROLLER_NAME, + SERVE_LOGGER_NAME, + SERVE_NAMESPACE, +) +from ray.serve._private.haproxy_templates import ( + HAPROXY_CONFIG_TEMPLATE, + HAPROXY_HEALTHZ_RULES_TEMPLATE, +) +from ray.serve._private.logging_utils import get_component_logger_file_path +from ray.serve._private.long_poll import LongPollClient, LongPollNamespace +from ray.serve._private.proxy import ProxyActorInterface +from ray.serve.config import HTTPOptions, gRPCOptions +from ray.serve.schema import ( + LoggingConfig, + Target, + TargetGroup, +) + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +@dataclass +class ServerStats: + """Server statistics from HAProxy.""" + + backend: str # Which backend this server belongs to + server: str # Server name within the backend + status: str # Current status: "UP", "DOWN", "DRAIN", etc. + current_sessions: int # Active sessions (HAProxy 'scur') + queued: int # Queued requests (HAProxy 'qcur') + + @property + def is_up(self) -> bool: + return self.status == "UP" + + @property + def is_draining(self) -> bool: + return self.status in ["DRAIN", "NOLB"] + + @property + def can_drain_safely(self) -> bool: + """ + Return True if the server can be drained safely based on the current load. + Safe to drain when: + - No current active sessions (0) + - No queued requests waiting + This ensures no active user sessions are disrupted during draining. + """ + return self.current_sessions == 0 and self.queued == 0 + + +@dataclass +class HAProxyStats: + """Complete HAProxy statistics with both individual server data and aggregate metrics.""" + + # Individual server statistics by backend and server name + backend_to_servers: Dict[str, Dict[str, ServerStats]] = field(default_factory=dict) + + # Computed aggregate metrics (calculated from server data) + @property + def total_backends(self) -> int: + """Total number of backends.""" + return len(self.backend_to_servers) + + @property + def total_servers(self) -> int: + """Total number of servers across all backends.""" + return sum( + len(backend_servers) for backend_servers in self.backend_to_servers.values() + ) + + @property + def active_servers(self) -> int: + """Number of servers currently UP.""" + return sum( + 1 + for backend_servers in self.backend_to_servers.values() + for server in backend_servers.values() + if server.is_up + ) + + @property + def draining_servers(self) -> int: + """Number of servers currently draining.""" + return sum( + 1 + for backend_servers in self.backend_to_servers.values() + for server in backend_servers.values() + if server.is_draining + ) + + @property + def total_active_sessions(self) -> int: + """Sum of all active sessions across all servers.""" + return sum( + server.current_sessions + for backend_servers in self.backend_to_servers.values() + for server in backend_servers.values() + ) + + @property + def total_queued_requests(self) -> int: + """Sum of all queued requests across all servers.""" + return sum( + server.queued + for backend_servers in self.backend_to_servers.values() + for server in backend_servers.values() + ) + + @property + def is_system_idle(self) -> bool: + """Return True if the entire system has no active load.""" + return self.total_active_sessions == 0 and self.total_queued_requests == 0 + + @property + def draining_progress_pct(self) -> float: + """Return percentage of servers currently draining (0.0 to 100.0).""" + if self.total_servers == 0: + return 0.0 + return (self.draining_servers / self.total_servers) * 100.0 + + +@dataclass +class HealthRouteInfo: + """Information regarding how proxy should respond to health and routes requests.""" + + healthy: bool = True + status: int = 200 + health_message: str = HEALTHY_MESSAGE + routes_message: str = "{}" + routes_content_type: str = "application/json" + + +@dataclass +class ServerConfig: + """Configuration for a single server.""" + + name: str # Server identifier for HAProxy config + host: str # IP/hostname to connect to + port: int # Port to connect to + + def __str__(self) -> str: + return f"ServerConfig(name='{self.name}', host='{self.host}', port={self.port})" + + def __repr__(self) -> str: + return str(self) + + +@dataclass +class BackendConfig: + """Configuration for a single application backend.""" + + # Name of the target group. + name: str + + # Path prefix for the target group. This will be used to route requests to the target group. + path_prefix: str + + # Maximum time HAProxy will wait for a successful TCP connection to be established with the backend server. + timeout_connect_s: Optional[int] = None + + # Maximum time that the backend server can be inactive while sending data back to HAProxy. + # This is also active during the initial response phase. + timeout_server_s: Optional[int] = None + + # Maximum time that the client can be inactive while sending data to HAProxy. + # This is active during the initial request phase. + timeout_client_s: Optional[int] = None + timeout_http_request_s: Optional[int] = None + + # Maximum time HAProxy will wait for a request in the queue. + timeout_queue_s: Optional[int] = None + + # Maximum time HAProxy will keep the connection alive. + # This has to be the same or greater than the client side keep-alive timeout. + timeout_http_keep_alive_s: Optional[int] = None + + # Control the inactivity timeout for established WebSocket connections. + # Without this setting, a WebSocket connection could be prematurely terminated by other, + # more general timeout settings like timeout client or timeout server, + # which are intended for the initial phases of a connection. + timeout_tunnel_s: Optional[int] = None + + # The number of consecutive failed health checks that must occur before a service instance is marked as unhealthy + health_check_fall: Optional[int] = None + + # Number of consecutive successful health checks required to mark an unhealthy service instance as healthy again + health_check_rise: Optional[int] = None + + # Interval, or the amount of time, between each health check attempt + health_check_inter: Optional[str] = None + + # The interval between two consecutive health checks when the server is in any of the transition states: UP - transitionally DOWN or DOWN - transitionally UP + health_check_fastinter: Optional[str] = None + + # The interval between two consecutive health checks when the server is in the DOWN state + health_check_downinter: Optional[str] = None + + # Endpoint path that the health check mechanism will send a request to. It's typically an HTTP path. + health_check_path: Optional[str] = "/-/healthz" + + # List of servers in this backend + servers: List[ServerConfig] = field(default_factory=list) + + # The app name for this backend. + app_name: str = field(default_factory=str) + + def build_health_check_config(self, global_config: "HAProxyConfig") -> dict: + """Build health check configuration for HAProxy backend. + + Returns a dict with: + - health_path: path for HTTP health checks (or None) + - default_server_directive: complete "default-server" line with all params + """ + # Resolve values: backend-specific overrides global defaults + fall = ( + self.health_check_fall + if self.health_check_fall is not None + else global_config.health_check_fall + ) + rise = ( + self.health_check_rise + if self.health_check_rise is not None + else global_config.health_check_rise + ) + inter = ( + self.health_check_inter + if self.health_check_inter is not None + else global_config.health_check_inter + ) + fastinter = ( + self.health_check_fastinter + if self.health_check_fastinter is not None + else global_config.health_check_fastinter + ) + downinter = ( + self.health_check_downinter + if self.health_check_downinter is not None + else global_config.health_check_downinter + ) + health_path = ( + self.health_check_path + if self.health_check_path is not None + else global_config.health_check_path + ) + + # Build default-server directive + parts = [] + + # Add optional fastinter/downinter only if provided + if fastinter is not None: + parts.append(f"fastinter {fastinter}") + if downinter is not None: + parts.append(f"downinter {downinter}") + + # Add required fall/rise/inter if any are set + if fall is not None: + parts.append(f"fall {fall}") + if rise is not None: + parts.append(f"rise {rise}") + if inter is not None: + parts.append(f"inter {inter}") + + # Always add check at the end + parts.append("check") + + default_server_directive = "default-server " + " ".join(parts) + + return { + "health_path": health_path, + "default_server_directive": default_server_directive, + } + + def __str__(self) -> str: + return f"BackendConfig(app_name='{self.app_name}', name='{self.name}', path_prefix='{self.path_prefix}', servers={self.servers})" + + def __repr__(self) -> str: + return str(self) + + +@dataclass +class HAProxyConfig: + """Configuration for HAProxy.""" + + socket_path: str = RAY_SERVE_HAPROXY_SOCKET_PATH + server_state_base: str = RAY_SERVE_HAPROXY_SERVER_STATE_BASE + server_state_file: str = RAY_SERVE_HAPROXY_SERVER_STATE_FILE + # Enable HAProxy optimizations (server state persistence, etc.) + # Disabled by default to prevent test suite interference + enable_hap_optimization: bool = RAY_SERVE_ENABLE_HAPROXY_OPTIMIZED_CONFIG + maxconn: int = RAY_SERVE_HAPROXY_MAXCONN + nbthread: int = RAY_SERVE_HAPROXY_NBTHREAD + stats_port: int = 8404 + stats_uri: str = "/stats" + metrics_port: int = RAY_SERVE_HAPROXY_METRICS_PORT + metrics_uri: str = "/metrics" + # All timeout values are in seconds + timeout_queue_s: Optional[int] = None + timeout_connect_s: Optional[int] = RAY_SERVE_HAPROXY_TIMEOUT_CONNECT_S + timeout_client_s: Optional[int] = RAY_SERVE_HAPROXY_TIMEOUT_CLIENT_S + timeout_server_s: Optional[int] = RAY_SERVE_HAPROXY_TIMEOUT_SERVER_S + timeout_http_request_s: Optional[int] = None + hard_stop_after_s: Optional[int] = RAY_SERVE_HAPROXY_HARD_STOP_AFTER_S + custom_global: Dict[str, str] = field(default_factory=dict) + custom_defaults: Dict[str, str] = field(default_factory=dict) + inject_process_id_header: bool = False + reload_id: Optional[str] = None # Unique ID for each reload + enable_so_reuseport: bool = ( + os.environ.get("SERVE_SOCKET_REUSE_PORT_ENABLED", "0") == "1" + ) + has_received_routes: bool = False + has_received_servers: bool = False + pass_health_checks: bool = True + health_check_endpoint: str = "/-/healthz" + # Global health check parameters (used as defaults for backends) + # Number of consecutive failed health checks that must occur before a service instance is marked as unhealthy + health_check_fall: Optional[int] = RAY_SERVE_HAPROXY_HEALTH_CHECK_FALL + + # Number of consecutive successful health checks required to mark an unhealthy service instance as healthy again + health_check_rise: Optional[int] = RAY_SERVE_HAPROXY_HEALTH_CHECK_RISE + + # Interval, or the amount of time, between each health check attempt + health_check_inter: Optional[str] = RAY_SERVE_HAPROXY_HEALTH_CHECK_INTER + + # The interval between two consecutive health checks when the server is in any of the transition states: UP - transitionally DOWN or DOWN - transitionally UP + health_check_fastinter: Optional[str] = RAY_SERVE_HAPROXY_HEALTH_CHECK_FASTINTER + + # The interval between two consecutive health checks when the server is in the DOWN state + health_check_downinter: Optional[str] = RAY_SERVE_HAPROXY_HEALTH_CHECK_DOWNINTER + + health_check_path: Optional[str] = "/-/healthz" # For HTTP health checks + + http_options: HTTPOptions = field(default_factory=HTTPOptions) + + syslog_port: int = RAY_SERVE_HAPROXY_SYSLOG_PORT + + @property + def frontend_host(self) -> str: + if self.http_options.host is None or self.http_options.host == "0.0.0.0": + return "*" + + return self.http_options.host + + @property + def frontend_port(self) -> int: + return self.http_options.port + + @property + def timeout_http_keep_alive_s(self) -> int: + return self.http_options.keep_alive_timeout_s + + def build_health_route_info(self, backends: List[BackendConfig]) -> HealthRouteInfo: + if not self.has_received_routes: + router_ready_for_traffic = False + router_message = NO_ROUTES_MESSAGE + elif not self.has_received_servers: + router_ready_for_traffic = False + router_message = NO_REPLICAS_MESSAGE + else: + router_ready_for_traffic = True + router_message = "" + + if not self.pass_health_checks: + healthy = False + message = DRAINING_MESSAGE + elif not router_ready_for_traffic: + healthy = False + message = router_message + else: + healthy = True + message = HEALTHY_MESSAGE + + if healthy: + # Build routes JSON mapping: {"": "", ...} + routes = { + be.path_prefix: be.app_name + for be in backends + if be.app_name and be.path_prefix + } + routes_json = json.dumps(routes, separators=(",", ":"), ensure_ascii=False) + + # Escape for haproxy double-quoted string literal + routes_message = routes_json.replace("\\", "\\\\").replace('"', '\\"') + else: + routes_message = message + + return HealthRouteInfo( + healthy=healthy, + status=200 if healthy else 503, + health_message=message, + routes_message=routes_message, + routes_content_type="application/json" if healthy else "text/plain", + ) + + # TODO: support custom root_path and https + + +class ProxyApi(ABC): + """Generic interface for load balancer management operations.""" + + @abstractmethod + async def start(self) -> None: + """Initializes proxy configuration files.""" + pass + + @abstractmethod + async def get_all_stats(self) -> Dict[str, Dict[str, ServerStats]]: + """Get statistics for all servers in all backends.""" + pass + + @abstractmethod + async def stop(self) -> None: + """Stop the proxy.""" + pass + + @abstractmethod + async def disable(self) -> None: + """Disables the proxy from receiving any HTTP requests""" + pass + + @abstractmethod + async def enable(self) -> None: + """Enables the proxy from receiving any HTTP requests""" + pass + + @abstractmethod + async def reload(self) -> None: + """Gracefully reload the service.""" + pass + + +class HAProxyApi(ProxyApi): + """ProxyApi implementation for HAProxy.""" + + def __init__( + self, + cfg: HAProxyConfig, + backend_configs: Dict[str, BackendConfig] = None, + config_file_path: str = RAY_SERVE_HAPROXY_CONFIG_FILE_LOC, + ): + self.cfg = cfg + self.backend_configs = backend_configs or {} + self.config_file_path = config_file_path + # Lock to prevent concurrent config modifications + self._config_lock = asyncio.Lock() + self._proc = None + # Track old processes from graceful reloads that may still be draining + self._old_procs: List[asyncio.subprocess.Process] = [] + + # Ensure required directories exist during initialization + self._initialize_directories_and_error_files() + + def _initialize_directories_and_error_files(self) -> None: + """ + Ensures all required directories exist, creates a unified 500 error file, + and assigns its path to self.cfg.error_file_path. Called once during initialization. + """ + # Create a config file directory + config_dir = os.path.dirname(self.config_file_path) + os.makedirs(config_dir, exist_ok=True) + + # Create a socket directory + socket_dir = os.path.dirname(self.cfg.socket_path) + os.makedirs(socket_dir, exist_ok=True) + + # Create a server state directory only if optimization is enabled + if self.cfg.enable_hap_optimization: + server_state_dir = os.path.dirname(self.cfg.server_state_file) + os.makedirs(server_state_dir, exist_ok=True) + + # Create a single error file for both 502 and 504 errors + # Both will be normalized to 500 Internal Server Error + error_file_path = os.path.join(config_dir, "500.http") + with open(error_file_path, "w") as ef: + ef.write("HTTP/1.1 500 Internal Server Error\r\n") + ef.write("Content-Type: text/plain\r\n") + ef.write("Content-Length: 21\r\n") + ef.write("\r\n") + ef.write("Internal Server Error") + + self.cfg.error_file_path = error_file_path + + def _is_running(self) -> bool: + """Check if the HAProxy process is still running.""" + return self._proc is not None and self._proc.returncode is None + + async def _start_and_wait_for_haproxy( + self, *extra_args: str, timeout_s: int = 5 + ) -> asyncio.subprocess.Process: + # Build command args + args = ["haproxy", "-db", "-f", self.config_file_path] + + if not self.cfg.enable_so_reuseport: + args.append("-dR") + + # Add any extra args (like -sf for graceful reload) + args.extend(extra_args) + + logger.debug(f"Starting HAProxy with args: {args}") + + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + await self._wait_for_hap_availability(proc) + except Exception: + # If startup fails, ensure the process is killed to avoid orphaned processes + if proc.returncode is None: + proc.kill() + await proc.wait() + raise + + return proc + + async def _save_server_state(self) -> None: + """Save the server state to the file.""" + server_state = await self._send_socket_command("show servers state") + with open(self.cfg.server_state_file, "w") as f: + f.write(server_state) + + async def _graceful_reload(self) -> None: + """Perform a graceful reload of HAProxy by starting a new process with -sf.""" + try: + old_proc = self._proc + await self._wait_for_hap_availability(old_proc) + + # Save server state if optimization is enabled + if self.cfg.enable_hap_optimization: + await self._save_server_state() + + # Start new HAProxy process with -sf flag to gracefully take over from old process + # Use -x socket transfer for seamless reloads if optimization is enabled + reload_args = ["-sf", str(old_proc.pid)] + if self.cfg.enable_hap_optimization: + reload_args.extend(["-x", self.cfg.socket_path]) + + self._proc = await self._start_and_wait_for_haproxy(*reload_args) + + # Track old process so we can ensure it's cleaned up during shutdown + if old_proc is not None: + self._old_procs.append(old_proc) + + logger.info( + "Successfully performed graceful HAProxy reload with process restart." + ) + except Exception as e: + logger.error(f"HAProxy graceful reload failed: {e}") + raise + + async def _wait_for_hap_availability( + self, proc: asyncio.subprocess.Process, timeout_s: int = 5 + ) -> None: + start_time = time.time() + + # TODO: update this to use health checks + while time.time() - start_time < timeout_s: + if proc.returncode is not None: + stdout = await proc.stdout.read() if proc.stdout else b"" + stderr = await proc.stderr.read() if proc.stderr else b"" + output = ( + stderr.decode("utf-8", errors="ignore").strip() + or stdout.decode("utf-8", errors="ignore").strip() + ) + + raise RuntimeError( + f"HAProxy crashed during startup: {output or f'exit code {proc.returncode}'}" + ) + + if await self.is_running(): + return + + await asyncio.sleep(0.5) + + raise RuntimeError( + f"HAProxy did not enter running state within {timeout_s} seconds." + ) + + def _generate_config_file_internal(self) -> None: + """Internal config generation without locking (for use within locked sections).""" + try: + env = Environment() + + # Backends are sorted in decreasing order of length of path prefix + # to ensure that the longest path prefix match is taken first. + # Equal lengthed prefixes are then sorted alphabetically. + backends = sorted( + self.backend_configs.values(), + key=lambda be: (-len(be.path_prefix), be.path_prefix), + ) + + # Enrich backends with precomputed health check configuration strings + backends_with_health_config = [ + { + "backend": backend, + "health_config": backend.build_health_check_config(self.cfg), + } + for backend in backends + ] + + health_route_info = self.cfg.build_health_route_info(backends) + + # Render healthz rules separately for readability/reuse + healthz_template = env.from_string(HAPROXY_HEALTHZ_RULES_TEMPLATE) + healthz_rules = healthz_template.render( + { + "config": self.cfg, + "backends": backends, + "health_info": health_route_info, + } + ) + + config_template = env.from_string(HAPROXY_CONFIG_TEMPLATE) + config_content = config_template.render( + { + "config": self.cfg, + "backends": backends, + "backends_with_health_config": backends_with_health_config, + "healthz_rules": healthz_rules, + "route_info": health_route_info, + } + ) + + # Ensure the config ends with a newline + if not config_content.endswith("\n"): + config_content += "\n" + + # Use file locking to prevent concurrent writes from multiple processes + # This is important in test environments where multiple nodes may run + # on the same machine + with open(self.config_file_path, "w") as f: + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + try: + f.write(config_content) + finally: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + + logger.debug( + f"Succesfully generated HAProxy configuration: {self.config_file_path}." + ) + except Exception as e: + logger.error(f"Failed to create HAProxy configuration files: {e}") + raise + + async def start(self) -> None: + """ + Generate HAProxy configuration files and start the HAProxy server process. + + This method creates the necessary configuration files and launches the HAProxy + process in foreground mode, ensuring that the proxy is running with the latest + configuration and that the parent retains control of the subprocess handle. + """ + try: + async with self._config_lock: + # Set initial reload ID if header injection is enabled and ID is not set + if self.cfg.inject_process_id_header and self.cfg.reload_id is None: + self.cfg.reload_id = f"initial-{int(time.time() * 1000)}" + + self._generate_config_file_internal() + logger.info("Successfully generated HAProxy config file.") + + self._proc = await self._start_and_wait_for_haproxy() + logger.info("HAProxy started successfully.") + except Exception as e: + logger.error(f"Failed to initialize and start HAProxy configuration: {e}") + raise + + async def get_all_stats(self) -> Dict[str, Dict[str, ServerStats]]: + """Get statistics for all servers in all backends (implements abstract method). + + Returns only application backends configured in self.backend_configs, + excluding HAProxy internal components (frontends, default_backend, stats). + Also excludes BACKEND aggregate entries, returning only individual servers. + """ + try: + stats_output = await self._send_socket_command("show stat") + all_stats = self._parse_haproxy_csv_stats(stats_output) + + # Filter to only return application backends (ones in backend_configs) + # Exclude HAProxy internal components like frontends, default_backend, stats + # Also exclude BACKEND aggregate entries, keep only individual servers + return { + backend_name: { + server_name: stats + for server_name, stats in servers.items() + if server_name != "BACKEND" + } + for backend_name, servers in all_stats.items() + if backend_name in self.backend_configs + } + except Exception as e: + logger.error(f"Failed to get HAProxy stats: {e}") + return {} + + async def get_haproxy_stats(self) -> HAProxyStats: + """Get complete HAProxy statistics including both individual and aggregate data.""" + server_stats = await self.get_all_stats() + return HAProxyStats(backend_to_servers=server_stats) + + # TODO: use socket library instead of subprocess + async def _send_socket_command(self, command: str) -> str: + """Send a command to the HAProxy stats socket via subprocess.""" + try: + # Check if a socket file exists + if not os.path.exists(self.cfg.socket_path): + raise RuntimeError( + f"HAProxy socket file does not exist: {self.cfg.socket_path}." + ) + + proc = await asyncio.create_subprocess_exec( + "socat", + "-", + f"UNIX-CONNECT:{self.cfg.socket_path}", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(f"{command}\n".encode("utf-8")), timeout=5.0 + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + raise RuntimeError( + f"Timeout while sending command '{command}' to HAProxy socket" + ) + + if proc.returncode != 0: + err = stderr.decode("utf-8", errors="ignore").strip() + raise RuntimeError( + f"Command '{command}' failed with code {proc.returncode}: {err}" + ) + + result = stdout.decode("utf-8", errors="ignore") + logger.debug(f"Socket command '{command}' returned {len(result)} chars.") + return result + except Exception as e: + raise RuntimeError(f"Failed to send socket command '{command}': {e}") + + @staticmethod + def _parse_haproxy_csv_stats( + stats_output: str, + ) -> Dict[str, Dict[str, ServerStats]]: + """Parse HAProxy stats CSV output into structured data.""" + if not stats_output or not stats_output.strip(): + return {} + + # HAProxy stats start with '#' comment - replace with nothing for CSV parsing + csv_data = stats_output.replace("# ", "", 1) + backend_stats: Dict[str, Dict[str, ServerStats]] = {} + + def safe_int(v): + try: + return int(v) + except (TypeError, ValueError): + return 0 + + for row in csv.DictReader(io.StringIO(csv_data)): + backend = row.get("pxname", "").strip() + server = row.get("svname", "").strip() + status = row.get("status", "").strip() or "UNKNOWN" + + if not backend or not server: + continue + + backend_stats.setdefault(backend, {}) + backend_stats[backend][server] = ServerStats( + backend=backend, + server=server, + status=status, + current_sessions=safe_int(row.get("scur")), + queued=safe_int(row.get("qcur")), + ) + + return backend_stats + + async def stop(self) -> None: + proc = self._proc + if proc is None: + logger.info("HAProxy process not running, skipping shutdown.") + return + + try: + # Kill the current process + if proc.returncode is None: + proc.kill() + await proc.wait() + self._proc = None + + # Also kill any old processes from graceful reloads that might still be running + for old_proc in self._old_procs: + try: + if old_proc.returncode is None: + old_proc.kill() + await old_proc.wait() + logger.info(f"Killed old HAProxy process (PID: {old_proc.pid})") + except Exception as e: + logger.warning(f"Error killing old HAProxy process: {e}") + + self._old_procs.clear() + + logger.info("Stopped HAProxy process.") + except RuntimeError as e: + logger.error(f"Error during HAProxy shutdown: {e}") + + async def reload(self) -> None: + try: + self._generate_config_file_internal() + await self._graceful_reload() + except Exception as e: + raise RuntimeError(f"Failed to update and reload HAProxy: {e}") + + async def disable(self) -> None: + """Force haproxy health checks to fail.""" + try: + # Disable health checks (set to fail) + self.cfg.pass_health_checks = False + + # Regenerate the config file with the deny rule + self._generate_config_file_internal() + + # Perform a graceful reload to apply changes + await self._graceful_reload() + logger.info("Successfully disabled health checks.") + except Exception as e: + logger.error(f"Failed to disable health checks: {e}") + raise + + async def enable(self) -> None: + """Force haproxy health checks to pass.""" + try: + self.cfg.pass_health_checks = True + + self._generate_config_file_internal() + # Perform a graceful reload to apply changes + await self._graceful_reload() + logger.info("Successfully enabled health checks.") + except Exception as e: + logger.error(f"Failed to disable health checks: {e}") + raise + + def set_backend_configs( + self, + backend_configs: Dict[str, BackendConfig], + ) -> None: + if backend_configs: + self.cfg.has_received_routes = True + + self.backend_configs = backend_configs + + self.cfg.has_received_servers = self.cfg.has_received_servers or any( + len(bc.servers) > 0 for bc in backend_configs.values() + ) + + async def is_running(self) -> bool: + try: + await self._send_socket_command("show info") + return True + except Exception: + # During reload or shutdown, socket can be temporarily unavailable. + # Treat as unhealthy instead of raising. + return False + + +@ray.remote(num_cpus=0) +class HAProxyManager(ProxyActorInterface): + def __init__( + self, + http_options: HTTPOptions, + grpc_options: gRPCOptions, + *, + node_id: NodeId, + node_ip_address: str, + logging_config: LoggingConfig, + long_poll_client: Optional[LongPollClient] = None, + ): # noqa: F821 + super().__init__( + node_id=node_id, + node_ip_address=node_ip_address, + logging_config=logging_config, + # HAProxyManager is not on the request path, so we can disable + # the buffer to ensure logs are immediately flushed. + log_buffer_size=1, + ) + + self._grpc_options = grpc_options + self._http_options = http_options + + # The time when the node starts to drain. + # The node is not draining if it's None. + self._draining_start_time: Optional[float] = None + + self.event_loop = get_or_create_event_loop() + + self._target_groups: List[TargetGroup] = [] + + # Lock to serialize HAProxy reloads and prevent concurrent reload operations + # which can cause race conditions with SO_REUSEPORT + self._reload_lock = asyncio.Lock() + + self.long_poll_client = long_poll_client or LongPollClient( + ray.get_actor(SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE), + { + LongPollNamespace.GLOBAL_LOGGING_CONFIG: self._update_logging_config, + LongPollNamespace.TARGET_GROUPS: self.update_target_groups, + }, + call_in_event_loop=self.event_loop, + ) + + startup_msg = f"HAProxy starting on node {self._node_id} (HTTP port: {self._http_options.port})." + logger.info(startup_msg) + logger.debug( + f"Configure HAProxyManager actor {ray.get_runtime_context().get_actor_id()} " + f"logger with logging config: {logging_config}" + ) + + self._haproxy = HAProxyApi(cfg=HAProxyConfig(http_options=http_options)) + self._haproxy_start_task = self.event_loop.create_task(self._haproxy.start()) + + async def shutdown(self) -> None: + """Shutdown the HAProxyManager and clean up the HAProxy process. + + This method should be called before the actor is killed to ensure + the HAProxy subprocess is properly terminated. + """ + try: + logger.info( + f"Shutting down HAProxyManager on node {self._node_id}.", + extra={"log_to_stderr": False}, + ) + + await self._haproxy.stop() + + logger.info( + f"Successfully stopped HAProxy process on node {self._node_id}.", + extra={"log_to_stderr": False}, + ) + except Exception as e: + raise RuntimeError(f"Error stopping HAProxy during shutdown: {e}") + + async def ready(self) -> str: + try: + # Wait for haproxy to start. Internally, this starts the process and + # waits for it to be running by querying the stats socket. + await self._haproxy_start_task + except Exception as e: + logger.exception("Failed to start HAProxy.") + raise e from None + + # Return proxy metadata used by the controller. + # NOTE(zcin): We need to convert the metadata to a json string because + # of cross-language scenarios. Java can't deserialize a Python tuple. + return json.dumps( + [ + ray.get_runtime_context().get_worker_id(), + get_component_logger_file_path(), + ] + ) + + async def serving(self, wait_for_applications_running: bool = True) -> None: + """Wait for the HAProxy process to be ready to serve requests.""" + if not wait_for_applications_running: + return + + ready_to_serve = False + while not ready_to_serve: + try: + all_backends = set() + ready_backends = set() + stats = await self._haproxy.get_all_stats() + for backend, servers in stats.items(): + # The backend name is suffixed with the protocol. We omit + # grpc backends for now since they aren't supported yet. + if backend.lower().startswith("grpc"): + continue + all_backends.add(backend) + for server in servers.values(): + if server.is_up: + ready_backends.add(backend) + ready_to_serve = all_backends == ready_backends + except Exception: + pass + if not ready_to_serve: + await asyncio.sleep(0.2) + + def _is_draining(self) -> bool: + """Whether is haproxy is in the draining status or not.""" + return self._draining_start_time is not None + + async def update_draining( + self, draining: bool, _after: Optional[Any] = None + ) -> None: + """Update the draining status of the proxy. + + This is called by the proxy state manager + to drain or un-drain the haproxy. + """ + + if draining and (not self._is_draining()): + logger.info( + f"Start to drain the HAProxy on node {self._node_id}.", + extra={"log_to_stderr": False}, + ) + # Use the reload lock to serialize with other HAProxy reload operations + async with self._reload_lock: + await self._haproxy.disable() + self._draining_start_time = time.time() + if (not draining) and self._is_draining(): + logger.info( + f"Stop draining the HAProxy on node {self._node_id}.", + extra={"log_to_stderr": False}, + ) + # Use the reload lock to serialize with other HAProxy reload operations + async with self._reload_lock: + await self._haproxy.enable() + self._draining_start_time = None + + async def is_drained(self, _after: Optional[Any] = None) -> bool: + """Check whether the haproxy is drained or not. + + An haproxy is drained if it has no ongoing requests + AND it has been draining for more than + `PROXY_MIN_DRAINING_PERIOD_S` seconds. + """ + if not self._is_draining(): + return False + + haproxy_stats = await self._haproxy.get_haproxy_stats() + return haproxy_stats.is_system_idle and ( + (time.time() - self._draining_start_time) > PROXY_MIN_DRAINING_PERIOD_S + ) + + async def check_health(self) -> bool: + # If haproxy is already shutdown, return False. + if not self._haproxy or not self._haproxy._proc: + return False + + logger.debug("Received health check.", extra={"log_to_stderr": False}) + return await self._haproxy.is_running() + + def pong(self) -> str: + pass + + async def receive_asgi_messages(self, request_metadata: RequestMetadata) -> bytes: + raise NotImplementedError("Receive is handled by the ingress replicas.") + + def _get_http_options(self) -> HTTPOptions: + return self._http_options + + def _get_logging_config(self) -> Optional[str]: + """Get the logging configuration (for testing purposes).""" + log_file_path = None + for handler in logger.handlers: + if isinstance(handler, logging.handlers.MemoryHandler): + log_file_path = handler.target.baseFilename + + return log_file_path + + def _targets_to_servers(self, targets: List[Target]) -> List[ServerConfig]: + """Convert a list of targets to a list of servers.""" + # The server name is derived from the replica's actor name, with the + # format `SERVE_REPLICA::##`, or the + # proxy's actor name, with the format `SERVE_PROXY_ACTOR-`. + # Special characters in the names are converted to comply with haproxy + # config's allowed characters, e.g. `#` -> `-`. + return [ + ServerConfig( + name=self.get_safe_name(target.name), + # Use localhost if target is on the same node as HAProxy + host="127.0.0.1" if target.ip == self._node_ip_address else target.ip, + port=target.port, + ) + for target in targets + ] + + def _target_group_to_backend(self, target_group: TargetGroup) -> BackendConfig: + """Convert a target group to a backend name.""" + servers = self._targets_to_servers(target_group.targets) + # The name is lowercased and formatted as -. Special + # characters in the name are converted to comply with haproxy config's + # allowed characters, e.g. `#` -> `-`. + return BackendConfig( + name=self.get_safe_name( + f"{target_group.protocol.value.lower()}-{target_group.app_name}" + ), + path_prefix=target_group.route_prefix, + servers=servers, + app_name=target_group.app_name, + ) + + async def _reload_haproxy(self) -> None: + # To avoid dropping updates from a long poll, we wait until HAProxy + # is up and running before attempting to generate config and reload. + # Use lock to serialize reloads and prevent race conditions with SO_REUSEPORT + async with self._reload_lock: + await self._haproxy_start_task + await self._haproxy.reload() + + def update_target_groups(self, target_groups: List[TargetGroup]) -> None: + self._target_groups = target_groups + + backend_configs = [ + self._target_group_to_backend(target_group) + for target_group in target_groups + ] + + logger.info( + f"Got updated backend configs: {backend_configs}.", + extra={"log_to_stderr": True}, + ) + + name_to_backend_configs = { + backend_config.name: backend_config for backend_config in backend_configs + } + + self._haproxy.set_backend_configs(name_to_backend_configs) + self.event_loop.create_task(self._reload_haproxy()) + + def get_target_groups(self) -> List[TargetGroup]: + """Get current target groups.""" + return self._target_groups + + @staticmethod + def get_safe_name(name: str) -> str: + """Get a safe label name for the haproxy config.""" + name = name.replace("#", "-").replace("/", ".") + # replace all remaining non-alphanumeric and non-{".", "_", "-"} with "_" + return re.sub(r"[^A-Za-z0-9._-]+", "_", name) + + def _dump_ingress_replicas_for_testing(self, route: str) -> Set[ReplicaID]: + """Return the set of replica IDs for targets matching the given route. + + Args: + route: The route prefix to match against target groups. + + Returns: + Set of ReplicaID objects for targets in the matching target group. + """ + replica_ids = set() + + if self._target_groups is None: + return replica_ids + + for target_group in self._target_groups: + if target_group.route_prefix == route: + for target in target_group.targets: + # Target names are in the format "SERVE_REPLICA::##" + if ReplicaID.is_full_id_str(target.name): + replica_id = ReplicaID.from_full_id_str(target.name) + replica_ids.add(replica_id) + + return replica_ids + + def _dump_ingress_cache_for_testing(self, route: str) -> Set[ReplicaID]: + """Return replica IDs that are cached/ready for the given route (for testing). + + For HAProxy, all registered replicas are immediately ready for routing + (no warm-up cache like the internal router), so this returns the same + set as _dump_ingress_replicas_for_testing. + + Args: + route: The route prefix to match against target groups. + + Returns: + Set of ReplicaID objects for targets in the matching target group. + """ + return self._dump_ingress_replicas_for_testing(route) diff --git a/python/ray/serve/_private/haproxy_templates.py b/python/ray/serve/_private/haproxy_templates.py new file mode 100644 index 000000000000..fca93ae2388b --- /dev/null +++ b/python/ray/serve/_private/haproxy_templates.py @@ -0,0 +1,132 @@ +HAPROXY_HEALTHZ_RULES_TEMPLATE = """ # Health check endpoint + acl healthcheck path -i {{ config.health_check_endpoint }} + # Suppress logging for health checks + http-request set-log-level silent if healthcheck +{%- if not health_info.healthy %} + # Override: force health checks to fail (used by drain/disable) + http-request return status {{ health_info.status }} content-type text/plain string "{{ health_info.health_message }}" if healthcheck +{%- elif backends %} + # 200 if any backend has at least one server UP +{%- for backend in backends %} + acl backend_{{ backend.name or 'unknown' }}_server_up nbsrv({{ backend.name or 'unknown' }}) ge 1 +{%- endfor %} + # Any backend with a server UP passes the health check (OR logic) +{%- for backend in backends %} + http-request return status {{ health_info.status }} content-type text/plain string "{{ health_info.health_message }}" if healthcheck backend_{{ backend.name or 'unknown' }}_server_up +{%- endfor %} + http-request return status 503 content-type text/plain string "Service Unavailable" if healthcheck +{%- endif %} +""" + +HAPROXY_CONFIG_TEMPLATE = """global + # Log to the standard system log socket with debug level. + log /dev/log local0 debug + log 127.0.0.1:{{ config.syslog_port }} local0 debug + stats socket {{ config.socket_path }} mode 666 level admin expose-fd listeners + stats timeout 30s + maxconn {{ config.maxconn }} + nbthread {{ config.nbthread }} + {%- if config.enable_hap_optimization %} + server-state-base {{ config.server_state_base }} + server-state-file {{ config.server_state_file }} + {%- endif %} + {%- if config.hard_stop_after_s is not none %} + hard-stop-after {{ config.hard_stop_after_s }}s + {%- endif %} +defaults + mode http + option log-health-checks + {% if config.timeout_connect_s is not none %}timeout connect {{ config.timeout_connect_s }}s{% endif %} + {% if config.timeout_client_s is not none %}timeout client {{ config.timeout_client_s }}s{% endif %} + {% if config.timeout_server_s is not none %}timeout server {{ config.timeout_server_s }}s{% endif %} + {% if config.timeout_http_request_s is not none %}timeout http-request {{ config.timeout_http_request_s }}s{% endif %} + {% if config.timeout_http_keep_alive_s is not none %}timeout http-keep-alive {{ config.timeout_http_keep_alive_s }}s{% endif %} + {% if config.timeout_queue_s is not none %}timeout queue {{ config.timeout_queue_s }}s{% endif %} + log global + option httplog + option abortonclose + {%- if config.enable_hap_optimization %} + option idle-close-on-response + {%- endif %} + # Normalize 502 and 504 errors to 500 per Serve's default behavior + {%- if config.error_file_path %} + errorfile 502 {{ config.error_file_path }} + errorfile 504 {{ config.error_file_path }} + {%- endif %} + {%- if config.enable_hap_optimization %} + load-server-state-from-file global + {%- endif %} +frontend prometheus + bind :{{ config.metrics_port }} + mode http + http-request use-service prometheus-exporter if { path {{ config.metrics_uri }} } + no log +frontend http_frontend + bind {{ config.frontend_host }}:{{ config.frontend_port }} +{{ healthz_rules|safe }} + # Routes endpoint + acl routes path -i /-/routes + http-request return status {{ route_info.status }} content-type {{ route_info.routes_content_type }} string "{{ route_info.routes_message }}" if routes + + {%- if config.inject_process_id_header and config.reload_id %} + # Inject unique reload ID as header to track which HAProxy instance handled the request (testing only) + http-request set-header x-haproxy-reload-id {{ config.reload_id }} + {%- endif %} + # Static routing based on path prefixes in decreasing length then alphabetical order +{%- for backend in backends %} + acl is_{{ backend.name or 'unknown' }} path_beg {{ '/' if not backend.path_prefix or backend.path_prefix == '/' else backend.path_prefix ~ '/' }} + acl is_{{ backend.name or 'unknown' }} path {{ backend.path_prefix or '/' }} + use_backend {{ backend.name or 'unknown' }} if is_{{ backend.name or 'unknown' }} +{%- endfor %} + default_backend default_backend +backend default_backend + http-request return status 404 content-type text/plain lf-string "Path \'%[path]\' not found. Ping http://.../-/routes for available routes." +{%- for item in backends_with_health_config %} +{%- set backend = item.backend %} +{%- set hc = item.health_config %} +backend {{ backend.name or 'unknown' }} + log global + balance leastconn + # Enable HTTP connection reuse for better performance + http-reuse always + # Set backend-specific timeouts, overriding defaults if specified + {%- if backend.timeout_connect_s is not none %} + timeout connect {{ backend.timeout_connect_s }}s + {%- endif %} + {%- if backend.timeout_server_s is not none %} + timeout server {{ backend.timeout_server_s }}s + {%- endif %} + {%- if backend.timeout_client_s is not none %} + timeout client {{ backend.timeout_client_s }}s + {%- endif %} + {%- if backend.timeout_http_request_s is not none %} + timeout http-request {{ backend.timeout_http_request_s }}s + {%- endif %} + {%- if backend.timeout_queue_s is not none %} + timeout queue {{ backend.timeout_queue_s }}s + {%- endif %} + # Set timeouts to support keep-alive connections + {%- if backend.timeout_http_keep_alive_s is not none %} + timeout http-keep-alive {{ backend.timeout_http_keep_alive_s }}s + {%- endif %} + {%- if backend.timeout_tunnel_s is not none %} + timeout tunnel {{ backend.timeout_tunnel_s }}s + {%- endif %} + # Health check configuration - use backend-specific or global defaults + {%- if hc.health_path %} + # HTTP health check with custom path + option httpchk GET {{ hc.health_path }} + http-check expect status 200 + {%- endif %} + {{ hc.default_server_directive }} + # Servers in this backend + {%- for server in backend.servers %} + server {{ server.name }} {{ server.host }}:{{ server.port }} check + {%- endfor %} +{%- endfor %} +listen stats + bind *:{{ config.stats_port }} + stats enable + stats uri {{ config.stats_uri }} + stats refresh 1s +""" diff --git a/python/ray/serve/_private/long_poll.py b/python/ray/serve/_private/long_poll.py index 2efabe489205..e735fdaed657 100644 --- a/python/ray/serve/_private/long_poll.py +++ b/python/ray/serve/_private/long_poll.py @@ -48,6 +48,7 @@ def __repr__(self): ROUTE_TABLE = auto() GLOBAL_LOGGING_CONFIG = auto() DEPLOYMENT_CONFIG = auto() + TARGET_GROUPS = auto() @dataclass diff --git a/python/ray/serve/tests/BUILD.bazel b/python/ray/serve/tests/BUILD.bazel index 35d84f964c92..464b688c5caf 100644 --- a/python/ray/serve/tests/BUILD.bazel +++ b/python/ray/serve/tests/BUILD.bazel @@ -657,3 +657,28 @@ py_test_module_list( "//python/ray/serve:serve_lib", ], ) + +# HAProxy tests (require RAY_SERVE_ENABLE_HA_PROXY=1). +py_test_module_list( + size = "large", + env = { + "RAY_SERVE_ENABLE_HA_PROXY": "1", + "RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S": "0.01", + }, + files = [ + "test_haproxy.py", + "test_haproxy_api.py", + "test_metrics_haproxy.py", + ], + tags = [ + "exclusive", + "haproxy", + "no_windows", + "team:serve", + ], + deps = [ + ":common", + ":conftest", + "//python/ray/serve:serve_lib", + ], +) diff --git a/python/ray/serve/tests/test_haproxy.py b/python/ray/serve/tests/test_haproxy.py new file mode 100644 index 000000000000..9dd7759c181a --- /dev/null +++ b/python/ray/serve/tests/test_haproxy.py @@ -0,0 +1,916 @@ +import asyncio +import logging +import subprocess +import sys +import threading +import time +from tempfile import NamedTemporaryFile + +import httpx +import pytest +import requests + +import ray +from ray import serve +from ray._common.test_utils import ( + SignalActor, + wait_for_condition, +) +from ray.actor import ActorHandle +from ray.cluster_utils import Cluster +from ray.serve._private.constants import ( + DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S, + RAY_SERVE_ENABLE_HA_PROXY, + SERVE_NAMESPACE, +) +from ray.serve._private.haproxy import HAProxyManager +from ray.serve._private.test_utils import get_application_url +from ray.serve.context import _get_global_client +from ray.serve.schema import ( + ProxyStatus, + ServeDeploySchema, + ServeInstanceDetails, +) +from ray.serve.tests.conftest import * # noqa +from ray.serve.tests.test_cli_2 import ping_endpoint +from ray.tests.conftest import call_ray_stop_only # noqa: F401 +from ray.util.state import list_actors + +logger = logging.getLogger(__name__) + +# Skip all tests in this module if the HAProxy feature flag is not enabled +pytestmark = pytest.mark.skipif( + not RAY_SERVE_ENABLE_HA_PROXY, + reason="RAY_SERVE_ENABLE_HA_PROXY not set.", +) + + +@pytest.fixture(autouse=True) +def clean_up_haproxy_processes(): + """Clean up haproxy processes before and after each test.""" + subprocess.run( + ["pkill", "-x", "haproxy"], capture_output=True, text=True, check=False + ) + yield + # After test: verify no haproxy processes are running + result = subprocess.run( + ["pgrep", "-x", "haproxy"], capture_output=True, text=True, check=False + ) + assert ( + result.returncode != 0 + ), f"HAProxy processes still running after test: {result.stdout.strip()}" + + +@pytest.fixture +def shutdown_ray(): + if ray.is_initialized(): + ray.shutdown() + yield + if ray.is_initialized(): + ray.shutdown() + + +def test_deploy_with_no_applications(ray_shutdown): + """Deploy an empty list of applications, serve should just be started.""" + ray.init(num_cpus=8) + serve.start(http_options=dict(port=8003)) + client = _get_global_client() + config = ServeDeploySchema.parse_obj({"applications": []}) + client.deploy_apps(config) + + def serve_running(): + ServeInstanceDetails.parse_obj( + ray.get(client._controller.get_serve_instance_details.remote()) + ) + actors = list_actors( + filters=[ + ("ray_namespace", "=", SERVE_NAMESPACE), + ("state", "=", "ALIVE"), + ] + ) + actor_names = [actor["class_name"] for actor in actors] + return "ServeController" in actor_names and "HAProxyManager" in actor_names + + wait_for_condition(serve_running) + client.shutdown() + + +def test_single_app_shutdown_actors(ray_shutdown): + """Tests serve.shutdown() works correctly in single-app case + + Ensures that after deploying a (nameless) app using serve.run(), serve.shutdown() + deletes all actors (controller, haproxy, all replicas) in the "serve" namespace. + """ + address = ray.init(num_cpus=8)["address"] + serve.start(http_options=dict(port=8003)) + + @serve.deployment + def f(): + pass + + serve.run(f.bind(), name="app") + + actor_names = { + "ServeController", + "HAProxyManager", + "ServeReplica:app:f", + } + + def check_alive(): + actors = list_actors( + address=address, + filters=[("ray_namespace", "=", SERVE_NAMESPACE), ("state", "=", "ALIVE")], + ) + return {actor["class_name"] for actor in actors} == actor_names + + def check_dead(): + actors = list_actors( + address=address, + filters=[("ray_namespace", "=", SERVE_NAMESPACE), ("state", "=", "ALIVE")], + ) + return len(actors) == 0 + + wait_for_condition(check_alive) + serve.shutdown() + wait_for_condition(check_dead) + + +@pytest.mark.asyncio +async def test_single_app_shutdown_actors_async(ray_shutdown): + """Tests serve.shutdown_async() works correctly in single-app case + + Ensures that after deploying a (nameless) app using serve.run(), serve.shutdown_async() + deletes all actors (controller, haproxy, all replicas) in the "serve" namespace. + """ + address = ray.init(num_cpus=8)["address"] + serve.start(http_options=dict(port=8003)) + + @serve.deployment + def f(): + pass + + serve.run(f.bind(), name="app") + + actor_names = { + "ServeController", + "HAProxyManager", + "ServeReplica:app:f", + } + + def check_alive(): + actors = list_actors( + address=address, + filters=[("ray_namespace", "=", SERVE_NAMESPACE), ("state", "=", "ALIVE")], + ) + return {actor["class_name"] for actor in actors} == actor_names + + def check_dead(): + actors = list_actors( + address=address, + filters=[("ray_namespace", "=", SERVE_NAMESPACE), ("state", "=", "ALIVE")], + ) + return len(actors) == 0 + + wait_for_condition(check_alive) + await serve.shutdown_async() + wait_for_condition(check_dead) + + +def test_haproxy_subprocess_killed_on_manager_shutdown(ray_shutdown): + """Test that the HAProxy subprocess is killed when the HAProxyManager actor is shutdown. + + This ensures proper cleanup of HAProxy processes when the manager is killed, + preventing orphaned HAProxy processes. + """ + + def get_haproxy_pids(): + """Get all haproxy process PIDs.""" + result = subprocess.run( + ["pgrep", "-x", "haproxy"], capture_output=True, text=True, timeout=2 + ) + if result.returncode == 0 and result.stdout.strip(): + return [int(pid) for pid in result.stdout.strip().split("\n")] + + return [] + + wait_for_condition( + lambda: len(get_haproxy_pids()) == 0, timeout=5, retry_interval_ms=100 + ) + + @serve.deployment + def hello(): + return "hello" + + serve.run(hello.bind()) + wait_for_condition( + lambda: len(get_haproxy_pids()) == 1, timeout=10, retry_interval_ms=100 + ) + + serve.shutdown() + + wait_for_condition( + lambda: len(get_haproxy_pids()) == 0, timeout=10, retry_interval_ms=100 + ) + + +# TODO(alexyang): Delete these tests and run test_proxy.py instead once HAProxy is fully supported. +class TestTimeoutKeepAliveConfig: + """Test setting keep_alive_timeout_s in config and env.""" + + def get_proxy_actor(self) -> ActorHandle: + [proxy_actor] = list_actors(filters=[("class_name", "=", "HAProxyManager")]) + return ray.get_actor(proxy_actor.name, namespace=SERVE_NAMESPACE) + + def test_default_keep_alive_timeout_s(self, ray_shutdown): + """Test when no keep_alive_timeout_s is set. + + When the keep_alive_timeout_s is not set, the uvicorn keep alive is 5. + """ + serve.start() + proxy_actor = self.get_proxy_actor() + assert ( + ray.get(proxy_actor._get_http_options.remote()).keep_alive_timeout_s + == DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S + ) + + def test_set_keep_alive_timeout_in_http_configs(self, ray_shutdown): + """Test when keep_alive_timeout_s is in http configs. + + When the keep_alive_timeout_s is set in http configs, the uvicorn keep alive + is set correctly. + """ + keep_alive_timeout_s = 222 + serve.start(http_options={"keep_alive_timeout_s": keep_alive_timeout_s}) + proxy_actor = self.get_proxy_actor() + assert ( + ray.get(proxy_actor._get_http_options.remote()).keep_alive_timeout_s + == keep_alive_timeout_s + ) + + @pytest.mark.parametrize( + "ray_instance", + [ + {"RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S": "333"}, + ], + indirect=True, + ) + def test_set_keep_alive_timeout_in_env(self, ray_instance, ray_shutdown): + """Test when keep_alive_timeout_s is in env. + + When the keep_alive_timeout_s is set in env, the uvicorn keep alive + is set correctly. + """ + serve.start() + proxy_actor = self.get_proxy_actor() + assert ( + ray.get(proxy_actor._get_http_options.remote()).keep_alive_timeout_s == 333 + ) + + @pytest.mark.parametrize( + "ray_instance", + [ + {"RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S": "333"}, + ], + indirect=True, + ) + def test_set_timeout_keep_alive_in_both_config_and_env( + self, ray_instance, ray_shutdown + ): + """Test when keep_alive_timeout_s is in both http configs and env. + + When the keep_alive_timeout_s is set in env, the uvicorn keep alive + is set to the one in env. + """ + keep_alive_timeout_s = 222 + serve.start(http_options={"keep_alive_timeout_s": keep_alive_timeout_s}) + proxy_actor = self.get_proxy_actor() + assert ( + ray.get(proxy_actor._get_http_options.remote()).keep_alive_timeout_s == 333 + ) + + +@pytest.mark.asyncio +async def test_drain_and_undrain_haproxy_manager( + monkeypatch, shutdown_ray, call_ray_stop_only # noqa: F811 +): + """Test the state transtion of the haproxy manager between + HEALTHY, DRAINING and DRAINED + """ + monkeypatch.setenv("RAY_SERVE_PROXY_MIN_DRAINING_PERIOD_S", "10") + monkeypatch.setenv("SERVE_SOCKET_REUSE_PORT_ENABLED", "1") + + cluster = Cluster() + head_node = cluster.add_node(num_cpus=0) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.wait_for_nodes() + ray.init(address=head_node.address) + serve.start(http_options={"location": "EveryNode"}) + + signal_actor = SignalActor.remote() + + @serve.deployment + class HelloModel: + async def __call__(self): + await signal_actor.wait.remote() + return "hello" + + serve.run(HelloModel.options(num_replicas=2).bind()) + + # 3 proxies, 1 controller, 2 replicas, 1 signal actor + wait_for_condition(lambda: len(list_actors()) == 7) + assert len(ray.nodes()) == 3 + + client = _get_global_client() + serve_details = ServeInstanceDetails( + **ray.get(client._controller.get_serve_instance_details.remote()) + ) + proxy_actor_ids = {proxy.actor_id for _, proxy in serve_details.proxies.items()} + + assert len(proxy_actor_ids) == 3 + + # Start a long-running request in background to test draining behavior + request_result = [] + + def make_blocking_request(): + try: + response = httpx.get("http://localhost:8000/", timeout=5) + request_result.append(("success", response.status_code)) + except Exception as e: + request_result.append(("error", str(e))) + + request_thread = threading.Thread(target=make_blocking_request) + request_thread.start() + + wait_for_condition( + lambda: ray.get(signal_actor.cur_num_waiters.remote()) >= 1, timeout=10 + ) + + serve.run(HelloModel.options(num_replicas=1).bind()) + + # 1 proxy should be draining + + def check_proxy_status(proxy_status_to_count): + serve_details = ServeInstanceDetails( + **ray.get(client._controller.get_serve_instance_details.remote()) + ) + proxy_status_list = [proxy.status for _, proxy in serve_details.proxies.items()] + print("all proxies!!!", [proxy for _, proxy in serve_details.proxies.items()]) + current_status = { + status: proxy_status_list.count(status) for status in proxy_status_list + } + return current_status == proxy_status_to_count, current_status + + wait_for_condition( + condition_predictor=check_proxy_status, + proxy_status_to_count={ProxyStatus.HEALTHY: 2, ProxyStatus.DRAINING: 1}, + ) + + # should stay in draining status until the signal is sent + await asyncio.sleep(1) + + assert check_proxy_status( + proxy_status_to_count={ProxyStatus.HEALTHY: 2, ProxyStatus.DRAINING: 1} + ) + + serve.run(HelloModel.options(num_replicas=2).bind()) + # The proxy should return to healthy status + wait_for_condition( + condition_predictor=check_proxy_status, + proxy_status_to_count={ProxyStatus.HEALTHY: 3}, + ) + serve_details = ServeInstanceDetails( + **ray.get(client._controller.get_serve_instance_details.remote()) + ) + + assert { + proxy.actor_id for _, proxy in serve_details.proxies.items() + } == proxy_actor_ids + + serve.run(HelloModel.options(num_replicas=1).bind()) + await signal_actor.send.remote() + # 1 proxy should be draining and eventually be drained. + wait_for_condition( + condition_predictor=check_proxy_status, + timeout=40, + proxy_status_to_count={ProxyStatus.HEALTHY: 2}, + ) + + # Verify the long-running request completed successfully + request_thread.join(timeout=5) + + # Clean up serve. + serve.shutdown() + + +def test_haproxy_failure(ray_shutdown): + """Test HAProxyManager is successfully restarted after being killed.""" + ray.init(num_cpus=1) + serve.start() + + @serve.deployment(name="proxy_failure") + def function(_): + return "hello1" + + serve.run(function.bind()) + + def check_proxy_alive(): + actors = list_actors( + filters=[("ray_namespace", "=", SERVE_NAMESPACE), ("state", "=", "ALIVE")], + ) + return "HAProxyManager" in {actor["class_name"] for actor in actors} + + wait_for_condition(check_proxy_alive) + + [proxy_actor] = list_actors( + filters=[("class_name", "=", "HAProxyManager"), ("state", "=", "ALIVE")] + ) + proxy_actor_id = proxy_actor.actor_id + + proxy_actor = ray.get_actor(proxy_actor.name, namespace=SERVE_NAMESPACE) + ray.kill(proxy_actor, no_restart=False) + + def check_new_proxy(): + proxies = list_actors( + filters=[("class_name", "=", "HAProxyManager"), ("state", "=", "ALIVE")] + ) + return len(proxies) == 1 and proxies[0].actor_id != proxy_actor_id + + wait_for_condition(check_new_proxy, timeout=45) + serve.shutdown() + + +def test_haproxy_get_target_groups(shutdown_ray): + """Test that haproxy get_target_groups retrieves the correct target groups.""" + ray.init(num_cpus=4) + serve.start() + + @serve.deployment + def function(_): + return "hello1" + + # Deploy the application + serve.run( + function.options(num_replicas=1).bind(), name="test_app", route_prefix="/test" + ) + + def check_proxy_alive(): + actors = list_actors( + filters=[("ray_namespace", "=", SERVE_NAMESPACE), ("state", "=", "ALIVE")], + ) + return "HAProxyManager" in {actor["class_name"] for actor in actors} + + wait_for_condition(check_proxy_alive) + + [proxy_actor] = list_actors( + filters=[("class_name", "=", "HAProxyManager"), ("state", "=", "ALIVE")] + ) + proxy_actor = ray.get_actor(proxy_actor.name, namespace=SERVE_NAMESPACE) + + def has_n_targets(route_prefix: str, n: int): + target_groups = ray.get(proxy_actor.get_target_groups.remote()) + for tg in target_groups: + if tg.route_prefix == route_prefix and len(tg.targets) == n: + return True + return False + + wait_for_condition(has_n_targets, route_prefix="/test", n=1) + + serve.run( + function.options(num_replicas=2).bind(), name="test_app", route_prefix="/test2" + ) + wait_for_condition(has_n_targets, route_prefix="/test2", n=2) + + serve.shutdown() + + +@pytest.mark.asyncio +async def test_haproxy_update_target_groups(ray_shutdown): + """Test that the haproxy correctly updates the target groups.""" + ray.init(num_cpus=4) + serve.start(http_options={"host": "0.0.0.0"}) + + @serve.deployment + def function(_): + return "hello1" + + serve.run( + function.options(num_replicas=1).bind(), name="app1", route_prefix="/test" + ) + assert httpx.get("http://localhost:8000/test").text == "hello1" + assert httpx.get("http://localhost:8000/test2").status_code == 404 + + serve.run( + function.options(num_replicas=1).bind(), name="app2", route_prefix="/test2" + ) + assert httpx.get("http://localhost:8000/test").text == "hello1" + assert httpx.get("http://localhost:8000/test2").text == "hello1" + + serve.delete("app1") + assert httpx.get("http://localhost:8000/test").status_code == 404 + assert httpx.get("http://localhost:8000/test2").text == "hello1" + + serve.run( + function.options(num_replicas=1).bind(), name="app1", route_prefix="/test" + ) + assert httpx.get("http://localhost:8000/test").text == "hello1" + assert httpx.get("http://localhost:8000/test2").text == "hello1" + + serve.shutdown() + + +@pytest.mark.asyncio +async def test_haproxy_update_draining_health_checks(ray_shutdown): + """Test that the haproxy update_draining method updates the HAProxy health checks.""" + ray.init(num_cpus=4) + serve.start() + + signal_actor = SignalActor.remote() + + @serve.deployment + async def function(_): + await signal_actor.wait.remote() + return "hello1" + + serve.run(function.bind()) + + def check_proxy_alive(): + actors = list_actors( + filters=[("ray_namespace", "=", SERVE_NAMESPACE), ("state", "=", "ALIVE")], + ) + return "HAProxyManager" in {actor["class_name"] for actor in actors} + + wait_for_condition(check_proxy_alive) + + [proxy_actor] = list_actors( + filters=[("class_name", "=", "HAProxyManager"), ("state", "=", "ALIVE")] + ) + proxy_actor = ray.get_actor(proxy_actor.name, namespace=SERVE_NAMESPACE) + + assert httpx.get("http://localhost:8000/-/healthz").status_code == 200 + + await proxy_actor.update_draining.remote(draining=True) + wait_for_condition( + lambda: httpx.get("http://localhost:8000/-/healthz").status_code == 503 + ) + + await proxy_actor.update_draining.remote(draining=False) + wait_for_condition( + lambda: httpx.get("http://localhost:8000/-/healthz").status_code == 200 + ) + assert not await proxy_actor._is_draining.remote() + + serve.shutdown() + + +def test_haproxy_http_options(ray_shutdown): + """Test that the haproxy config file is generated correctly with http options.""" + ray.init(num_cpus=4) + serve.start( + http_options={ + "host": "0.0.0.0", + "port": 8001, + "keep_alive_timeout_s": 30, + }, + ) + + @serve.deployment + def function(_): + return "hello1" + + serve.run(function.bind(), name="test_app", route_prefix="/test") + url = get_application_url(app_name="test_app", use_localhost=False) + assert httpx.get(url).text == "hello1" + with pytest.raises(httpx.ConnectError): + _ = httpx.get(url.replace(":8001", ":8000")).status_code + + serve.shutdown() + + +def test_haproxy_metrics(ray_shutdown): + """Test that the haproxy metrics are exported correctly.""" + ray.init(num_cpus=4) + serve.start( + http_options={ + "host": "0.0.0.0", + }, + ) + + @serve.deployment + def function(_): + return "hello1" + + serve.run(function.bind()) + + assert httpx.get("http://localhost:8000/").text == "hello1" + + metrics_response = httpx.get("http://localhost:9101/metrics") + assert metrics_response.status_code == 200 + + http_backend_metrics = ( + 'haproxy_backend_http_responses_total{proxy="http-default",code="2xx"} 1' + ) + assert http_backend_metrics in metrics_response.text + + serve.shutdown() + + +def test_haproxy_safe_name(): + """Test that the safe name is generated correctly.""" + assert HAProxyManager.get_safe_name("HTTP-test_foo.bar") == "HTTP-test_foo.bar" + assert HAProxyManager.get_safe_name("HTTP:test") == "HTTP_test" + assert HAProxyManager.get_safe_name("HTTP:test/foo") == "HTTP_test.foo" + assert HAProxyManager.get_safe_name("replica#abc") == "replica-abc" + + +@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.") +def test_build_multi_app(ray_start_stop): + with NamedTemporaryFile(mode="w+", suffix=".yaml") as tmp: + print('Building nodes "TestApp1Node" and "TestApp2Node".') + # Build an app + subprocess.check_output( + [ + "serve", + "build", + "ray.serve.tests.test_cli_3.TestApp1Node", + "ray.serve.tests.test_cli_3.TestApp2Node", + "-o", + tmp.name, + ] + ) + print("Build succeeded! Deploying node.") + + subprocess.check_output(["serve", "deploy", tmp.name]) + print("Deploy succeeded!") + wait_for_condition( + lambda: ping_endpoint("app1") == "wonderful world", timeout=15 + ) + print("App 1 is live and reachable over HTTP.") + wait_for_condition( + lambda: ping_endpoint("app2") == "wonderful world", timeout=15 + ) + print("App 2 is live and reachable over HTTP.") + + print("Deleting applications.") + app_urls = [ + get_application_url("HTTP", app_name=app) for app in ["app1", "app2"] + ] + subprocess.check_output(["serve", "shutdown", "-y"]) + + def check_no_apps(): + for url in app_urls: + with pytest.raises(httpx.HTTPError): + _ = httpx.get(url).text + return True + + wait_for_condition(check_no_apps, timeout=15) + print("Delete succeeded! Node is no longer reachable over HTTP.") + + +def test_haproxy_manager_ready_with_application(ray_shutdown): + """Test that HAProxyManager.ready() succeeds when an application is deployed.""" + ray.init(num_cpus=4) + serve.start() + + @serve.deployment + def function(_): + return "hello" + + # Deploy application + serve.run(function.bind(), name="test_app", route_prefix="/test") + + # Get HAProxyManager actor + def check_proxy_alive(): + actors = list_actors( + filters=[("ray_namespace", "=", SERVE_NAMESPACE), ("state", "=", "ALIVE")], + ) + return "HAProxyManager" in {actor["class_name"] for actor in actors} + + wait_for_condition(check_proxy_alive) + + [proxy_actor] = list_actors( + filters=[("class_name", "=", "HAProxyManager"), ("state", "=", "ALIVE")] + ) + proxy_actor = ray.get_actor(proxy_actor.name, namespace=SERVE_NAMESPACE) + + # Call ready() - should succeed with active targets + ready_result = ray.get(proxy_actor.ready.remote()) + assert ready_result is not None + + wait_for_condition(lambda: httpx.get("http://localhost:8000/test").text == "hello") + + serve.shutdown() + + +def test_504_error_translated_to_500(ray_shutdown, monkeypatch): + """Test that HAProxy translates 504 Gateway Timeout errors to 500 Internal Server Error.""" + monkeypatch.setenv("RAY_SERVE_HAPROXY_TIMEOUT_SERVER_S", "2") + monkeypatch.setenv("RAY_SERVE_HAPROXY_TIMEOUT_CONNECT_S", "1") + + ray.init(num_cpus=8) + serve.start(http_options=dict(port=8003)) + + @serve.deployment + class TimeoutDeployment: + def __call__(self, request): + # Sleep for 3 seconds, longer than HAProxy's 2s timeout + # Use regular time.sleep (not async) to avoid event loop issues + time.sleep(3) + return "This should not be reached" + + serve.run(TimeoutDeployment.bind(), name="timeout_app", route_prefix="/test") + + url = get_application_url("HTTP", app_name="timeout_app") + + # HAProxy should timeout after 2s and return 504->500 + # Client timeout is 10s to ensure HAProxy times out first + response = requests.get(f"{url}/test", timeout=10) + + # Verify we got 500 (translated from 504), not 504 or 200 + assert ( + response.status_code == 500 + ), f"Expected 500 Internal Server Error (translated from 504), got {response.status_code}" + assert ( + "Internal Server Error" in response.text + ), f"Response should contain 'Internal Server Error' message, got: {response.text}" + + +def test_502_error_translated_to_500(ray_shutdown): + """Test that HAProxy translates 502 Bad Gateway errors to 500 Internal Server Error.""" + ray.init(num_cpus=8) + serve.start(http_options=dict(port=8003)) + + @serve.deployment + class BrokenDeployment: + def __call__(self, request): + # Always raise an exception to simulate backend failure + raise RuntimeError("Simulated backend failure for 502 error") + + serve.run( + BrokenDeployment.bind(), name="broken_app", route_prefix="/test", blocking=False + ) + url = get_application_url("HTTP", app_name="broken_app") + response = requests.get(f"{url}/test", timeout=5) + + assert ( + response.status_code == 500 + ), f"Expected 500 Internal Server Error, got {response.status_code}" + assert ( + "Internal Server Error" in response.text + ), "Response should contain 'Internal Server Error' message" + + +def test_haproxy_healthcheck_multiple_apps_and_backends(ray_shutdown): + """Health check behavior with 3 apps and 2 servers per backend. + + Expectations: + - With two servers per backend, healthz returns 200 (all backends have a primary UP). + - Disabling one primary in each backend keeps health at 200 (the other primary is UP). + - Disabling all servers in each backend results in healthz 503. + """ + ray.init(num_cpus=8) + serve.start() + + @serve.deployment + def f(_): + return "hello" + + # Helpers + SOCKET_PATH = "/tmp/haproxy-serve/admin.sock" + + def app_to_backend(app: str) -> str: + return f"http-{app}" + + def haproxy_show_stat() -> str: + result = subprocess.run( + f'echo "show stat" | socat - {SOCKET_PATH}', + shell=True, + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to query HAProxy stats: {result.stderr}") + return result.stdout + + def list_primary_servers(backend_name: str) -> list: + lines = haproxy_show_stat().strip().split("\n") + servers = [] + for line in lines: + parts = line.split(",") + if len(parts) < 2: + continue + pxname, svname = parts[0], parts[1] + if pxname == backend_name and svname not in [ + "FRONTEND", + "BACKEND", + ]: + servers.append(svname) + return servers + + def set_server_state(backend: str, server: str, state: str) -> None: + subprocess.run( + f'echo "set server {backend}/{server} state {state}" | socat - {SOCKET_PATH}', + shell=True, + capture_output=True, + timeout=5, + ) + + def wait_health(expected: int, timeout: float = 15.0) -> None: + wait_for_condition( + lambda: httpx.get("http://localhost:8000/-/healthz").status_code + == expected, + timeout=timeout, + ) + + # Deploy 3 apps, each with 2 replicas (servers) so each backend has 2 servers + 1 backup + apps = [ + ("app_a", "/a"), + ("app_b", "/b"), + ("app_c", "/c"), + ] + for app_name, route in apps: + serve.run(f.options(num_replicas=2).bind(), name=app_name, route_prefix=route) + + # Wait for all endpoints to be reachable + for _, route in apps: + wait_for_condition( + lambda r=route: httpx.get(f"http://localhost:8000{r}").text == "hello" + ) + + # Wait until each backend shows 2 primary servers in HAProxy stats + backends = [app_to_backend(app) for app, _ in apps] + for be in backends: + wait_for_condition(lambda b=be: len(list_primary_servers(b)) >= 2, timeout=20) + + # Initially healthy + wait_health(200, timeout=20) + + # Disable one primary per backend, should remain healthy (one primary still UP) + disabled_servers = [] + for be in backends: + servers = list_primary_servers(be) + set_server_state(be, servers[0], "maint") + disabled_servers.append((be, servers[0])) + + wait_health(200, timeout=20) + + # Disable the remaining primary per backend, should become unhealthy (no servers UP) + disabled_all = [] + for be in backends: + servers = list_primary_servers(be) + # Disable any remaining primary (skip ones already disabled) + for sv in servers: + if (be, sv) not in disabled_servers: + set_server_state(be, sv, "maint") + disabled_all.append((be, sv)) + break + + wait_health(503, timeout=20) + + # Re-enable all servers and expect health back to 200 + for be, sv in disabled_servers + disabled_all: + set_server_state(be, sv, "ready") + wait_health(200, timeout=20) + + # Sanity: all apps still respond + for _, route in apps: + resp = httpx.get(f"http://localhost:8000{route}") + assert resp.status_code == 200 and resp.text == "hello" + + serve.shutdown() + + +def test_haproxy_empty_backends_for_scaled_down_apps(ray_shutdown): + """Test that HAProxy has no backend servers for deleted apps. + + Verifies that when RAY_SERVE_ENABLE_HA_PROXY is True and apps are + deleted, the HAProxy stats show the backend is removed or has no servers. + """ + ray.init(num_cpus=4) + serve.start() + + @serve.deployment + def hello(): + return "hello" + + # Deploy app with 1 replica + serve.run( + hello.options(num_replicas=1).bind(), name="test_app", route_prefix="/test" + ) + + r = httpx.get("http://localhost:8000/test") + assert r.status_code == 200 + assert r.text == "hello" + + # Delete the app - this should remove or empty the backend + serve.delete("test_app") + + r = httpx.get("http://localhost:8000/test") + assert r.status_code == 404 + + serve.shutdown() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_haproxy_api.py b/python/ray/serve/tests/test_haproxy_api.py new file mode 100644 index 000000000000..eb589f3c69a1 --- /dev/null +++ b/python/ray/serve/tests/test_haproxy_api.py @@ -0,0 +1,1610 @@ +import asyncio +import logging +import os +import subprocess +import sys +import tempfile +import threading +import time +from typing import Optional +from unittest import mock + +import pytest +import pytest_asyncio +import requests +import uvicorn +from fastapi import FastAPI, Request, Response + +from ray._common.test_utils import async_wait_for_condition, wait_for_condition +from ray.serve._private.constants import ( + RAY_SERVE_ENABLE_HA_PROXY, +) +from ray.serve._private.haproxy import ( + BackendConfig, + HAProxyApi, + HAProxyConfig, + ServerConfig, +) +from ray.serve.config import HTTPOptions + +logger = logging.getLogger(__name__) + +# Skip all tests in this module if the HAProxy feature flag is not enabled +pytestmark = pytest.mark.skipif( + not RAY_SERVE_ENABLE_HA_PROXY, + reason="RAY_SERVE_ENABLE_HA_PROXY not set.", +) + +EXCLUDED_ACL_NAMES = ("healthcheck", "routes") + + +def check_haproxy_ready(stats_port: int, timeout: int = 2) -> bool: + """Check if HAProxy is ready by verifying the stats endpoint is accessible.""" + try: + response = requests.get(f"http://127.0.0.1:{stats_port}/stats", timeout=timeout) + return response.status_code == 200 + except Exception: + return False + + +def create_test_backend_server(port: int): + """Create a test backend server with slow and fast endpoints using uvicorn.""" + app = FastAPI() + + @app.get("/-/healthz") + async def health_endpoint(): + return {"status": "OK"} + + @app.get("/slow") + async def slow_endpoint(): + await asyncio.sleep(3) # 3-second delay + return "Slow response completed" + + @app.get("/fast") + async def fast_endpoint(req: Request, res: Response): + res.headers["x-haproxy-reload-id"] = req.headers.get("x-haproxy-reload-id", "") + + return "Fast response" + + # Configure uvicorn server with 60s keep-alive timeout + config = uvicorn.Config( + app=app, + host="127.0.0.1", + port=port, + log_level="error", # Reduce log noise + access_log=False, + timeout_keep_alive=60, # 60 seconds keep-alive timeout + ) + server = uvicorn.Server(config) + + # Run server in a separate thread + def run_server(): + asyncio.run(server.serve()) + + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + + # Wait for the server to start + def wait_for_server(): + r = requests.get(f"http://127.0.0.1:{port}/-/healthz") + assert r.status_code == 200 + return True + + wait_for_condition(wait_for_server) + return server, thread + + +def process_exists(pid: int) -> bool: + """Check if a process with the given PID exists.""" + try: + # Send signal 0 to check if process exists without actually sending a signal + os.kill(pid, 0) + return True + except (OSError, ProcessLookupError): + return False + + +def make_test_request( + url: str, + track_results: list = None, + signal_started: threading.Event = None, + timeout: int = 10, +): + """Unified function to make test requests with optional result tracking.""" + try: + if signal_started: + signal_started.set() # Signal that request has started + + start_time = time.time() + response = requests.get(url, timeout=timeout) + end_time = time.time() + + if track_results is not None: + track_results.append( + { + "status": response.status_code, + "duration": end_time - start_time, + "content": response.content, + } + ) + except Exception as ex: + if track_results is not None: + track_results.append({"error": str(ex)}) + + +@pytest.fixture(autouse=True) +def clean_up_haproxy_processes(): + """Clean up haproxy processes before and after each test.""" + + subprocess.run( + ["pkill", "-x", "haproxy"], capture_output=True, text=True, check=False + ) + yield + # After test: verify no haproxy processes are running + result = subprocess.run( + ["pgrep", "-x", "haproxy"], capture_output=True, text=True, check=False + ) + assert ( + result.returncode != 0 or not result.stdout.strip() + ), f"HAProxy processes still running after test: {result.stdout.strip()}" + + +@pytest_asyncio.fixture +async def haproxy_api_cleanup(): + registered_apis = [] + + def register(api: Optional[HAProxyApi]) -> None: + if api is not None: + registered_apis.append(api) + + yield register + + for api in registered_apis: + proc = getattr(api, "_proc", None) + if proc and proc.returncode is None: + try: + await api.stop() + except Exception as exc: # pragma: no cover - best effort cleanup + logger.warning(f"Failed to stop HAProxy API cleanly: {exc}") + try: + proc.kill() + await proc.wait() + except Exception as kill_exc: + logger.error( + f"Failed to kill HAProxy process {proc.pid}: {kill_exc}" + ) + elif proc and proc.returncode is not None: + continue + + +def test_generate_config_file_internal(haproxy_api_cleanup): + """Test that initialize writes the correct config_stub file content using the actual template.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + config_stub = HAProxyConfig( + socket_path=socket_path, + maxconn=1000, + nbthread=2, + timeout_connect_s=5, + timeout_client_s=30, + timeout_server_s=30, + timeout_http_request_s=10, + timeout_queue_s=1, + stats_port=8080, + stats_uri="/mystats", + health_check_fall=3, + health_check_rise=2, + health_check_inter="2s", + health_check_path="/health", + http_options=HTTPOptions( + host="0.0.0.0", + port=8000, + keep_alive_timeout_s=55, + ), + has_received_routes=True, + has_received_servers=True, + enable_hap_optimization=True, + ) + backend_config_stub = { + "api_backend": BackendConfig( + name="api_backend", + path_prefix="/api", + app_name="api_backend", + timeout_http_keep_alive_s=60, + timeout_tunnel_s=60, + health_check_path="/api/health", + health_check_fall=2, + health_check_rise=3, + health_check_inter="5s", + servers=[ + ServerConfig(name="api_server1", host="127.0.0.1", port=8001), + ServerConfig(name="api_server2", host="127.0.0.1", port=8002), + ], + ), + "web_backend": BackendConfig( + name="web_backend", + path_prefix="/web", + app_name="web_backend", + timeout_connect_s=3, + timeout_server_s=25, + timeout_http_keep_alive_s=45, + timeout_tunnel_s=45, + servers=[ + ServerConfig(name="web_server1", host="127.0.0.1", port=8003), + ] + # No health check overrides - should use global defaults + ), + } + + with mock.patch( + "ray.serve._private.constants.RAY_SERVE_HAPROXY_CONFIG_FILE_LOC", + config_file_path, + ): + + api = HAProxyApi( + cfg=config_stub, + backend_configs=backend_config_stub, + config_file_path=config_file_path, + ) + + try: + api._generate_config_file_internal() + + # Read and verify the generated file + with open(config_file_path, "r") as f: + actual_content = f.read() + + routes = '{\\"/api\\":\\"api_backend\\",\\"/web\\":\\"web_backend\\"}' + # Expected configuration stub (matching the actual template output) + expected_config = f""" +global + # Log to the standard system log socket with debug level. + log /dev/log local0 debug + log 127.0.0.1:514 local0 debug + stats socket {socket_path} mode 666 level admin expose-fd listeners + stats timeout 30s + maxconn 1000 + nbthread 2 + server-state-base /tmp/haproxy-serve + server-state-file /tmp/haproxy-serve/server-state + hard-stop-after 120s +defaults + mode http + option log-health-checks + timeout connect 5s + timeout client 30s + timeout server 30s + timeout http-request 10s + timeout http-keep-alive 55s + timeout queue 1s + log global + option httplog + option abortonclose + option idle-close-on-response + # Normalize 502 and 504 errors to 500 per Serve's default behavior + errorfile 502 {temp_dir}/500.http + errorfile 504 {temp_dir}/500.http + load-server-state-from-file global +frontend prometheus + bind :9101 + mode http + http-request use-service prometheus-exporter if {{ path /metrics }} + no log +frontend http_frontend + bind *:8000 + # Health check endpoint + acl healthcheck path -i /-/healthz + # Suppress logging for health checks + http-request set-log-level silent if healthcheck + # 200 if any backend has at least one server UP + acl backend_api_backend_server_up nbsrv(api_backend) ge 1 + acl backend_web_backend_server_up nbsrv(web_backend) ge 1 + # Any backend with a server UP passes the health check (OR logic) + http-request return status 200 content-type text/plain string "success" if healthcheck backend_api_backend_server_up + http-request return status 200 content-type text/plain string "success" if healthcheck backend_web_backend_server_up + http-request return status 503 content-type text/plain string "Service Unavailable" if healthcheck + # Routes endpoint + acl routes path -i /-/routes + http-request return status 200 content-type application/json string "{routes}" if routes + # Static routing based on path prefixes in decreasing length then alphabetical order + acl is_api_backend path_beg /api/ + acl is_api_backend path /api + use_backend api_backend if is_api_backend + acl is_web_backend path_beg /web/ + acl is_web_backend path /web + use_backend web_backend if is_web_backend + default_backend default_backend +backend default_backend + http-request return status 404 content-type text/plain lf-string "Path \'%[path]\' not found. Ping http://.../-/routes for available routes." +backend api_backend + log global + balance leastconn + # Enable HTTP connection reuse for better performance + http-reuse always + # Set backend-specific timeouts, overriding defaults if specified + # Set timeouts to support keep-alive connections + timeout http-keep-alive 60s + timeout tunnel 60s + # Health check configuration - use backend-specific or global defaults + # HTTP health check with custom path + option httpchk GET /api/health + http-check expect status 200 + default-server fastinter 250ms downinter 250ms fall 2 rise 3 inter 5s check + # Servers in this backend + server api_server1 127.0.0.1:8001 check + server api_server2 127.0.0.1:8002 check +backend web_backend + log global + balance leastconn + # Enable HTTP connection reuse for better performance + http-reuse always + # Set backend-specific timeouts, overriding defaults if specified + timeout connect 3s + timeout server 25s + # Set timeouts to support keep-alive connections + timeout http-keep-alive 45s + timeout tunnel 45s + # Health check configuration - use backend-specific or global defaults + # HTTP health check with custom path + option httpchk GET /-/healthz + http-check expect status 200 + default-server fastinter 250ms downinter 250ms fall 3 rise 2 inter 2s check + # Servers in this backend + server web_server1 127.0.0.1:8003 check +listen stats + bind *:8080 + stats enable + stats uri /mystats + stats refresh 1s +""" + + # Compare the entire configuration + assert actual_content.strip() == expected_config.strip() + finally: + # Clean up any temporary files created by initialize() + temp_files = ["haproxy.cfg", "routes.map"] + for temp_file in temp_files: + try: + if os.path.exists(temp_file): + os.remove(temp_file) + except (FileNotFoundError, OSError): + pass # File already removed or doesn't exist + + +def test_generate_backends_in_order(haproxy_api_cleanup): + """Test that the backends are generated in the correct order.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + backend_config_stub = { + "foo": BackendConfig( + name="foo", + path_prefix="/foo", + app_name="foo", + ), + "foobar": BackendConfig( + name="foobar", + path_prefix="/foo/bar", + app_name="foobar", + ), + "bar": BackendConfig( + name="bar", + path_prefix="/bar", + app_name="bar", + ), + "default": BackendConfig( + name="default", + path_prefix="/", + app_name="default", + ), + } + + with mock.patch( + "ray.serve._private.constants.RAY_SERVE_HAPROXY_CONFIG_FILE_LOC", + config_file_path, + ): + api = HAProxyApi( + cfg=HAProxyConfig(), + config_file_path=config_file_path, + backend_configs=backend_config_stub, + ) + + api._generate_config_file_internal() + + # Read and verify the generated file + lines = [] + with open(config_file_path, "r") as f: + lines = f.readlines() + + acl_names = [] + path_begs = [] + paths = [] + backend_lines = [] + for line in lines: + line = line.strip() + if line.startswith("acl"): + acl_name = line.split(" ")[1] + if acl_name in EXCLUDED_ACL_NAMES: + continue + + acl_names.append(acl_name) + + # strip prefix/suffix added for acl checks + backend_name = ( + acl_name.lstrip("is_") + .replace("backend_", "") + .replace("_server_up", "") + ) + assert backend_name in backend_config_stub + + condition = line.split(" ")[-2] + if condition == "path_beg": + path_prefix = line.split(" ")[-1].rstrip("/") or "/" + assert backend_config_stub[backend_name].path_prefix == path_prefix + path_begs.append(path_prefix) + elif condition == "path": + path_prefix = line.split(" ")[-1] + assert backend_config_stub[backend_name].path_prefix == path_prefix + paths.append(path_prefix) + else: + # gt condition is used for health check, no need to check. + continue + if line.startswith("use_backend"): + acl_name = line.split(" ")[-1] + assert acl_name in acl_names + backend_lines.append(acl_name) + + expected_order = ["is_foobar", "is_bar", "is_foo", "is_default"] + assert backend_lines == expected_order + + +@pytest.mark.asyncio +async def test_graceful_reload(haproxy_api_cleanup): + """Test that graceful reload preserves long-running connections.""" + + with tempfile.TemporaryDirectory() as temp_dir: + # Setup ports + haproxy_port = 8000 + backend_port = 8404 + stats_port = 8405 + + # Create and start a backend server + backend_server, backend_thread = create_test_backend_server(backend_port) + + # Configure HAProxy + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=haproxy_port, + keep_alive_timeout_s=58, + ), + stats_port=stats_port, + inject_process_id_header=True, # Enable for testing graceful reload + reload_id=f"initial-{int(time.time() * 1000)}", # Set initial reload ID + socket_path=os.path.join(temp_dir, "admin.sock"), + ) + + backend_config = BackendConfig( + name="test_backend", + path_prefix="/", + app_name="test_app", + servers=[ServerConfig(name="backend", host="127.0.0.1", port=backend_port)], + timeout_http_keep_alive_s=58, + ) + + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + + api = HAProxyApi( + cfg=config, + backend_configs={"test_backend": backend_config}, + config_file_path=config_file_path, + ) + + haproxy_api_cleanup(api) + + try: + await api.start() + + # Wait for HAProxy to be ready (check stat endpoint) + def check_stats_ready(): + try: + response = requests.get( + f"http://127.0.0.1:{config.stats_port}/stats", timeout=2 + ) + return response.status_code == 200 + except Exception: + return False + + wait_for_condition(check_stats_ready, timeout=10, retry_interval_ms=100) + + # Track slow request results + slow_results = [] + request_started = threading.Event() + + slow_thread = threading.Thread( + target=make_test_request, + args=[f"http://127.0.0.1:{haproxy_port}/slow"], + kwargs={ + "track_results": slow_results, + "signal_started": request_started, + }, + ) + + slow_thread.start() + wait_for_condition( + lambda: request_started.is_set(), timeout=5, retry_interval_ms=10 + ) + + assert api._proc is not None + original_pid = api._proc.pid + + await api._graceful_reload() + + assert api._proc is not None + new_pid = api._proc.pid + + def check_for_new_reload_id(): + fast_response = requests.get( + f"http://127.0.0.1:{haproxy_port}/fast", timeout=5 + ) + + # Reload ID should always match what exists in the config. + return ( + fast_response.headers.get("x-haproxy-reload-id") + == api.cfg.reload_id + and fast_response.status_code == 200 + ) + + wait_for_condition( + check_for_new_reload_id, timeout=5, retry_interval_ms=100 + ) + + slow_thread.join(timeout=10) + + assert ( + original_pid != new_pid + ), "Process should have been reloaded with new PID" + + wait_for_condition( + lambda: not process_exists(original_pid), + timeout=15, + retry_interval_ms=100, + ) + + assert len(slow_results) == 1, "Slow request should have completed" + + result = slow_results[0] + assert "error" not in result, f"Slow request failed: {result.get('error')}" + assert result["status"] == 200, "Slow request should have succeeded" + assert result["duration"] >= 3.0, "Slow request should have taken full time" + assert ( + b"Slow response completed" in result["content"] + ), "Slow request should have completed" + + finally: + # Backend server cleanup + try: + backend_server.should_exit = True + backend_thread.join(timeout=5) # Wait for thread to finish + except Exception as e: + print(f"Error occurred while shutting down server stub. Error: {e}") + + +@pytest.mark.asyncio +async def test_start(haproxy_api_cleanup): + """Test HAProxy start functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + # Create HAProxy config + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + keep_alive_timeout_s=58, + ), + stats_port=8404, + pass_health_checks=True, + socket_path=socket_path, + has_received_routes=True, + has_received_servers=True, + ) + + # Add a backend so routes are populated + backend = BackendConfig( + name="test_backend", + path_prefix="/", + app_name="test_app", + servers=[ServerConfig(name="server", host="127.0.0.1", port=9999)], + ) + + api = HAProxyApi( + cfg=config, + backend_configs={"test_backend": backend}, + config_file_path=config_file_path, + ) + + haproxy_api_cleanup(api) + + await api.start() + + assert api._proc is not None, "HAProxy process should exist" + assert api._is_running(), "HAProxy should be running" + + # Verify config file contains expected content + with open(config_file_path, "r") as f: + config_content = f.read() + assert "frontend http_frontend" in config_content + assert f"bind 127.0.0.1:{config.frontend_port}" in config_content + assert "acl healthcheck path -i /-/healthz" in config_content + + health_response = requests.get( + f"http://127.0.0.1:{config.frontend_port}/-/healthz", timeout=5 + ) + assert ( + health_response.status_code == 503 + ), "Health check with no servers up should return 503" + + await api.stop() + assert api._proc is None + assert not api._is_running() + + +@pytest.mark.asyncio +async def test_stop(haproxy_api_cleanup): + """Test HAProxy stop functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + ), + stats_port=8404, + socket_path=os.path.join(temp_dir, "admin.sock"), + ) + + api = HAProxyApi(cfg=config, config_file_path=config_file_path) + + haproxy_api_cleanup(api) + + # Start HAProxy + await api.start() + + haproxy_api_cleanup(api) + + await api.stop() + + # Verify it's stopped + assert not api._is_running(), "HAProxy should be stopped after shutdown" + + +@pytest.mark.asyncio +async def test_stop_kills_haproxy_process(haproxy_api_cleanup): + """Test that stop() properly kills the HAProxy subprocess.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + ), + stats_port=8404, + socket_path=os.path.join(temp_dir, "admin.sock"), + ) + + api = HAProxyApi(cfg=config, config_file_path=config_file_path) + haproxy_api_cleanup(api) + + # Start HAProxy + await api.start() + assert api._proc is not None, "HAProxy process should exist after start" + + haproxy_pid = api._proc.pid + assert process_exists(haproxy_pid), "HAProxy process should be running" + + # Stop HAProxy + await api.stop() + + # Verify the process is killed + assert api._proc is None, "HAProxy proc should be None after stop" + + # Wait a bit for process cleanup + def haproxy_process_killed(): + return not process_exists(haproxy_pid) + + wait_for_condition( + haproxy_process_killed, + timeout=1, + retry_interval_ms=100, + ) + + +@pytest.mark.asyncio +async def test_get_stats_integration(haproxy_api_cleanup): + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + # Create test backend servers + backend_port1 = 9900 + backend_port2 = 9901 + backend_server1, backend_thread1 = create_test_backend_server(backend_port1) + backend_server2, backend_thread2 = create_test_backend_server(backend_port2) + + # Configure HAProxy with multiple backends + config = HAProxyConfig( + http_options=HTTPOptions( + port=8000, + keep_alive_timeout_s=58, + ), + socket_path=socket_path, + stats_port=8404, + ) + + backend_configs = { + "test_backend1": BackendConfig( + name="test_backend1", + path_prefix="/api", + app_name="test_app1", + servers=[ + ServerConfig(name="server1", host="127.0.0.1", port=backend_port1) + ], + timeout_http_keep_alive_s=58, + ), + "test_backend2": BackendConfig( + name="test_backend2", + path_prefix="/web", + app_name="test_app2", + servers=[ + ServerConfig(name="server2", host="127.0.0.1", port=backend_port2) + ], + timeout_http_keep_alive_s=58, + ), + } + + api = HAProxyApi( + cfg=config, + backend_configs=backend_configs, + config_file_path=config_file_path, + ) + + haproxy_api_cleanup(api) + + try: + # Start HAProxy + await api.start() + + # Wait for HAProxy to be ready + wait_for_condition( + lambda: check_haproxy_ready(config.stats_port), + timeout=10, + retry_interval_ms=500, + ) + + # Make some API calls to generate sessions and traffic + request_threads = [] + + for i in range(3): + thread = threading.Thread( + target=make_test_request, + args=[f"http://127.0.0.1:{config.frontend_port}/api/slow"], + ) + thread.start() + request_threads.append(thread) + + for i in range(3): + thread = threading.Thread( + target=make_test_request, + args=[f"http://127.0.0.1:{config.frontend_port}/web/slow"], + ) + thread.start() + request_threads.append(thread) + + # Get actual stats + async def two_servers_up(): + stats = await api.get_haproxy_stats() + return stats.active_servers == 2 + + await async_wait_for_condition( + two_servers_up, timeout=10, retry_interval_ms=200 + ) + + async def wait_for_running(): + return await api.is_running() + + await async_wait_for_condition( + wait_for_running, timeout=10, retry_interval_ms=200 + ) + + all_stats = await api.get_all_stats() + haproxy_stats = await api.get_haproxy_stats() + + # Assert against the expected stub with exact values + assert ( + len(all_stats) == 2 + ), f"Should have exactly 2 backends, got {len(all_stats)}" + assert ( + haproxy_stats.total_backends == 2 + ), f"Should have exactly 2 backends, got {haproxy_stats.total_backends}" + assert ( + haproxy_stats.total_servers == 2 + ), f"Should have exactly 2 servers, got {haproxy_stats.total_servers}" + assert ( + haproxy_stats.active_servers == 2 + ), f"Should have exactly 2 active servers, got {haproxy_stats.active_servers}" + + # Wait for request threads to complete + for thread in request_threads: + thread.join(timeout=1) + finally: + # Stop HAProxy + await api.stop() + + # Cleanup backend servers + try: + backend_server1.should_exit = True + backend_server2.should_exit = True + backend_thread1.join(timeout=5) # Wait for the thread to finish + backend_thread2.join(timeout=5) # Wait for the thread to finish + except Exception as e: + print(f"Error cleaning up backend servers: {e}") + + +@pytest.mark.asyncio +async def test_update_and_reload(haproxy_api_cleanup): + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + backend = BackendConfig( + name="backend", + path_prefix="/", + app_name="backend_app", + servers=[ServerConfig(name="server", host="127.0.0.1", port=9999)], + ) + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + ), + stats_port=8404, + socket_path=socket_path, + ) + + api = HAProxyApi( + cfg=config, + backend_configs={backend.name: backend}, + config_file_path=config_file_path, + ) + + await api.start() + haproxy_api_cleanup(api) + + with open(config_file_path, "r") as f: + actual_content = f.read() + assert "backend_2" not in actual_content + + original_proc = api._proc + original_pid = original_proc.pid + + # Add another backend + backend2 = BackendConfig( + name="backend_2", + path_prefix="/", + app_name="backend_app_2", + servers=[ServerConfig(name="server", host="127.0.0.1", port=9999)], + ) + + api.set_backend_configs({backend.name: backend, backend2.name: backend2}) + await api.reload() + + assert api._proc is not None + assert api._proc.pid != original_pid + + with open(config_file_path, "r") as f: + actual_content = f.read() + assert "backend_2" in actual_content + + wait_for_condition( + lambda: not process_exists(original_pid), + timeout=5, + retry_interval_ms=100, + ) + + +@pytest.mark.asyncio +async def test_haproxy_start_should_throw_error_when_already_running( + haproxy_api_cleanup, +): + """Test that HAProxy throws an error when trying to start on an already-used port (SO_REUSEPORT disabled).""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + ), + stats_port=8404, + socket_path=socket_path, + enable_so_reuseport=False, # Disable SO_REUSEPORT + ) + + api = HAProxyApi(cfg=config, config_file_path=config_file_path) + haproxy_api_cleanup(api) + + # Start HAProxy with SO_REUSEPORT disabled + await api.start() + + assert api._proc is not None, "HAProxy process should be running" + first_pid = api._proc.pid + + # Verify we can't start another instance on the same port (SO_REUSEPORT disabled) + config2 = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=config.frontend_port, # Same port + ), + stats_port=8404, + socket_path=os.path.join(temp_dir, "admin2.sock"), + enable_so_reuseport=False, # Disable SO_REUSEPORT + ) + + api2 = HAProxyApi( + cfg=config2, config_file_path=os.path.join(temp_dir, "haproxy2.cfg") + ) + + # This should fail because SO_REUSEPORT is disabled + with pytest.raises(RuntimeError, match="(Address already in use)"): + await api2.start() + + # Cleanup first instance + await api.stop() + assert not process_exists(first_pid), "HAProxy process should be stopped" + + +@pytest.mark.asyncio +async def test_toggle_health_checks(haproxy_api_cleanup): + """Test that disable()/enable() toggle HAProxy health checks end-to-end.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + backend = BackendConfig( + name="backend", + path_prefix="/", + app_name="backend_app", + servers=[ServerConfig(name="server", host="127.0.0.1", port=9999)], + ) + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + ), + stats_port=8404, + socket_path=socket_path, + inject_process_id_header=True, + has_received_routes=True, + has_received_servers=True, + ) + + # Start a real backend server so HAProxy can mark the server UP + backend_server, backend_thread = create_test_backend_server(9999) + try: + api = HAProxyApi( + cfg=config, + backend_configs={backend.name: backend}, + config_file_path=config_file_path, + ) + + await api.start() + haproxy_api_cleanup(api) + + # Verify HAProxy is running + assert api._is_running(), "HAProxy should be running" + + # Health requires servers; wait until health passes + def health_ok(): + resp = requests.get( + f"http://127.0.0.1:{config.frontend_port}{config.health_check_endpoint}", + timeout=5, + ) + return resp.status_code == 200 + + wait_for_condition(health_ok, timeout=10) + + # Verify a config file contains health check enabled + with open(api.config_file_path, "r") as f: + config_content = f.read() + assert ( + "http-request return status 200" in config_content + ), "Health checks should be enabled in config" + + # Disable health checks + await api.disable() + + # Verify HAProxy is still running after calling disable() + assert api._is_running(), "HAProxy should still be running after disable" + + # Config should now deny the health endpoint + with open(api.config_file_path, "r") as f: + config_content = f.read() + assert ( + "http-request return status 503" in config_content + ), "Health checks should be disabled in config" + + def health_check_condition(status_code: int): + # Test health check endpoint now fails + health_response = requests.get( + f"http://127.0.0.1:{config.frontend_port}{config.health_check_endpoint}", + timeout=5, + ) + + return health_response.status_code == status_code + + wait_for_condition(health_check_condition, timeout=2, status_code=503) + + # Re-enable health checks + await api.enable() + + # Config should contain the 200 response again + with open(api.config_file_path, "r") as f: + config_content = f.read() + assert ( + "http-request return status 200" in config_content + ), "Health checks should be re-enabled in config" + + wait_for_condition(health_check_condition, timeout=5, status_code=200) + + finally: + backend_server.should_exit = True + backend_thread.join(timeout=5) + + +@pytest.mark.asyncio +async def test_health_endpoint_or_logic_multiple_backends(haproxy_api_cleanup): + """Test that the health endpoint returns 200 if ANY backend has at least one server UP (OR logic).""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + backend1_port = 9996 + backend2_port = 9997 + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + ), + stats_port=8404, + socket_path=socket_path, + has_received_routes=True, + has_received_servers=True, + ) + + backend1 = BackendConfig( + name="backend1", + path_prefix="/api1", + servers=[ + ServerConfig(name="server1", host="127.0.0.1", port=backend1_port) + ], + health_check_fall=1, + health_check_rise=1, + health_check_inter="1s", + ) + + backend2 = BackendConfig( + name="backend2", + path_prefix="/api2", + servers=[ + ServerConfig(name="server2", host="127.0.0.1", port=backend2_port) + ], + health_check_fall=1, + health_check_rise=1, + health_check_inter="1s", + ) + + backend1_server, backend1_thread = create_test_backend_server(backend1_port) + backend2_server, backend2_thread = create_test_backend_server(backend2_port) + + try: + api = HAProxyApi( + cfg=config, + backend_configs={backend1.name: backend1, backend2.name: backend2}, + config_file_path=config_file_path, + ) + + await api.start() + haproxy_api_cleanup(api) + + # Wait for health check to pass (both servers are UP) + def health_ok(): + resp = requests.get( + f"http://127.0.0.1:{config.frontend_port}{config.health_check_endpoint}", + timeout=5, + ) + return resp.status_code == 200 + + wait_for_condition(health_ok, timeout=10, retry_interval_ms=200) + + # Verify health check returns 200 when both servers are UP + health_response = requests.get( + f"http://127.0.0.1:{config.frontend_port}{config.health_check_endpoint}", + timeout=5, + ) + assert ( + health_response.status_code == 200 + ), "Health check should return 200 when both servers are UP" + assert b"success" in health_response.content + + # Stop backend1 server + backend1_server.should_exit = True + backend1_thread.join(timeout=5) + + # Wait a bit for HAProxy to detect backend1 is down + await asyncio.sleep(2) + + # Verify health check STILL returns 200 (backend2 is still UP - OR logic) + health_response = requests.get( + f"http://127.0.0.1:{config.frontend_port}{config.health_check_endpoint}", + timeout=5, + ) + assert ( + health_response.status_code == 200 + ), "Health check should return 200 when at least one backend (backend2) is UP (OR logic)" + assert b"success" in health_response.content + + # Stop backend2 server as well + backend2_server.should_exit = True + backend2_thread.join(timeout=5) + + # Wait for health check to fail (both servers are DOWN) + def health_fails(): + resp = requests.get( + f"http://127.0.0.1:{config.frontend_port}{config.health_check_endpoint}", + timeout=5, + ) + return resp.status_code == 503 + + wait_for_condition(health_fails, timeout=10, retry_interval_ms=200) + + # Verify health check returns 503 when ALL servers are DOWN + health_response = requests.get( + f"http://127.0.0.1:{config.frontend_port}{config.health_check_endpoint}", + timeout=5, + ) + assert ( + health_response.status_code == 503 + ), "Health check should return 503 when all servers are DOWN" + assert b"Service Unavailable" in health_response.content + + await api.stop() + finally: + # Cleanup + try: + if not backend1_server.should_exit: + backend1_server.should_exit = True + backend1_thread.join(timeout=5) + except Exception: + pass + try: + if not backend2_server.should_exit: + backend2_server.should_exit = True + backend2_thread.join(timeout=5) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_errorfile_creation_and_config(haproxy_api_cleanup): + """Test that the errorfile is created and configured correctly for both 502 and 504.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + # Launch a simple backend server with /fast endpoint + backend_port = 9107 + backend_server, backend_thread = create_test_backend_server(backend_port) + + # Configure HAProxy with one backend under root ('/') so upstream sees '/fast' + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + keep_alive_timeout_s=58, + ), + stats_port=8404, + socket_path=socket_path, + ) + + api = HAProxyApi(cfg=config, config_file_path=config_file_path) + haproxy_api_cleanup(api) + + # Verify the error file was created during initialization + expected_error_file_path = os.path.join(temp_dir, "500.http") + assert os.path.exists( + expected_error_file_path + ), "Error file 500.http should be created" + assert ( + api.cfg.error_file_path == expected_error_file_path + ), "Error file path should be set in config" + + # Verify the error file content + with open(expected_error_file_path, "r") as ef: + error_content = ef.read() + assert ( + "HTTP/1.1 500 Internal Server Error" in error_content + ), "Error file should contain 500 status" + assert ( + "Content-Type: text/plain" in error_content + ), "Error file should contain content-type header" + assert ( + "Internal Server Error" in error_content + ), "Error file should contain error message" + + # Start HAProxy and verify config contains errorfile directives + await api.start() + + # Verify config file contains errorfile directives for both 502 and 504 pointing to the same file + with open(config_file_path, "r") as f: + config_content = f.read() + assert ( + f"errorfile 502 {expected_error_file_path}" in config_content + ), "HAProxy config should contain 502 errorfile directive" + assert ( + f"errorfile 504 {expected_error_file_path}" in config_content + ), "HAProxy config should contain 504 errorfile directive" + + await api.stop() + backend = BackendConfig( + name="app_backend", + path_prefix="/", + app_name="app", + servers=[ServerConfig(name="server1", host="127.0.0.1", port=backend_port)], + timeout_http_keep_alive_s=58, + ) + + api = HAProxyApi( + cfg=config, + backend_configs={backend.name: backend}, + config_file_path=config_file_path, + ) + + haproxy_api_cleanup(api) + + try: + await api.start() + + # Ensure HAProxy is up (stats endpoint reachable) + wait_for_condition( + lambda: check_haproxy_ready(config.stats_port), + timeout=10, + retry_interval_ms=100, + ) + + # Route exists -> expect 200 + r = requests.get("http://127.0.0.1:8000/fast", timeout=5) + assert r.status_code == 200 + + # Remove backend (no targets for /app) and reload + api.set_backend_configs({}) + await api.reload() + + # After removal, route should fall back to default backend -> 404 + def get_status(): + resp = requests.get("http://127.0.0.1:8000/fast", timeout=5) + return resp.status_code + + # Allow a brief window for reload to take effect + wait_for_condition( + lambda: get_status() == 404, timeout=5, retry_interval_ms=100 + ) + + finally: + try: + await api.stop() + except Exception: + pass + + try: + backend_server.should_exit = True + backend_thread.join(timeout=5) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_routes_endpoint_returns_backends_and_respects_health( + haproxy_api_cleanup, +): + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + # Start two backend servers; health endpoint exists at '/-/healthz'. + backend_port1 = 9910 + backend_port2 = 9911 + backend_server1, backend_thread1 = create_test_backend_server(backend_port1) + backend_server2, backend_thread2 = create_test_backend_server(backend_port2) + + # Configure HAProxy with two prefixed backends + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8013, + keep_alive_timeout_s=58, + ), + stats_port=8413, + socket_path=socket_path, + ) + + backend_api = BackendConfig( + name="api_backend", + path_prefix="/api", + app_name="api_app", + servers=[ + ServerConfig(name="server1", host="127.0.0.1", port=backend_port1) + ], + timeout_http_keep_alive_s=58, + ) + backend_web = BackendConfig( + name="web_backend", + path_prefix="/web", + app_name="web_app", + servers=[ + ServerConfig(name="server2", host="127.0.0.1", port=backend_port2) + ], + timeout_http_keep_alive_s=58, + ) + + api = HAProxyApi( + cfg=config, + backend_configs={ + backend_api.name: backend_api, + backend_web.name: backend_web, + }, + config_file_path=config_file_path, + ) + + haproxy_api_cleanup(api) + + try: + await api.start() + + # Wait for HAProxy to be ready + wait_for_condition( + lambda: check_haproxy_ready(config.stats_port), + timeout=10, + retry_interval_ms=100, + ) + + # Helper to get fresh routes response (avoids connection reuse) + def get_routes(): + with requests.Session() as session: + return session.get("http://127.0.0.1:8013/-/routes", timeout=1) + + # Initial state: no routes + r = requests.get("http://127.0.0.1:8013/-/routes", timeout=5) + assert r.status_code == 503 + assert r.headers.get("content-type", "").startswith("text/plain") + assert r.text == "Route table is not populated yet." + + # Set has_received_routes but not has_received_servers -> should show "No replicas available" + api.cfg.has_received_routes = True + api.cfg.has_received_servers = False + await api.reload() + get_routes().text == "No replicas are available yet.", + r = get_routes() + assert r.status_code == 503 + assert r.headers.get("content-type", "").startswith("text/plain") + + # Set both flags -> should show routes JSON + api.cfg.has_received_routes = True + api.cfg.has_received_servers = True + await api.reload() + + # Reload is not synchronous, so we need to wait for the config to be applied + def check_json_routes(): + r = get_routes() + return r.status_code == 200 and r.headers.get( + "content-type", "" + ).startswith("application/json") + + wait_for_condition(check_json_routes, timeout=5, retry_interval_ms=50) + r = get_routes() + data = r.json() + assert data == {"/api": "api_app", "/web": "web_app"} + + # Disable (simulate draining/unhealthy) -> wait for healthz to flip, then routes 503 + await api.disable() + + def health_is(code: int): + resp = requests.get("http://127.0.0.1:8013/-/healthz", timeout=5) + return resp.status_code == code + + wait_for_condition(health_is, timeout=5, retry_interval_ms=100, code=503) + r = requests.get("http://127.0.0.1:8013/-/routes", timeout=5) + assert r.status_code == 503 + assert r.headers.get("content-type", "").startswith("text/plain") + assert r.text == "This node is being drained." + + # Re-enable -> wait for healthz to flip back, then routes 200 + await api.enable() + wait_for_condition(health_is, timeout=5, retry_interval_ms=100, code=200) + r = requests.get("http://127.0.0.1:8013/-/routes", timeout=5) + assert r.status_code == 200 + + finally: + try: + await api.stop() + except Exception: + pass + + +@pytest.mark.asyncio +async def test_routes_endpoint_no_routes(haproxy_api_cleanup): + """When no backends are configured, /-/routes should return {} and respect health gating.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8014, + keep_alive_timeout_s=58, + ), + stats_port=8414, + socket_path=socket_path, + ) + + api = HAProxyApi( + cfg=config, + backend_configs={}, + config_file_path=config_file_path, + ) + + haproxy_api_cleanup(api) + + try: + await api.start() + + # Wait for HAProxy to be ready + wait_for_condition( + lambda: check_haproxy_ready(config.stats_port), + timeout=10, + retry_interval_ms=100, + ) + + # Healthy -> expect 200 and empty JSON + r = requests.get( + f"http://127.0.0.1:{config.frontend_port}/-/routes", timeout=5 + ) + assert r.status_code == 503 + assert r.headers.get("content-type", "").startswith("text/plain") + assert r.text == "Route table is not populated yet." + + # Disable -> wait for healthz to flip, then expect 503 with draining message + await api.disable() + + def health_is(code: int): + resp = requests.get( + f"http://127.0.0.1:{config.frontend_port}/-/healthz", timeout=5 + ) + return resp.status_code == code + + wait_for_condition(health_is, timeout=5, retry_interval_ms=100, code=503) + + # Wait for routes endpoint to also return draining message (graceful reload might take a moment) + def routes_is_draining(): + try: + resp = requests.get( + f"http://127.0.0.1:{config.frontend_port}/-/routes", timeout=5 + ) + return ( + resp.status_code == 503 + and resp.text == "This node is being drained." + ) + except Exception: + return False + + wait_for_condition(routes_is_draining, timeout=5, retry_interval_ms=100) + + r = requests.get( + f"http://127.0.0.1:{config.frontend_port}/-/routes", timeout=5 + ) + assert r.status_code == 503 + assert r.headers.get("content-type", "").startswith("text/plain") + assert r.text == "This node is being drained." + + # Re-enable -> wait for healthz back to 200, then routes 200 + await api.enable() + wait_for_condition(health_is, timeout=5, retry_interval_ms=100, code=503) + + def routes_is_healthy(): + try: + r = requests.get( + f"http://127.0.0.1:{config.frontend_port}/-/routes", timeout=5 + ) + return ( + r.status_code == 503 + and r.text == "Route table is not populated yet." + ) + except Exception: + return False + + wait_for_condition(routes_is_healthy, timeout=5, retry_interval_ms=100) + finally: + try: + await api.stop() + except Exception: + pass + + +@pytest.mark.asyncio +async def test_404_error_message(haproxy_api_cleanup): + """Test that HAProxy returns the correct 404 error message for non-existent paths.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file_path = os.path.join(temp_dir, "haproxy.cfg") + socket_path = os.path.join(temp_dir, "admin.sock") + + # Create a backend that serves /api + backend = BackendConfig( + name="api_backend", + path_prefix="/api", + servers=[], # No servers, but we're testing the 404 path anyway + ) + + config = HAProxyConfig( + http_options=HTTPOptions( + host="127.0.0.1", + port=8000, + ), + stats_port=8404, + socket_path=socket_path, + ) + + api = HAProxyApi( + cfg=config, + backend_configs={"api_backend": backend}, + config_file_path=config_file_path, + ) + + await api.start() + haproxy_api_cleanup(api) + + # Verify HAProxy is running + assert api._is_running(), "HAProxy should be running" + + # Wait for HAProxy to be ready + wait_for_condition( + lambda: check_haproxy_ready(config.stats_port), + timeout=10, + retry_interval_ms=500, + ) + + # Request a non-existent path and verify the error message + response = requests.get( + f"http://127.0.0.1:{config.frontend_port}/nonexistent", + timeout=5, + ) + + assert response.status_code == 404, "Should return 404 for non-existent path" + assert ( + "Path '/nonexistent' not found" in response.text + ), f"Error message should contain path. Got: {response.text}" + assert ( + "Ping http://.../-/routes for available routes" in response.text + ), f"Error message should contain routes hint. Got: {response.text}" + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_metrics_haproxy.py b/python/ray/serve/tests/test_metrics_haproxy.py new file mode 100644 index 000000000000..409a75227402 --- /dev/null +++ b/python/ray/serve/tests/test_metrics_haproxy.py @@ -0,0 +1,958 @@ +""" +HAProxy metrics tests for Ray Serve. + +These tests verify that Ray Serve metrics work correctly when HAProxy is enabled +as a replacement for the default Serve HTTP proxy. + +Key differences from the default Serve proxy: +1. When HAProxy is enabled, RAY_SERVE_ENABLE_DIRECT_INGRESS is automatically set. +2. HTTP proxy metrics (serve_num_http_requests, etc.) are emitted from replicas when + they receive direct ingress requests from HAProxy. +3. 404 errors for non-existent routes are handled by HAProxy itself (not forwarded to + replicas), so these won't generate Serve metrics. Tests that need to verify 404 + metrics must deploy an application that returns 404s. +4. HAProxy has its own metrics exposed on a separate port (default 9101), but these + tests focus on Serve metrics exposed via the Ray metrics port (9999). +""" +import http +import json +import sys +from typing import Dict, Optional + +import httpx +import pytest +from fastapi import FastAPI +from starlette.requests import Request +from starlette.responses import PlainTextResponse + +import ray +from ray import serve +from ray._common.network_utils import parse_address +from ray._common.test_utils import SignalActor, wait_for_condition +from ray._common.utils import reset_ray_address +from ray._private.test_utils import ( + fetch_prometheus_metrics, +) +from ray.serve import HTTPOptions +from ray.serve._private.long_poll import LongPollHost, UpdatedObject +from ray.serve._private.test_utils import ( + get_application_url, +) +from ray.serve._private.utils import block_until_http_ready +from ray.serve.tests.conftest import TEST_METRICS_EXPORT_PORT +from ray.serve.tests.test_metrics import get_metric_dictionaries +from ray.util.state import list_actors + + +@pytest.fixture +def metrics_start_shutdown(request): + """Fixture provides a fresh Ray cluster to prevent metrics state sharing.""" + param = request.param if hasattr(request, "param") else None + request_timeout_s = param if param else None + ray.init( + _metrics_export_port=TEST_METRICS_EXPORT_PORT, + _system_config={ + "metrics_report_interval_ms": 100, + "task_retry_delay_ms": 50, + }, + ) + yield serve.start( + http_options=HTTPOptions( + host="0.0.0.0", + request_timeout_s=request_timeout_s, + ), + ) + serve.shutdown() + ray.shutdown() + reset_ray_address() + + +def extract_tags(line: str) -> Dict[str, str]: + """Extracts any tags from the metrics line.""" + + try: + tags_string = line.replace("{", "}").split("}")[1] + except IndexError: + # No tags were found in this line. + return {} + + detected_tags = {} + for tag_pair in tags_string.split(","): + sanitized_pair = tag_pair.replace('"', "") + tag, value = sanitized_pair.split("=") + detected_tags[tag] = value + + return detected_tags + + +def contains_tags(line: str, expected_tags: Optional[Dict[str, str]] = None) -> bool: + """Checks if the metrics line contains the expected tags. + + Does nothing if expected_tags is None. + """ + + if expected_tags is not None: + detected_tags = extract_tags(line) + + # Check if expected_tags is a subset of detected_tags + return expected_tags.items() <= detected_tags.items() + else: + return True + + +def get_metric_float( + metric: str, expected_tags: Optional[Dict[str, str]] = None +) -> float: + """Gets the float value of metric. + + If tags is specified, searched for metric with matching tags. + + Returns -1 if the metric isn't available. + """ + + metrics = httpx.get("http://127.0.0.1:9999").text + metric_value = -1 + for line in metrics.split("\n"): + if metric in line and contains_tags(line, expected_tags): + metric_value = line.split(" ")[-1] + return metric_value + + +def check_metric_float_eq( + metric: str, expected: float, expected_tags: Optional[Dict[str, str]] = None +) -> bool: + metric_value = get_metric_float(metric, expected_tags) + assert float(metric_value) == expected + return True + + +def check_sum_metric_eq( + metric_name: str, + expected: float, + tags: Optional[Dict[str, str]] = None, +) -> bool: + if tags is None: + tags = {} + + metrics = fetch_prometheus_metrics(["localhost:9999"]) + metrics = {k: v for k, v in metrics.items() if "ray_serve_" in k} + metric_samples = metrics.get(metric_name, None) + if metric_samples is None: + metric_sum = 0 + else: + metric_samples = [ + sample for sample in metric_samples if tags.items() <= sample.labels.items() + ] + metric_sum = sum(sample.value for sample in metric_samples) + + # Check the metrics sum to the expected number + assert float(metric_sum) == float(expected), ( + f"The following metrics don't sum to {expected}: " + f"{json.dumps(metric_samples, indent=4)}\n." + f"All metrics: {json.dumps(metrics, indent=4)}" + ) + + # # For debugging + if metric_samples: + print(f"The following sum to {expected} for '{metric_name}' and tags {tags}:") + for sample in metric_samples: + print(sample) + + return True + + +def test_serve_metrics_for_successful_connection(metrics_start_shutdown): + @serve.deployment(name="metrics") + async def f(request): + return "hello" + + app_name = "app1" + handle = serve.run(target=f.bind(), name=app_name) + + http_url = get_application_url(app_name=app_name) + # send 10 concurrent requests + ray.get([block_until_http_ready.remote(http_url) for _ in range(10)]) + [handle.remote(http_url) for _ in range(10)] + + def verify_metrics(do_assert=False): + try: + resp = httpx.get("http://127.0.0.1:9999").text + # Requests will fail if we are crashing the controller + except httpx.HTTPError: + return False + + # NOTE: These metrics should be documented at + # https://docs.ray.io/en/latest/serve/monitoring.html#metrics + # Any updates to here should be reflected there too. + expected_metrics = [ + # counter + "ray_serve_num_router_requests", + "ray_serve_num_http_requests", + "ray_serve_deployment_queued_queries", + "ray_serve_deployment_request_counter", + "ray_serve_deployment_replica_starts", + # histogram + "ray_serve_deployment_processing_latency_ms_bucket", + "ray_serve_deployment_processing_latency_ms_count", + "ray_serve_deployment_processing_latency_ms_sum", + "ray_serve_deployment_processing_latency_ms", + # gauge + "ray_serve_replica_processing_queries", + "ray_serve_deployment_replica_healthy", + # handle + "ray_serve_handle_request_counter", + ] + + for metric in expected_metrics: + # For the final error round + if do_assert: + assert metric in resp + # For the wait_for_condition + else: + if metric not in resp: + return False + return True + + try: + wait_for_condition(verify_metrics, retry_interval_ms=500) + except RuntimeError: + verify_metrics(do_assert=True) + + +def test_http_replica_gauge_metrics(metrics_start_shutdown): + """Test http replica gauge metrics""" + signal = SignalActor.remote() + + @serve.deployment(graceful_shutdown_timeout_s=0.0001) + class A: + async def __call__(self): + await signal.wait.remote() + + handle = serve.run(A.bind(), name="app1") + _ = handle.remote() + + processing_requests = get_metric_dictionaries( + "ray_serve_replica_processing_queries", timeout=5 + ) + assert len(processing_requests) == 1 + assert processing_requests[0]["deployment"] == "A" + assert processing_requests[0]["application"] == "app1" + print("ray_serve_replica_processing_queries exists.") + + def ensure_request_processing(): + resp = httpx.get("http://127.0.0.1:9999").text + resp = resp.split("\n") + for metrics in resp: + if "# HELP" in metrics or "# TYPE" in metrics: + continue + if "ray_serve_replica_processing_queries" in metrics: + assert "1.0" in metrics + return True + + wait_for_condition(ensure_request_processing, timeout=5) + + +def test_proxy_metrics_not_found(metrics_start_shutdown): + # NOTE: When using HAProxy, 404 errors for non-existent routes are handled + # by HAProxy itself (not forwarded to replicas), so we need to deploy an + # application and test 404s within that application's context. + # These metrics should be documented at + # https://docs.ray.io/en/latest/serve/monitoring.html#metrics + # Any updates here should be reflected there too. + expected_metrics = [ + "ray_serve_num_http_requests", + "ray_serve_num_http_error_requests_total", + "ray_serve_num_deployment_http_error_requests", + "ray_serve_http_request_latency_ms", + ] + + app = FastAPI() + + @serve.deployment(name="A") + @serve.ingress(app) + class A: + @app.get("/existing-path") # Only this path is defined + async def handler(self, request: Request): + return {"message": "success"} + + app_name = "app" + serve.run(A.bind(), name=app_name, route_prefix="/A") + + def verify_metrics(_expected_metrics, do_assert=False): + try: + resp = httpx.get("http://127.0.0.1:9999").text + # Requests will fail if we are crashing the controller + except httpx.HTTPError: + return False + for metric in _expected_metrics: + if do_assert: + assert metric in resp + if metric not in resp: + return False + return True + + # Trigger HTTP 404 error via the deployed application + httpx.get("http://127.0.0.1:8000/A/nonexistent") + httpx.get("http://127.0.0.1:8000/A/nonexistent") + + # Ensure all expected metrics are present. + try: + wait_for_condition( + verify_metrics, + retry_interval_ms=1000, + timeout=10, + expected_metrics=expected_metrics, + ) + except RuntimeError: + verify_metrics(expected_metrics, True) + + def verify_error_count(do_assert=False): + resp = httpx.get("http://127.0.0.1:9999").text + resp = resp.split("\n") + http_error_count = 0 + deployment_404_count = 0 + + for metrics in resp: + if "# HELP" in metrics or "# TYPE" in metrics: + continue + # Skip health check metrics + if "/-/healthz" in metrics: + continue + if ( + "ray_serve_num_http_error_requests_total" in metrics + and 'route="/A"' in metrics + ): + # Accumulate error counts from route "/A" + http_error_count += int(float(metrics.split(" ")[-1])) + elif ( + "ray_serve_num_deployment_http_error_requests_total" in metrics + and 'route="/A"' in metrics + and 'error_code="404"' in metrics + ): + # Count deployment 404 errors + deployment_404_count += int(float(metrics.split(" ")[-1])) + + # We expect 2 requests total, both should be 404 errors from the deployment + if do_assert: + assert ( + http_error_count == 2 + ), f"Expected at least 2 HTTP errors, got {http_error_count}" + assert ( + deployment_404_count == 2 + ), f"Expected 2 deployment 404 errors, got {deployment_404_count}" + + return http_error_count >= 2 and deployment_404_count == 2 + + # There is a latency in updating the counter + try: + wait_for_condition(verify_error_count, retry_interval_ms=1000, timeout=20) + except RuntimeError: + verify_error_count(do_assert=True) + + +def test_proxy_metrics_internal_error(metrics_start_shutdown): + # NOTE: When using HAProxy, we need the replica to stay alive to emit metrics. + # Instead of crashing the actor (which prevents metric emission), we return + # a 500 error explicitly. + # These metrics should be documented at + # https://docs.ray.io/en/latest/serve/monitoring.html#metrics + # Any updates here should be reflected there too. + expected_metrics = [ + "ray_serve_num_http_requests", + "ray_serve_num_http_error_requests_total", + "ray_serve_num_deployment_http_error_requests", + "ray_serve_http_request_latency_ms", + ] + + def verify_metrics(_expected_metrics, do_assert=False): + try: + resp = httpx.get("http://127.0.0.1:9999", timeout=None).text + # Requests will fail if we are crashing the controller + except httpx.HTTPError: + return False + for metric in _expected_metrics: + if do_assert: + assert metric in resp + if metric not in resp: + return False + return True + + @serve.deployment(name="A") + class A: + async def __init__(self): + pass + + async def __call__(self, request: Request): + # Return 500 Internal Server Error + return PlainTextResponse("Internal Server Error", status_code=500) + + app_name = "app" + serve.run(A.bind(), name=app_name, route_prefix="/") + + httpx.get("http://localhost:8000/", timeout=None) + httpx.get("http://localhost:8000/", timeout=None) + + # Ensure all expected metrics are present. + try: + wait_for_condition( + verify_metrics, + retry_interval_ms=1000, + timeout=10, + expected_metrics=expected_metrics, + ) + except RuntimeError: + verify_metrics(expected_metrics, True) + + def verify_error_count(do_assert=False): + resp = httpx.get("http://127.0.0.1:9999", timeout=None).text + resp = resp.split("\n") + for metrics in resp: + if "# HELP" in metrics or "# TYPE" in metrics: + continue + if "ray_serve_num_http_error_requests_total" in metrics: + # route "/" should have error count 2 (HTTP 500) + if do_assert: + assert "2.0" in metrics + if "2.0" not in metrics: + return False + elif "ray_serve_num_deployment_http_error_requests" in metrics: + # deployment A should have error count 2 (HTTP 500) + if do_assert: + assert 'deployment="A"' in metrics and "2.0" in metrics + if 'deployment="A"' not in metrics or "2.0" not in metrics: + return False + return True + + # There is a latency in updating the counter + try: + wait_for_condition(verify_error_count, retry_interval_ms=1000, timeout=10) + except RuntimeError: + verify_error_count(do_assert=True) + + +def test_proxy_metrics_fields_not_found(metrics_start_shutdown): + """Tests the proxy metrics' fields' behavior for not found. + + Note: When using HAProxy, we need to deploy an application that returns 404, + as HAProxy handles non-existent route 404s itself without forwarding to replicas. + """ + # These metrics should be documented at + # https://docs.ray.io/en/latest/serve/monitoring.html#metrics + # Any updates here should be reflected there too. + expected_metrics = [ + "ray_serve_num_http_requests", + "ray_serve_num_http_error_requests_total", + "ray_serve_num_deployment_http_error_requests", + "ray_serve_http_request_latency_ms", + ] + + app = FastAPI() + + @serve.deployment(name="test_app") + @serve.ingress(app) + class NotFoundApp: + @app.get("/existing-path") # Only this path is defined + async def handler(self, request: Request): + return {"message": "success"} + + app_name = "app" + serve.run(NotFoundApp.bind(), name=app_name, route_prefix="/test") + + def verify_metrics(_expected_metrics, do_assert=False): + try: + resp = httpx.get("http://127.0.0.1:9999").text + # Requests will fail if we are crashing the controller + except httpx.HTTPError: + return False + for metric in _expected_metrics: + if do_assert: + assert metric in resp + if metric not in resp: + return False + return True + + # Trigger HTTP 404 error via the deployed application + httpx.get("http://127.0.0.1:8000/test/nonexistent") + httpx.get("http://127.0.0.1:8000/test/nonexistent") + + # Ensure all expected metrics are present. + try: + wait_for_condition( + verify_metrics, + retry_interval_ms=1000, + timeout=10, + expected_metrics=expected_metrics, + ) + except RuntimeError: + verify_metrics(expected_metrics, True) + + def verify_error_count(do_assert=False): + resp = httpx.get("http://127.0.0.1:9999").text + resp = resp.split("\n") + http_error_count = 0 + deployment_404_count = 0 + + for metrics in resp: + if "# HELP" in metrics or "# TYPE" in metrics: + continue + # Skip health check metrics + if "/-/healthz" in metrics: + continue + if ( + "ray_serve_num_http_error_requests_total" in metrics + and 'route="/test"' in metrics + ): + # Accumulate error counts from route "/test" + http_error_count += int(float(metrics.split(" ")[-1])) + elif ( + "ray_serve_num_deployment_http_error_requests_total" in metrics + and 'route="/test"' in metrics + and 'error_code="404"' in metrics + ): + # Count deployment 404 errors + deployment_404_count += int(float(metrics.split(" ")[-1])) + + # We expect 2 requests total, both should be 404 errors from the deployment + if do_assert: + assert ( + http_error_count == 2 + ), f"Expected at least 2 HTTP errors, got {http_error_count}" + assert ( + deployment_404_count == 2 + ), f"Expected 2 deployment 404 errors, got {deployment_404_count}" + + return http_error_count >= 2 and deployment_404_count == 2 + + # There is a latency in updating the counter + try: + wait_for_condition(verify_error_count, retry_interval_ms=1000, timeout=20) + except RuntimeError: + verify_error_count(do_assert=True) + + +@pytest.mark.parametrize( + "metrics_start_shutdown", + [ + 1, + ], + indirect=True, +) +def test_proxy_timeout_metrics(metrics_start_shutdown): + """Test that HTTP timeout metrics are reported correctly.""" + signal = SignalActor.remote() + + @serve.deployment + async def return_status_code_with_timeout(request: Request): + await signal.wait.remote() + return + + serve.run( + return_status_code_with_timeout.bind(), + route_prefix="/status_code_timeout", + name="status_code_timeout", + ) + + http_url = get_application_url("HTTP", app_name="status_code_timeout") + + r = httpx.get(http_url) + assert r.status_code == 408 + ray.get(signal.send.remote(clear=True)) + + num_errors = get_metric_dictionaries("ray_serve_num_http_error_requests_total") + assert len(num_errors) == 1 + assert num_errors[0]["route"] == "/status_code_timeout" + assert num_errors[0]["error_code"] == "408" + assert num_errors[0]["method"] == "GET" + assert num_errors[0]["application"] == "status_code_timeout" + + +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows") +def test_proxy_disconnect_http_metrics(metrics_start_shutdown): + """Test that HTTP disconnect metrics are reported correctly.""" + + signal = SignalActor.remote() + + @serve.deployment + class Disconnect: + async def __call__(self, request: Request): + await signal.wait.remote() + return + + serve.run( + Disconnect.bind(), + route_prefix="/disconnect", + name="disconnect", + ) + + # Simulate an HTTP disconnect + http_url = get_application_url("HTTP", app_name="disconnect") + ip_port = http_url.replace("http://", "").split("/")[0] # remove the route prefix + ip, port = parse_address(ip_port) + conn = http.client.HTTPConnection(ip, int(port)) + conn.request("GET", "/disconnect") + wait_for_condition( + lambda: ray.get(signal.cur_num_waiters.remote()) == 1, timeout=10 + ) + conn.close() # Forcefully close the connection + ray.get(signal.send.remote(clear=True)) + + num_errors = get_metric_dictionaries("ray_serve_num_http_error_requests_total") + assert len(num_errors) == 1 + assert num_errors[0]["route"] == "/disconnect" + assert num_errors[0]["error_code"] == "499" + assert num_errors[0]["method"] == "GET" + assert num_errors[0]["application"] == "disconnect" + + +def test_proxy_metrics_fields_internal_error(metrics_start_shutdown): + """Tests the proxy metrics' fields' behavior for internal error.""" + + @serve.deployment() + def f(*args): + return 1 / 0 + + real_app_name = "app" + real_app_name2 = "app2" + serve.run(f.bind(), name=real_app_name, route_prefix="/real_route") + serve.run(f.bind(), name=real_app_name2, route_prefix="/real_route2") + + # Deployment should generate divide-by-zero errors + correct_url = get_application_url("HTTP", real_app_name) + _ = httpx.get(correct_url).text + print("Sent requests to correct URL.") + + num_deployment_errors = get_metric_dictionaries( + "ray_serve_num_deployment_http_error_requests_total" + ) + assert len(num_deployment_errors) == 1 + assert num_deployment_errors[0]["deployment"] == "f" + assert num_deployment_errors[0]["error_code"] == "500" + assert num_deployment_errors[0]["method"] == "GET" + assert num_deployment_errors[0]["application"] == "app" + print("ray_serve_num_deployment_http_error_requests working as expected.") + + latency_metrics = get_metric_dictionaries("ray_serve_http_request_latency_ms_sum") + # Filter out health check metrics - HAProxy generates health checks to /-/healthz + latency_metrics = [m for m in latency_metrics if m["route"] != "/-/healthz"] + assert len(latency_metrics) == 1 + assert latency_metrics[0]["method"] == "GET" + assert latency_metrics[0]["route"] == "/real_route" + assert latency_metrics[0]["application"] == "app" + assert latency_metrics[0]["status_code"] == "500" + print("ray_serve_http_request_latency_ms working as expected.") + + +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows") +def test_proxy_metrics_http_status_code_is_error(metrics_start_shutdown): + """Verify that 2xx and 3xx status codes aren't errors, others are.""" + + def check_request_count_metrics( + expected_error_count: int, + expected_success_count: int, + ): + resp = httpx.get("http://127.0.0.1:9999").text + error_count = 0 + success_count = 0 + for line in resp.split("\n"): + # Skip health check metrics + if "/-/healthz" in line: + continue + if line.startswith("ray_serve_num_http_error_requests_total"): + error_count += int(float(line.split(" ")[-1])) + if line.startswith("ray_serve_num_http_requests_total"): + success_count += int(float(line.split(" ")[-1])) + + assert error_count == expected_error_count + assert success_count == expected_success_count + return True + + @serve.deployment + async def return_status_code(request: Request): + code = int((await request.body()).decode("utf-8")) + return PlainTextResponse("", status_code=code) + + serve.run(return_status_code.bind()) + + http_url = get_application_url("HTTP") + + # 200 is not an error. + r = httpx.request("GET", http_url, content=b"200") + assert r.status_code == 200 + wait_for_condition( + check_request_count_metrics, + expected_error_count=0, + expected_success_count=1, + ) + + # 2xx is not an error. + r = httpx.request("GET", http_url, content=b"250") + assert r.status_code == 250 + wait_for_condition( + check_request_count_metrics, + expected_error_count=0, + expected_success_count=2, + ) + + # 3xx is not an error. + r = httpx.request("GET", http_url, content=b"300") + assert r.status_code == 300 + wait_for_condition( + check_request_count_metrics, + expected_error_count=0, + expected_success_count=3, + ) + + # 4xx is an error. + r = httpx.request("GET", http_url, content=b"400") + assert r.status_code == 400 + wait_for_condition( + check_request_count_metrics, + expected_error_count=1, + expected_success_count=4, + ) + + # 5xx is an error. + r = httpx.request("GET", http_url, content=b"500") + assert r.status_code == 500 + wait_for_condition( + check_request_count_metrics, + expected_error_count=2, + expected_success_count=5, + ) + + +def test_replica_metrics_fields(metrics_start_shutdown): + """Test replica metrics fields""" + + @serve.deployment + def f(): + return "hello" + + @serve.deployment + def g(): + return "world" + + serve.run(f.bind(), name="app1", route_prefix="/f") + serve.run(g.bind(), name="app2", route_prefix="/g") + url_f = get_application_url("HTTP", "app1") + url_g = get_application_url("HTTP", "app2") + + assert "hello" == httpx.post(url_f).text + assert "world" == httpx.post(url_g).text + + wait_for_condition( + lambda: len( + get_metric_dictionaries("ray_serve_deployment_request_counter_total") + ) + == 2, + timeout=40, + ) + + metrics = get_metric_dictionaries("ray_serve_deployment_request_counter_total") + assert len(metrics) == 2 + expected_output = { + ("/f", "f", "app1"), + ("/g", "g", "app2"), + } + assert { + ( + metric["route"], + metric["deployment"], + metric["application"], + ) + for metric in metrics + } == expected_output + + start_metrics = get_metric_dictionaries("ray_serve_deployment_replica_starts_total") + assert len(start_metrics) == 2 + expected_output = {("f", "app1"), ("g", "app2")} + assert { + (start_metric["deployment"], start_metric["application"]) + for start_metric in start_metrics + } == expected_output + + # Latency metrics + wait_for_condition( + lambda: len( + get_metric_dictionaries("ray_serve_deployment_processing_latency_ms_count") + ) + == 2, + timeout=40, + ) + for metric_name in [ + "ray_serve_deployment_processing_latency_ms_count", + "ray_serve_deployment_processing_latency_ms_sum", + ]: + latency_metrics = get_metric_dictionaries(metric_name) + print(f"checking metric {metric_name}, {latency_metrics}") + assert len(latency_metrics) == 2 + expected_output = {("f", "app1"), ("g", "app2")} + assert { + (latency_metric["deployment"], latency_metric["application"]) + for latency_metric in latency_metrics + } == expected_output + + wait_for_condition( + lambda: len(get_metric_dictionaries("ray_serve_replica_processing_queries")) + == 2 + ) + processing_queries = get_metric_dictionaries("ray_serve_replica_processing_queries") + expected_output = {("f", "app1"), ("g", "app2")} + assert { + (processing_query["deployment"], processing_query["application"]) + for processing_query in processing_queries + } == expected_output + + @serve.deployment + def h(): + return 1 / 0 + + serve.run(h.bind(), name="app3", route_prefix="/h") + url_h = get_application_url("HTTP", "app3") + assert 500 == httpx.get(url_h).status_code + wait_for_condition( + lambda: len(get_metric_dictionaries("ray_serve_deployment_error_counter_total")) + == 1, + timeout=40, + ) + err_requests = get_metric_dictionaries("ray_serve_deployment_error_counter_total") + assert len(err_requests) == 1 + expected_output = ("/h", "h", "app3") + assert ( + err_requests[0]["route"], + err_requests[0]["deployment"], + err_requests[0]["application"], + ) == expected_output + + wait_for_condition( + lambda: len(get_metric_dictionaries("ray_serve_deployment_replica_healthy")) + == 3, + ) + health_metrics = get_metric_dictionaries("ray_serve_deployment_replica_healthy") + expected_output = { + ("f", "app1"), + ("g", "app2"), + ("h", "app3"), + } + assert { + (health_metric["deployment"], health_metric["application"]) + for health_metric in health_metrics + } == expected_output + + +@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows") +def test_multiplexed_metrics(metrics_start_shutdown): + """Tests multiplexed API corresponding metrics.""" + + @serve.deployment + class Model: + @serve.multiplexed(max_num_models_per_replica=2) + async def get_model(self, model_id: str): + return model_id + + async def __call__(self, model_id: str): + await self.get_model(model_id) + return + + handle = serve.run(Model.bind(), name="app", route_prefix="/app") + handle.remote("model1") + handle.remote("model2") + # Trigger model eviction. + handle.remote("model3") + expected_metrics = [ + "ray_serve_multiplexed_model_load_latency_ms", + "ray_serve_multiplexed_model_unload_latency_ms", + "ray_serve_num_multiplexed_models", + "ray_serve_multiplexed_models_load_counter", + "ray_serve_multiplexed_models_unload_counter", + ] + + def verify_metrics(): + try: + resp = httpx.get("http://127.0.0.1:9999").text + # Requests will fail if we are crashing the controller + except httpx.HTTPError: + return False + for metric in expected_metrics: + assert metric in resp + return True + + wait_for_condition( + verify_metrics, + timeout=40, + retry_interval_ms=1000, + ) + + +def test_long_poll_host_sends_counted(serve_instance): + """Check that the transmissions by the long_poll are counted.""" + + host = ray.remote(LongPollHost).remote( + listen_for_change_request_timeout_s=(0.01, 0.01) + ) + + # Write a value. + ray.get(host.notify_changed.remote({"key_1": 999})) + object_ref = host.listen_for_change.remote({"key_1": -1}) + + # Check that the result's size is reported. + result_1: Dict[str, UpdatedObject] = ray.get(object_ref) + wait_for_condition( + check_metric_float_eq, + timeout=15, + metric="ray_serve_long_poll_host_transmission_counter", + expected=1, + expected_tags={"namespace_or_state": "key_1"}, + ) + + # Write two new values. + ray.get(host.notify_changed.remote({"key_1": 1000})) + ray.get(host.notify_changed.remote({"key_2": 1000})) + object_ref = host.listen_for_change.remote( + {"key_1": result_1["key_1"].snapshot_id, "key_2": -1} + ) + + # Check that the new objects are transmitted. + result_2: Dict[str, UpdatedObject] = ray.get(object_ref) + wait_for_condition( + check_metric_float_eq, + timeout=15, + metric="ray_serve_long_poll_host_transmission_counter", + expected=1, + expected_tags={"namespace_or_state": "key_2"}, + ) + wait_for_condition( + check_metric_float_eq, + timeout=15, + metric="ray_serve_long_poll_host_transmission_counter", + expected=2, + expected_tags={"namespace_or_state": "key_1"}, + ) + + # Check that a timeout result is counted. + object_ref = host.listen_for_change.remote({"key_2": result_2["key_2"].snapshot_id}) + _ = ray.get(object_ref) + wait_for_condition( + check_metric_float_eq, + timeout=15, + metric="ray_serve_long_poll_host_transmission_counter", + expected=1, + expected_tags={"namespace_or_state": "TIMEOUT"}, + ) + + +def test_actor_summary(serve_instance): + @serve.deployment + def f(): + pass + + serve.run(f.bind(), name="app") + actors = list_actors(filters=[("state", "=", "ALIVE")]) + class_names = {actor["class_name"] for actor in actors} + assert class_names.issuperset( + {"ServeController", "HAProxyManager", "ServeReplica:app:f"} + ) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/BUILD.bazel b/python/ray/serve/tests/unit/BUILD.bazel index bfab69748aae..d1c6f00822a4 100644 --- a/python/ray/serve/tests/unit/BUILD.bazel +++ b/python/ray/serve/tests/unit/BUILD.bazel @@ -9,7 +9,9 @@ py_library( py_test_run_all_subdirectory( size = "small", include = glob(["test_*.py"]), - exclude = [], + exclude = [ + "test_controller_haproxy.py", # Requires RAY_SERVE_ENABLE_HA_PROXY=1 + ], extra_srcs = [], tags = ["team:serve"], deps = [ @@ -19,6 +21,28 @@ py_test_run_all_subdirectory( ], ) +# HAProxy controller unit tests (require RAY_SERVE_ENABLE_HA_PROXY=1). +py_test_module_list( + size = "small", + env = { + "RAY_SERVE_ENABLE_HA_PROXY": "1", + "RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S": "0.01", + }, + files = [ + "test_controller_haproxy.py", + ], + tags = [ + "haproxy", + "no_windows", + "team:serve", + ], + deps = [ + ":conftest", + "//python/ray/serve:serve_lib", + "//python/ray/serve/tests:common", + ], +) + py_test_module_list( size = "small", timeout = "short", diff --git a/python/ray/serve/tests/unit/test_controller_haproxy.py b/python/ray/serve/tests/unit/test_controller_haproxy.py new file mode 100644 index 000000000000..c5ec31a566f5 --- /dev/null +++ b/python/ray/serve/tests/unit/test_controller_haproxy.py @@ -0,0 +1,171 @@ +from unittest import mock + +import pytest + +from ray.serve._private.common import ( + DeploymentID, + ReplicaID, + RequestProtocol, + RunningReplicaInfo, +) +from ray.serve.schema import ( + Target, + TargetGroup, +) +from ray.serve.tests.unit.test_controller_direct_ingress import ( + FakeApplicationStateManager, + FakeDeploymentStateManager, + FakeDirectIngressController, + FakeKVStore, + FakeLongPollHost, + FakeProxyStateManager, +) + + +# Test Controller that overrides methods and dependencies for HAProxy testing +class FakeHAProxyController(FakeDirectIngressController): + def __init__( + self, + kv_store, + long_poll_host, + application_state_manager, + deployment_state_manager, + proxy_state_manager, + ): + super().__init__( + kv_store=kv_store, + long_poll_host=long_poll_host, + application_state_manager=application_state_manager, + deployment_state_manager=deployment_state_manager, + proxy_state_manager=proxy_state_manager, + ) + + self._ha_proxy_enabled = True + + +@pytest.fixture +def haproxy_controller(): + kv_store = FakeKVStore() + long_poll_host = FakeLongPollHost() + app_state_manager = FakeApplicationStateManager({}, {}, {}) + deployment_state_manager = FakeDeploymentStateManager({}) + proxy_state_manager = FakeProxyStateManager() + + controller = FakeHAProxyController( + kv_store=kv_store, + long_poll_host=long_poll_host, + application_state_manager=app_state_manager, + deployment_state_manager=deployment_state_manager, + proxy_state_manager=proxy_state_manager, + ) + + yield controller + + +@pytest.mark.parametrize("from_proxy_manager", [True, False]) +@pytest.mark.parametrize("ha_proxy_enabled", [True, False]) +def test_get_target_groups_haproxy( + haproxy_controller: FakeHAProxyController, + from_proxy_manager: bool, + ha_proxy_enabled: bool, +): + """Tests get_target_groups returns the appropriate target groups based on the + ha_proxy_enabled and from_proxy_manager parameters.""" + + haproxy_controller._ha_proxy_enabled = ha_proxy_enabled + + # Setup test data with running applications + app_statuses = {"app1": {}} + route_prefixes = {"app1": "/app1"} + ingress_deployments = {"app1": "app1_ingress"} + + deployment_id1 = DeploymentID(name="app1_ingress", app_name="app1") + + # Create replica info + replica_id1 = ReplicaID(unique_id="replica1", deployment_id=deployment_id1) + replica_info1 = RunningReplicaInfo( + replica_id=replica_id1, + node_id="node1", + node_ip="10.0.0.1", + availability_zone="az1", + actor_name=mock.Mock(), + max_ongoing_requests=100, + ) + + running_replica_infos = {deployment_id1: [replica_info1]} + + # Setup test application state manager + haproxy_controller.application_state_manager = FakeApplicationStateManager( + app_statuses=app_statuses, + route_prefixes=route_prefixes, + ingress_deployments=ingress_deployments, + ) + + # Setup test deployment state manager + haproxy_controller.deployment_state_manager = FakeDeploymentStateManager( + running_replica_infos=running_replica_infos, + ) + + # Setup proxy state manager + haproxy_controller.proxy_state_manager.add_proxy_details("proxy_node1", "10.0.1.1") + haproxy_controller.proxy_state_manager.add_proxy_details("proxy_node2", "10.0.1.2") + + # Allocate ports for replicas using controller's methods + http_port1 = haproxy_controller.allocate_replica_port( + "node1", replica_id1.unique_id, RequestProtocol.HTTP + ) + grpc_port1 = haproxy_controller.allocate_replica_port( + "node1", replica_id1.unique_id, RequestProtocol.GRPC + ) + + target_groups = haproxy_controller.get_target_groups( + from_proxy_manager=from_proxy_manager + ) + + # Create expected target groups + if ha_proxy_enabled and not from_proxy_manager: + expected_target_groups = [ + TargetGroup( + protocol=RequestProtocol.HTTP, + route_prefix="/", + targets=[ + Target(ip="10.0.1.1", port=8000, instance_id=""), + Target(ip="10.0.1.2", port=8000, instance_id=""), + ], + ), + TargetGroup( + protocol=RequestProtocol.GRPC, + route_prefix="/", + targets=[ + Target(ip="10.0.1.1", port=9000, instance_id=""), + Target(ip="10.0.1.2", port=9000, instance_id=""), + ], + ), + ] + else: + expected_target_groups = [ + TargetGroup( + protocol=RequestProtocol.HTTP, + route_prefix="/app1", + targets=[ + Target(ip="10.0.0.1", port=http_port1, instance_id=""), + ], + ), + TargetGroup( + protocol=RequestProtocol.GRPC, + route_prefix="/app1", + targets=[ + Target(ip="10.0.0.1", port=grpc_port1, instance_id=""), + ], + ), + ] + + # Sort both lists to ensure consistent comparison + target_groups.sort(key=lambda g: (g.protocol, g.route_prefix)) + expected_target_groups.sort(key=lambda g: (g.protocol, g.route_prefix)) + + assert target_groups == expected_target_groups + + +if __name__ == "__main__": + pytest.main()