Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,12 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)


# Global progress state — written by weight loader, read by health monitor
_current_load_progress: tuple[int, int] = (0, 0)


def get_current_load_progress() -> tuple[int, int]:
"""Get current weight loading progress (files_loaded, total_files)."""
return _current_load_progress
179 changes: 179 additions & 0 deletions vllm/v1/executor/orphan_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Detect and clean up orphaned vLLM/Ray GPU worker processes.

Orphaned processes occur when NeMo RL triggers an unclean shutdown
(timeout, OOM, signal) and vLLM workers under the Ray distributed
executor aren't properly terminated. They hold GPU memory and NCCL
communicators, blocking subsequent runs.
"""

import os
import signal
import subprocess

from vllm.logger import init_logger

logger = init_logger(__name__)


def cleanup_orphaned_gpu_workers() -> int:
"""Find and kill orphaned vLLM Ray worker processes holding GPU memory.

Strategy:
1. Use nvidia-smi to find processes using GPU memory
2. Filter to Python processes that look like vLLM/Ray workers
3. Check if they belong to the current Ray session (if Ray is running)
4. Kill orphans that don't belong to any active Ray session

Returns:
Number of orphaned processes killed.
"""
if os.environ.get("VLLM_CLEANUP_ORPHANS_ON_STARTUP", "1") == "0":
logger.info("Orphan cleanup disabled (VLLM_CLEANUP_ORPHANS_ON_STARTUP=0)")
return 0

try:
gpu_pids = _get_gpu_process_pids()
except Exception as e:
logger.warning("Failed to query GPU processes: %s", e)
return 0

if not gpu_pids:
logger.debug("No GPU processes found — nothing to clean up.")
return 0

current_pid = os.getpid()
current_ray_session = _get_current_ray_session_id()

killed = 0
for pid, gpu_mem_mb, cmdline in gpu_pids:
if pid == current_pid:
continue

# Only target vLLM/Ray worker processes
if not _is_vllm_ray_worker(cmdline):
continue

# If Ray is running, check if process belongs to current session
if current_ray_session and _belongs_to_ray_session(pid, current_ray_session):
continue

logger.warning(
"Killing orphaned vLLM/Ray worker process: pid=%d, gpu_mem=%dMB, cmd=%s",
pid, gpu_mem_mb, cmdline[:100],
)
try:
os.kill(pid, signal.SIGTERM)
killed += 1
except ProcessLookupError:
pass # Already dead
except PermissionError:
logger.warning("Permission denied killing pid %d", pid)

if killed:
logger.info("Cleaned up %d orphaned GPU worker process(es).", killed)
return killed


def _get_gpu_process_pids() -> list[tuple[int, int, str]]:
"""Get (pid, gpu_mem_mb, cmdline) for all GPU processes."""
try:
result = subprocess.run(
["nvidia-smi", "--query-compute-apps=pid,used_memory",
"--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=10,
)
if result.returncode != 0:
return []
except FileNotFoundError:
return []

processes = []
for line in result.stdout.strip().split("\n"):
if not line.strip():
continue
parts = line.strip().split(",")
if len(parts) != 2:
continue
try:
pid = int(parts[0].strip())
mem = int(parts[1].strip())
except ValueError:
continue

# Get command line for the process
cmdline = _get_process_cmdline(pid)
processes.append((pid, mem, cmdline))

return processes


def _get_process_cmdline(pid: int) -> str:
"""Get the command line for a process."""
try:
with open(f"/proc/{pid}/cmdline", "r") as f:
return f.read().replace("\0", " ").strip()
except (FileNotFoundError, PermissionError):
return ""


def _is_vllm_ray_worker(cmdline: str) -> bool:
"""Check if a process looks like a vLLM Ray worker."""
indicators = [
"ray::RayWorkerWrapper",
"vllm.v1.worker",
"vllm.worker",
"ray::IDLE", # Ray idle workers can hold GPU memory
]
return any(ind in cmdline for ind in indicators)


def _get_current_ray_session_id() -> str | None:
"""Get the current Ray session ID, if Ray is initialized."""
try:
import ray
if ray.is_initialized():
return ray.get_runtime_context().get_job_id()
except Exception:
pass
return None


def _belongs_to_ray_session(pid: int, session_id: str) -> bool:
"""Check if a process belongs to the given Ray session.

This is best-effort — checks environment variables of the process.
"""
try:
with open(f"/proc/{pid}/environ", "r") as f:
environ = f.read()
return session_id in environ
except (FileNotFoundError, PermissionError):
return False # Can't tell — assume orphan


def register_shutdown_handlers(executor):
"""Register signal handlers and atexit for clean shutdown.

Ensures all Ray workers are killed and GPU memory released
even on SIGTERM/SIGINT.
"""
import atexit

def _cleanup_on_signal(signum, frame):
logger.info("Received signal %d, shutting down executor...", signum)
try:
executor.shutdown()
except Exception as e:
logger.warning("Error during signal-triggered shutdown: %s", e)
# Re-raise the signal for the default handler
signal.signal(signum, signal.SIG_DFL)
os.kill(os.getpid(), signum)

# Register for SIGTERM (container kill, Slurm timeout)
signal.signal(signal.SIGTERM, _cleanup_on_signal)

# atexit for normal Python exit
atexit.register(executor.shutdown)
101 changes: 95 additions & 6 deletions vllm/v1/executor/ray_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import atexit
import os
import signal
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import Future
Expand All @@ -12,6 +14,7 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.v1.executor.orphan_cleanup import cleanup_orphaned_gpu_workers
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.utils.network_utils import (
Expand Down Expand Up @@ -44,6 +47,9 @@
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)

# Configurable timeouts (seconds). Set to 0 or negative to disable.
VLLM_RPC_TIMEOUT = int(os.environ.get("VLLM_RPC_TIMEOUT", "300"))


@dataclass
class RayWorkerMetaData:
Expand Down Expand Up @@ -92,9 +98,26 @@ def _init_executor(self) -> None:
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

# Clean up orphaned workers from previous runs
cleanup_orphaned_gpu_workers()

# Create the parallel GPU workers.
self._init_workers_ray(placement_group)

# Register shutdown handlers for clean teardown
atexit.register(self.shutdown)

def _signal_handler(signum, frame):
logger.info("Received signal %d, shutting down executor...", signum)
try:
self.shutdown()
except Exception:
pass
signal.signal(signum, signal.SIG_DFL)
os.kill(os.getpid(), signum)

signal.signal(signal.SIGTERM, _signal_handler)

# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None

Expand Down Expand Up @@ -384,8 +407,28 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):

is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.collective_rpc("init_device")
self.collective_rpc("load_model")
# init_device with timeout
try:
self.collective_rpc(
"init_device",
timeout=VLLM_RPC_TIMEOUT if VLLM_RPC_TIMEOUT > 0 else None,
)
except Exception as e:
logger.error("init_device failed or timed out: %s", e)
self.shutdown()
raise

# load_model with timeout
load_timeout = int(os.environ.get("VLLM_MODEL_LOAD_TIMEOUT", "600"))
try:
self.collective_rpc(
"load_model",
timeout=load_timeout if load_timeout > 0 else None,
)
except Exception as e:
logger.error("load_model failed or timed out: %s", e)
self.shutdown()
raise

for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
Expand Down Expand Up @@ -507,7 +550,29 @@ def collective_rpc( # type: ignore[override]
if non_block:
return FutureWrapper(ray_worker_outputs)

return ray.get(ray_worker_outputs, timeout=timeout)
# Apply default timeout if none specified
effective_timeout = timeout
if effective_timeout is None and VLLM_RPC_TIMEOUT > 0:
# Only apply default timeout for non-load operations
# (load_model passes timeout=None explicitly and uses progress monitor)
effective_timeout = VLLM_RPC_TIMEOUT

try:
return ray.get(ray_worker_outputs, timeout=effective_timeout)
except ray.exceptions.GetTimeoutError:
# Identify which workers are still running
ready, not_ready = ray.wait(ray_worker_outputs, timeout=0)
hung_ranks = [
i for i, ref in enumerate(ray_worker_outputs) if ref not in ready
]
method_name = sent_method if isinstance(sent_method, str) else "<serialized>"
raise TimeoutError(
f"collective_rpc('{method_name}') timed out after "
f"{effective_timeout}s. Hung worker ranks: {hung_ranks}. "
f"This may indicate workers stalled due to Ray GCS overload, "
f"NCCL hang, or resource contention. "
f"Set VLLM_RPC_TIMEOUT to adjust (current: {VLLM_RPC_TIMEOUT}s)."
) from None

def _check_ray_cgraph_installation(self):
import importlib.metadata
Expand Down Expand Up @@ -638,6 +703,30 @@ def __del__(self):
self.shutdown()

def check_health(self) -> None:
# Assume that the Ray workers are healthy.
# TODO: check the health of the Ray workers
return
"""Check health of all Ray workers. Raises on failure."""
timeout = int(os.environ.get("VLLM_HEALTH_CHECK_TIMEOUT", "60"))

# First: check if Ray actors are still alive (catches killed workers)
for i, worker in enumerate(self.workers):
try:
ray.get(worker.get_node_ip.remote(), timeout=5)
except ray.exceptions.RayActorError as e:
raise RuntimeError(
f"Health check failed: Ray worker {i} is dead. {e}"
) from e
except ray.exceptions.GetTimeoutError:
raise RuntimeError(
f"Health check failed: Ray worker {i} unresponsive (5s timeout)"
)
except Exception as e:
raise RuntimeError(
f"Health check failed: Ray worker {i} error: {e}"
) from e

# Second: run the actual health check on each worker
try:
self.collective_rpc(
"check_health", timeout=timeout if timeout > 0 else None
)
except Exception as e:
raise RuntimeError(f"Health check failed: {e}") from e
15 changes: 15 additions & 0 deletions vllm/v1/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def __init__(self, *args, **kwargs) -> None:
# that thread.
self.compiled_dag_cuda_device_set = False

def __ray_check_health__(self):
"""Called periodically by Ray to verify actor health.

If this method raises, Ray marks the actor as dead and any
pending/future RPCs raise RayActorError.
"""
if self.worker is not None:
try:
self.worker.check_health()
except Exception as e:
raise RuntimeError(
f"Ray worker health check failed: {e}"
) from e
# If worker is None, we're still initializing — that's ok

def get_node_ip(self) -> str:
return get_ip()

Expand Down
Loading
Loading