feat(vllm): switch DP+EP launch to hybrid_lb (per-node process)#90
Draft
esmeetu wants to merge 1 commit intoNVIDIA:mainfrom
Draft
feat(vllm): switch DP+EP launch to hybrid_lb (per-node process)#90esmeetu wants to merge 1 commit intoNVIDIA:mainfrom
esmeetu wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
DP+EP mode previously launched one srun task per GPU with restricted
CUDA_VISIBLE_DEVICES, which makes vLLM auto-select external_lb (one pod
per rank). Locally co-located ranks then each see only `cuda:0`, and
CUDA Symmetric Memory rendezvous fails with:
CUDASymmetricMemoryAllocator::rendezvous: detected allocations from
overlapping devices from different ranks.
This blocks DSv4 MegaMoE (deep_gemm.get_symm_buffer_for_mega_moe), the
SymmMemCommunicator all-reduce path, and any future shared-namespace
fast paths on GB200/GB300 nodes that pack multiple DP ranks per node.
Switch DP+EP to hybrid_lb:
- endpoints_to_processes(): one Process per node (full local GPU set)
instead of one Process per GPU. Reserves local_dp_size kv-events
ports per process so dynamo's per-rank ZMQ publishers do not collide.
- build_worker_command(): drop --data-parallel-rank (which silently
flips vLLM into external_lb and forces size_local=1, see
vllm/engine/arg_utils.py:1702-1717). Pass instead:
--data-parallel-hybrid-lb
--data-parallel-size-local <local_dp>
--data-parallel-start-rank <node_rank * local_dp>
--data-parallel-address <leader_ip>
--data-parallel-rpc-port <port>
--data-parallel-hybrid-lb is passed explicitly because vLLM's
auto-detect at arg_utils.py:1721 (`if self.data_parallel_start_rank
and not headless`) uses Python truthiness, so the leader node
(start_rank=0) silently falls out of hybrid_lb and rejects worker
engines with "Remote engine N must use --headless unless in
external or hybrid dp lb mode".
Side effect: worker_stage.py's `len(gpu_indices) < gpus_per_node`
check is now False for DP processes, so CUDA_VISIBLE_DEVICES is no
longer injected — all local GPUs share one CUDA namespace.
Single-node DP (size_local == data_parallel_size) automatically
collapses to internal_lb inside vLLM (arg_utils.py:1735-1737), so the
flag is harmless there.
Tests: rewrite test_dp_mode_creates_per_node_processes and
test_dp_mode_command_includes_dp_flags to assert the new shape;
extend test_tp_mode_command_includes_multinode_flags to reject the
hybrid_lb flags.
Verified on GB300 1P1D DEP4 (single-node) and 1P10D DEP16 (4-node)
recipes via dry-run.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
DP+EP mode previously launched one srun task per GPU with restricted CUDA_VISIBLE_DEVICES, which makes vLLM auto-select external_lb (one pod per rank). Locally co-located ranks then each see only
cuda:0, and CUDA Symmetric Memory rendezvous fails with:This blocks DSv4 MegaMoE (deep_gemm.get_symm_buffer_for_mega_moe), the SymmMemCommunicator all-reduce path, and any future shared-namespace fast paths on GB200/GB300 nodes that pack multiple DP ranks per node.
Switch DP+EP to hybrid_lb:
endpoints_to_processes(): one Process per node (full local GPU set) instead of one Process per GPU. Reserves local_dp_size kv-events ports per process so dynamo's per-rank ZMQ publishers do not collide.
build_worker_command(): drop --data-parallel-rank (which silently flips vLLM into external_lb and forces size_local=1, see vllm/engine/arg_utils.py:1702-1717). Pass instead:
--data-parallel-hybrid-lb is passed explicitly because vLLM's auto-detect at arg_utils.py:1721 (
if self.data_parallel_start_rank and not headless) uses Python truthiness, so the leader node (start_rank=0) silently falls out of hybrid_lb and rejects worker engines with "Remote engine N must use --headless unless in external or hybrid dp lb mode".Side effect: worker_stage.py's
len(gpu_indices) < gpus_per_nodecheck is now False for DP processes, so CUDA_VISIBLE_DEVICES is no longer injected — all local GPUs share one CUDA namespace.Single-node DP (size_local == data_parallel_size) automatically collapses to internal_lb inside vLLM (arg_utils.py:1735-1737), so the flag is harmless there.
Tests: rewrite test_dp_mode_creates_per_node_processes and test_dp_mode_command_includes_dp_flags to assert the new shape; extend test_tp_mode_command_includes_multinode_flags to reject the hybrid_lb flags.
Verified on GB300 1P1D DEP4 (single-node) and 1P10D DEP16 (4-node) recipes via dry-run.