Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions tests/worker/test_process_gpu_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Tests for process-scoped GPU memory accounting."""

import os
from unittest import mock

import pytest


class TestParseCudaVisibleDevices:
def test_empty(self):
from vllm_omni.worker.base import _parse_cuda_visible_devices

with mock.patch.dict(os.environ, {}, clear=True):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
assert _parse_cuda_visible_devices() == []

with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": ""}):
assert _parse_cuda_visible_devices() == []

def test_integer_indices(self):
from vllm_omni.worker.base import _parse_cuda_visible_devices

with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "2,3,5"}):
assert _parse_cuda_visible_devices() == [2, 3, 5]

with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}):
assert _parse_cuda_visible_devices() == [0]

def test_uuids(self):
from vllm_omni.worker.base import _parse_cuda_visible_devices

uuid1 = "GPU-12345678-1234-1234-1234-123456789abc"
uuid2 = "GPU-87654321-4321-4321-4321-cba987654321"
with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": f"{uuid1},{uuid2}"}):
assert _parse_cuda_visible_devices() == [uuid1, uuid2]

def test_mig_ids(self):
from vllm_omni.worker.base import _parse_cuda_visible_devices

mig1 = "MIG-GPU-12345678-1234-1234-1234-123456789abc/0/0"
mig2 = "MIG-GPU-12345678-1234-1234-1234-123456789abc/1/0"
with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": f"{mig1},{mig2}"}):
assert _parse_cuda_visible_devices() == [mig1, mig2]

def test_spaces(self):
from vllm_omni.worker.base import _parse_cuda_visible_devices

with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": " 2 , 3 , 5 "}):
assert _parse_cuda_visible_devices() == [2, 3, 5]


class TestGetProcessGpuMemory:
@pytest.mark.skipif(not os.path.exists("/dev/nvidia0"), reason="No GPU")
def test_returns_memory_for_current_process(self):
import torch

from vllm_omni.worker.base import _get_process_gpu_memory

if not torch.cuda.is_available():
pytest.skip("CUDA not available")

device = torch.device("cuda:0")
tensor = torch.zeros(1000, 1000, device=device)

memory = _get_process_gpu_memory(0)
assert memory >= 0

del tensor
torch.cuda.empty_cache()

def test_raises_on_invalid_device(self):
from vllm_omni.worker.base import _get_process_gpu_memory

with (
mock.patch("vllm_omni.worker.base.nvmlInit"),
mock.patch("vllm_omni.worker.base.nvmlShutdown"),
mock.patch("vllm.third_party.pynvml.nvmlDeviceGetCount", return_value=1),
):
with pytest.raises(RuntimeError, match="Invalid GPU device"):
_get_process_gpu_memory(5)

def test_returns_zero_when_process_not_found(self):
from vllm_omni.worker.base import _get_process_gpu_memory

with (
mock.patch("vllm_omni.worker.base.nvmlInit"),
mock.patch("vllm_omni.worker.base.nvmlShutdown"),
mock.patch("vllm.third_party.pynvml.nvmlDeviceGetCount", return_value=8),
mock.patch("vllm_omni.worker.base.nvmlDeviceGetHandleByIndex"),
mock.patch("vllm_omni.worker.base.nvmlDeviceGetComputeRunningProcesses", return_value=[]),
):
memory = _get_process_gpu_memory(0)
assert memory == 0

def test_uses_uuid_when_provided(self):
from vllm_omni.worker.base import _get_process_gpu_memory

uuid = "GPU-12345678-1234-1234-1234-123456789abc"
mock_handle = mock.MagicMock()

with (
mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": uuid}),
mock.patch("vllm_omni.worker.base.nvmlInit"),
mock.patch("vllm_omni.worker.base.nvmlShutdown"),
mock.patch("vllm.third_party.pynvml.nvmlDeviceGetHandleByUUID", return_value=mock_handle) as mock_by_uuid,
mock.patch("vllm_omni.worker.base.nvmlDeviceGetComputeRunningProcesses", return_value=[]),
):
memory = _get_process_gpu_memory(0)
assert memory == 0
mock_by_uuid.assert_called_once_with(uuid)

def test_raises_on_invalid_uuid(self):
from vllm_omni.worker.base import _get_process_gpu_memory

uuid = "GPU-invalid-uuid"

with (
mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": uuid}),
mock.patch("vllm_omni.worker.base.nvmlInit"),
mock.patch("vllm_omni.worker.base.nvmlShutdown"),
mock.patch("vllm.third_party.pynvml.nvmlDeviceGetHandleByUUID", side_effect=Exception("Invalid UUID")),
):
with pytest.raises(RuntimeError, match="Failed to get NVML handle"):
_get_process_gpu_memory(0)

def test_returns_none_on_nvml_init_failure(self):
from vllm_omni.worker.base import _get_process_gpu_memory

with (
mock.patch("vllm_omni.worker.base.nvmlInit", side_effect=Exception("NVML unavailable")),
mock.patch("vllm_omni.worker.base.nvmlShutdown"),
):
result = _get_process_gpu_memory(0)
assert result is None
194 changes: 194 additions & 0 deletions vllm_omni/worker/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Base worker class for vLLM-Omni with process-scoped GPU memory accounting."""

from __future__ import annotations

import os

import torch
from vllm.logger import init_logger
from vllm.third_party.pynvml import (
nvmlDeviceGetComputeRunningProcesses,
nvmlDeviceGetHandleByIndex,
nvmlInit,
nvmlShutdown,
)
from vllm.utils.mem_utils import format_gib, memory_profiling
from vllm.v1.worker.gpu_worker import Worker as GPUWorker

logger = init_logger(__name__)


def _parse_cuda_visible_devices() -> list[str | int]:
"""Parse CUDA_VISIBLE_DEVICES into a list of device identifiers.

Returns list of integers (physical indices) or strings (UUIDs/MIG IDs).
"""
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if not visible_devices:
return []

result: list[str | int] = []
for item in visible_devices.split(","):
item = item.strip()
if not item:
continue
try:
result.append(int(item))
except ValueError:
# UUID (GPU-xxx) or MIG ID (MIG-xxx)
result.append(item)
return result


def _get_device_handle(device_id: str | int):
"""Get NVML device handle by index or UUID."""
if isinstance(device_id, int):
return nvmlDeviceGetHandleByIndex(device_id)
else:
from vllm.third_party.pynvml import nvmlDeviceGetHandleByUUID

return nvmlDeviceGetHandleByUUID(device_id)


def _get_process_gpu_memory(local_rank: int) -> int | None:
"""Get GPU memory used by current process via pynvml.

Supports CUDA_VISIBLE_DEVICES with integer indices, UUIDs, or MIG IDs.

Returns:
Memory in bytes used by this process, or None if NVML unavailable.

Raises:
RuntimeError: If device validation fails (invalid index or UUID).
"""
from vllm.third_party.pynvml import nvmlDeviceGetCount

my_pid = os.getpid()
visible_devices = _parse_cuda_visible_devices()

try:
nvmlInit()
except Exception as e:
logger.warning("NVML init failed, will use profiling fallback: %s", e)
return None

try:
if visible_devices and local_rank < len(visible_devices):
device_id = visible_devices[local_rank]
try:
handle = _get_device_handle(device_id)
except Exception as e:
raise RuntimeError(
f"Failed to get NVML handle for device '{device_id}' (local_rank={local_rank}). "
f"Check CUDA_VISIBLE_DEVICES or stage config 'devices' setting."
) from e
else:
# No CUDA_VISIBLE_DEVICES or local_rank out of range: use index directly
device_count = nvmlDeviceGetCount()
if local_rank >= device_count:
raise RuntimeError(
f"Invalid GPU device {local_rank}. Only {device_count} GPU(s) available. "
f"Check CUDA_VISIBLE_DEVICES or stage config 'devices' setting."
)
handle = nvmlDeviceGetHandleByIndex(local_rank)

for proc in nvmlDeviceGetComputeRunningProcesses(handle):
if proc.pid == my_pid:
return proc.usedGpuMemory
return 0
except RuntimeError:
raise
except Exception as e:
logger.warning("NVML query failed, will use profiling fallback: %s", e)
return None
finally:
try:
nvmlShutdown()
except Exception:
pass


class OmniGPUWorkerBase(GPUWorker):
"""Base GPU worker for vLLM-Omni with process-scoped memory accounting.

This class overrides determine_available_memory() to use per-process GPU
memory tracking via pynvml, allowing multiple stages to initialize
concurrently on the same GPU without memory accounting interference.
"""

@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Process-scoped GPU memory profiling for concurrent stage initialization.

Algorithm:
1. requested_memory = total_gpu_memory * gpu_memory_utilization
(computed in init_device from cache_config)

2. process_memory = memory used by THIS process only (via pynvml)
- Uses nvmlDeviceGetComputeRunningProcesses to get per-PID memory
- Supports CUDA_VISIBLE_DEVICES with indices, UUIDs, or MIG IDs

3. available_kv_cache = requested_memory - process_memory

Fallback:
If NVML is unavailable, falls back to profiling data:
available = requested - (weights + activations + non_torch)
"""
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
self.model_runner.profile_run()
logger.info(
"Using explicit kv_cache_memory_bytes: %s GiB",
format_gib(kv_cache_memory_bytes),
)
return kv_cache_memory_bytes

with memory_profiling(
self.init_snapshot,
weights_memory=int(self.model_runner.model_memory_usage),
) as profile_result:
self.model_runner.profile_run()

self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase

process_memory = _get_process_gpu_memory(self.local_rank)

if process_memory is not None:
# NVML available: use per-process memory
self.available_kv_cache_memory_bytes = max(0, self.requested_memory - process_memory)
logger.debug(
"Process-scoped memory (PID %d, GPU %d): requested=%s, used=%s, available=%s",
os.getpid(),
self.local_rank,
format_gib(self.requested_memory),
format_gib(process_memory),
format_gib(self.available_kv_cache_memory_bytes),
)
logger.info_once(
"Available KV cache memory: %s GiB (process-scoped)",
format_gib(self.available_kv_cache_memory_bytes),
scope="local",
)
else:
# NVML unavailable: use profiling data as conservative fallback
profiled_usage = (
int(self.model_runner.model_memory_usage)
+ profile_result.torch_peak_increase
+ profile_result.non_torch_increase
)
self.available_kv_cache_memory_bytes = max(0, self.requested_memory - profiled_usage)
logger.debug(
"Profiling fallback (PID %d, GPU %d): requested=%s, profiled=%s, available=%s",
os.getpid(),
self.local_rank,
format_gib(self.requested_memory),
format_gib(profiled_usage),
format_gib(self.available_kv_cache_memory_bytes),
)
logger.info_once(
"Available KV cache memory: %s GiB (profiling fallback)",
format_gib(self.available_kv_cache_memory_bytes),
scope="local",
)

return int(self.available_kv_cache_memory_bytes)
4 changes: 2 additions & 2 deletions vllm_omni/worker/gpu_ar_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_worker import Worker as GPUWorker
from vllm.v1.worker.gpu_worker import init_worker_distributed_environment
from vllm.v1.worker.utils import request_memory
from vllm.v1.worker.workspace import init_workspace_manager

from vllm_omni.worker.base import OmniGPUWorkerBase
from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner
from vllm_omni.worker.mixins import OmniWorkerMixin

logger = init_logger(__name__)


class GPUARWorker(OmniWorkerMixin, GPUWorker):
class GPUARWorker(OmniWorkerMixin, OmniGPUWorkerBase):
"""GPU worker for autoregressive omni model stages.

Extends the base GPUWorker to initialize and manage autoregressive
Expand Down
4 changes: 2 additions & 2 deletions vllm_omni/worker/gpu_generation_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_worker import Worker as GPUWorker
from vllm.v1.worker.gpu_worker import init_worker_distributed_environment
from vllm.v1.worker.utils import request_memory
from vllm.v1.worker.workspace import init_workspace_manager

from vllm_omni.worker.base import OmniGPUWorkerBase
from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner
from vllm_omni.worker.mixins import OmniWorkerMixin

logger = init_logger(__name__)


class GPUGenerationWorker(OmniWorkerMixin, GPUWorker):
class GPUGenerationWorker(OmniWorkerMixin, OmniGPUWorkerBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@princepride please check whether the diffusion worker needs to inherit from this OmniGPUWorkerBase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just created a wrapper over GPUWorker, so that I can patch the corressponding functions. If there is a cleaner approach can shift to that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, DiffusionWorker cannot inherit from it due to interface differences. However, I wonder if we could extract the NVML util functions into an independent module, so that DiffusionWorker could also use it to assess the GPU memory usage of the current process.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract an abstraction would work? ALL workers inherits this abstraction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it's a good idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats inside the workerbase?

Copy link
Contributor Author

@divyanshsinghvi divyanshsinghvi Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the vllm one or omni wrapper (inheriting from it which is only patching the function and keeping other functionality same) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Omni wrapper, diffusion worker can't inherit from vLLM GPU worker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I saw the code and diffusion_worker.py doesnt inherit from GPUWorker.

I abstracted the functions but didnt introduce them to the diffusion_worker as I think currently diffusion_worker require it cuMemAllocator from vllm to get memory usage.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functionality for the vLLM engine looks good to me.
I have double-checked the code for the AR part's available memory determination and the dummy run in the diffusion engine. Since the diffusion engine's dummy run does not record memory usage, I think the current design should work even for mixed-structured models like Bagel.

"""GPU Worker for Generation model (non-autoregressive waveform generation).

Usage in stage config:
Expand Down