Skip to content

feat(vllm): switch DP+EP launch to hybrid_lb (per-node process)#90

Draft
esmeetu wants to merge 1 commit intoNVIDIA:mainfrom
esmeetu:yasong/vllm-hybrid-dp-lb
Draft

feat(vllm): switch DP+EP launch to hybrid_lb (per-node process)#90
esmeetu wants to merge 1 commit intoNVIDIA:mainfrom
esmeetu:yasong/vllm-hybrid-dp-lb

Conversation

@esmeetu
Copy link
Copy Markdown

@esmeetu esmeetu commented Apr 27, 2026

DP+EP mode previously launched one srun task per GPU with restricted CUDA_VISIBLE_DEVICES, which makes vLLM auto-select external_lb (one pod per rank). Locally co-located ranks then each see only cuda:0, and CUDA Symmetric Memory rendezvous fails with:

CUDASymmetricMemoryAllocator::rendezvous: detected allocations from
overlapping devices from different ranks.

This blocks DSv4 MegaMoE (deep_gemm.get_symm_buffer_for_mega_moe), the SymmMemCommunicator all-reduce path, and any future shared-namespace fast paths on GB200/GB300 nodes that pack multiple DP ranks per node.

Switch DP+EP to hybrid_lb:

  • endpoints_to_processes(): one Process per node (full local GPU set) instead of one Process per GPU. Reserves local_dp_size kv-events ports per process so dynamo's per-rank ZMQ publishers do not collide.

  • build_worker_command(): drop --data-parallel-rank (which silently flips vLLM into external_lb and forces size_local=1, see vllm/engine/arg_utils.py:1702-1717). Pass instead:

    --data-parallel-hybrid-lb
    --data-parallel-size-local <local_dp>
    --data-parallel-start-rank <node_rank * local_dp>
    --data-parallel-address    <leader_ip>
    --data-parallel-rpc-port   <port>
    

    --data-parallel-hybrid-lb is passed explicitly because vLLM's auto-detect at arg_utils.py:1721 (if self.data_parallel_start_rank and not headless) uses Python truthiness, so the leader node (start_rank=0) silently falls out of hybrid_lb and rejects worker engines with "Remote engine N must use --headless unless in external or hybrid dp lb mode".

Side effect: worker_stage.py's len(gpu_indices) < gpus_per_node check is now False for DP processes, so CUDA_VISIBLE_DEVICES is no longer injected — all local GPUs share one CUDA namespace.

Single-node DP (size_local == data_parallel_size) automatically collapses to internal_lb inside vLLM (arg_utils.py:1735-1737), so the flag is harmless there.

Tests: rewrite test_dp_mode_creates_per_node_processes and test_dp_mode_command_includes_dp_flags to assert the new shape; extend test_tp_mode_command_includes_multinode_flags to reject the hybrid_lb flags.

Verified on GB300 1P1D DEP4 (single-node) and 1P10D DEP16 (4-node) recipes via dry-run.

DP+EP mode previously launched one srun task per GPU with restricted
CUDA_VISIBLE_DEVICES, which makes vLLM auto-select external_lb (one pod
per rank). Locally co-located ranks then each see only `cuda:0`, and
CUDA Symmetric Memory rendezvous fails with:

    CUDASymmetricMemoryAllocator::rendezvous: detected allocations from
    overlapping devices from different ranks.

This blocks DSv4 MegaMoE (deep_gemm.get_symm_buffer_for_mega_moe), the
SymmMemCommunicator all-reduce path, and any future shared-namespace
fast paths on GB200/GB300 nodes that pack multiple DP ranks per node.

Switch DP+EP to hybrid_lb:

  - endpoints_to_processes(): one Process per node (full local GPU set)
    instead of one Process per GPU. Reserves local_dp_size kv-events
    ports per process so dynamo's per-rank ZMQ publishers do not collide.

  - build_worker_command(): drop --data-parallel-rank (which silently
    flips vLLM into external_lb and forces size_local=1, see
    vllm/engine/arg_utils.py:1702-1717). Pass instead:

        --data-parallel-hybrid-lb
        --data-parallel-size-local <local_dp>
        --data-parallel-start-rank <node_rank * local_dp>
        --data-parallel-address    <leader_ip>
        --data-parallel-rpc-port   <port>

    --data-parallel-hybrid-lb is passed explicitly because vLLM's
    auto-detect at arg_utils.py:1721 (`if self.data_parallel_start_rank
    and not headless`) uses Python truthiness, so the leader node
    (start_rank=0) silently falls out of hybrid_lb and rejects worker
    engines with "Remote engine N must use --headless unless in
    external or hybrid dp lb mode".

Side effect: worker_stage.py's `len(gpu_indices) < gpus_per_node`
check is now False for DP processes, so CUDA_VISIBLE_DEVICES is no
longer injected — all local GPUs share one CUDA namespace.

Single-node DP (size_local == data_parallel_size) automatically
collapses to internal_lb inside vLLM (arg_utils.py:1735-1737), so the
flag is harmless there.

Tests: rewrite test_dp_mode_creates_per_node_processes and
test_dp_mode_command_includes_dp_flags to assert the new shape;
extend test_tp_mode_command_includes_multinode_flags to reject the
hybrid_lb flags.

Verified on GB300 1P1D DEP4 (single-node) and 1P10D DEP16 (4-node)
recipes via dry-run.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant