Skip to content

Process-Scoped GPU Memory Accounting#1204

Open
divyanshsinghvi wants to merge 8 commits intovllm-project:mainfrom
divyanshsinghvi:feature/gpu-memory-accounting
Open

Process-Scoped GPU Memory Accounting#1204
divyanshsinghvi wants to merge 8 commits intovllm-project:mainfrom
divyanshsinghvi:feature/gpu-memory-accounting

Conversation

@divyanshsinghvi
Copy link
Contributor

@divyanshsinghvi divyanshsinghvi commented Feb 4, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

#1147

Test Plan

VLLM_LOGGING_LEVEL=DEBUG .venv/bin/python -c "
      from vllm_omni.entrypoints.omni import Omni
      from vllm.sampling_params import SamplingParams

      omni = Omni(model='Qwen/Qwen2.5-Omni-3B')
      prompts = [{'prompt': '<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n'}]
      params = [SamplingParams(max_tokens=50)]
      for output in omni.generate(prompts, params):
          print(output)
      omni.close()

Test Result

(EngineCore_DP0 pid=51748) [Stage-1] DEBUG 02-05 02:59:46 [base.py:97] Process-scoped memory (PID 51748, GPU 0): requested=6.12, used=4.22, available=1.9
(EngineCore_DP0 pid=51748) [Stage-1] INFO 02-05 02:59:46 [base.py:105] Available KV cache memory: 1.9 GiB (process-scoped)

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
@divyanshsinghvi divyanshsinghvi marked this pull request as ready for review February 4, 2026 21:35
@divyanshsinghvi
Copy link
Contributor Author

divyanshsinghvi commented Feb 4, 2026

cc: @tzhouam

Shoud I remove the logic of sequential initlaization in omni_stage.py also in this PR or send a follow up PR post this is merged?

Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8e08010228

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 23 to 27
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if visible_devices:
try:
physical_indices = [int(x.strip()) for x in visible_devices.split(",") if x.strip()]
if local_rank < len(physical_indices):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle UUID/MIG CUDA_VISIBLE_DEVICES mappings

When CUDA_VISIBLE_DEVICES is set to UUIDs or MIG IDs (common in Kubernetes and multi-tenant setups), parsing with int(...) fails and this helper falls back to local_rank. That means nvmlDeviceGetHandleByIndex is called with a physical index that does not correspond to the visible device list, so the process memory query can target the wrong GPU. In those environments, the available KV cache calculation can be significantly wrong and lead to over-allocation. Consider supporting UUID/MIG strings (e.g., use nvmlDeviceGetHandleByUUID when parsing fails) to keep the NVML lookup aligned with CUDA device visibility.

Useful? React with 👍 / 👎.

Comment on lines 49 to 53
local_rank,
physical_device,
e,
)
return 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid treating NVML failure as zero process memory

If nvmlInit or the NVML query fails, _get_process_gpu_memory returns 0, and the caller then treats the full pre-load requested_memory as available. In environments where NVML is unavailable or permission-restricted but CUDA is functional (e.g., some containers), this overestimates KV cache capacity because model weights/activations already occupy GPU memory after profile_run, which can lead to OOM during initialization. A safer fallback would be to use a global free-memory query or the profiling result when NVML errors occur, rather than assuming zero process usage.

Useful? React with 👍 / 👎.

Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
@tzhouam tzhouam self-requested a review February 5, 2026 02:11
@tzhouam
Copy link
Collaborator

tzhouam commented Feb 5, 2026

cc: @tzhouam

Shoud I remove the logic of sequential initlaization in omni_stage.py also in this PR or send a follow up PR post this is merged?

Nice work!

I went through the code, and here’s my understanding:

  1. If we can successfully obtain process-level memory usage, we can safely disable sequential initialization—since the stages don’t affect each other (as described in RFC [RFC]: Process-Scoped GPU Memory Accounting for Concurrent Omni Stage Initialization in vLLM #1147).
  2. If process-level memory usage isn’t available (e.g., NVML isn’t accessible), we should fall back to sequential initialization.

Would it make sense to wrap the sequential init into a helper function and add a (mock) check for process-level memory availability?



class GPUGenerationWorker(OmniWorkerMixin, GPUWorker):
class GPUGenerationWorker(OmniWorkerMixin, OmniGPUWorkerBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@princepride please check whether the diffusion worker needs to inherit from this OmniGPUWorkerBase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just created a wrapper over GPUWorker, so that I can patch the corressponding functions. If there is a cleaner approach can shift to that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, DiffusionWorker cannot inherit from it due to interface differences. However, I wonder if we could extract the NVML util functions into an independent module, so that DiffusionWorker could also use it to assess the GPU memory usage of the current process.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract an abstraction would work? ALL workers inherits this abstraction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it's a good idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats inside the workerbase?

Copy link
Contributor Author

@divyanshsinghvi divyanshsinghvi Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the vllm one or omni wrapper (inheriting from it which is only patching the function and keeping other functionality same) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Omni wrapper, diffusion worker can't inherit from vLLM GPU worker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I saw the code and diffusion_worker.py doesnt inherit from GPUWorker.

I abstracted the functions but didnt introduce them to the diffusion_worker as I think currently diffusion_worker require it cuMemAllocator from vllm to get memory usage.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functionality for the vLLM engine looks good to me.
I have double-checked the code for the AR part's available memory determination and the dummy run in the diffusion engine. Since the diffusion engine's dummy run does not record memory usage, I think the current design should work even for mixed-structured models like Bagel.

@princepride
Copy link
Contributor

cc: @tzhouam
Shoud I remove the logic of sequential initlaization in omni_stage.py also in this PR or send a follow up PR post this is merged?

Nice work!

I went through the code, and here’s my understanding:

  1. If we can successfully obtain process-level memory usage, we can safely disable sequential initialization—since the stages don’t affect each other (as described in RFC [RFC]: Process-Scoped GPU Memory Accounting for Concurrent Omni Stage Initialization in vLLM #1147).
  2. If process-level memory usage isn’t available (e.g., NVML isn’t accessible), we should fall back to sequential initialization.

Would it make sense to wrap the sequential init into a helper function and add a (mock) check for process-level memory availability?

Wow, This feature is very important to us, isn't it? The current stage loading is too slow, and we need to pass in additional parameters to extend the stage initialization timeout.

Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
@tzhouam
Copy link
Collaborator

tzhouam commented Feb 7, 2026

LGTM, marked as ready

@tzhouam tzhouam added the ready label to trigger buildkite CI label Feb 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants