Skip to content

Commit 8bcf00b

Browse files
authored
fix(vllm): fail fast on single-process data parallel (#1236)
1 parent ffc62a2 commit 8bcf00b

3 files changed

Lines changed: 13 additions & 2 deletions

File tree

examples/models/vllm_qwen35.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ MODEL="Qwen/Qwen3.5-397B-A17B"
44
TASKS="mmmu_val,mme"
55

66
TENSOR_PARALLEL_SIZE=8
7+
# Global DP replica count across the full launch, not a per-GPU local value.
78
DATA_PARALLEL_SIZE=1
89
GPU_MEMORY_UTILIZATION=0.85
910
BATCH_SIZE=16

examples/models/vllm_qwen3vl.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
2929
# Adjust based on your GPU configuration.
3030
# If DATA_PARALLEL_SIZE > 1, this script automatically switches to torchrun.
3131
TENSOR_PARALLEL_SIZE=4 # Number of GPUs for tensor parallelism
32-
DATA_PARALLEL_SIZE=1 # Number of model replicas for data parallelism
32+
DATA_PARALLEL_SIZE=1 # Global number of data-parallel replicas, not a per-GPU local value
3333

3434
# Memory and Performance Settings
3535
GPU_MEMORY_UTILIZATION=0.85 # Fraction of GPU memory to use (0.0 - 1.0)

lmms_eval/models/simple/vllm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class VLLM(lmms):
5151
Default: "Qwen/Qwen2.5-VL-3B-Instruct"
5252
tensor_parallel_size (int): Number of GPUs to use for tensor parallelism.
5353
Default: 1
54+
data_parallel_size (int): Global number of data-parallel replicas across the
55+
distributed launch. This is a world-size value, not a per-node or per-GPU
56+
local count. Default: 1
5457
gpu_memory_utilization (float): Fraction of GPU memory to use for model weights.
5558
Should be between 0.0 and 1.0. Default: 0.8
5659
batch_size (int): Number of requests to process in parallel per GPU.
@@ -212,9 +215,16 @@ def __init__(
212215
self.accelerator = accelerator
213216
self._rank = self.accelerator.local_process_index
214217
self._world_size = self.accelerator.num_processes
218+
expected_world_size = self.tensor_parallel_size * self.data_parallel_size
219+
if self.data_parallel_size > 1 and accelerator.num_processes == 1:
220+
raise ValueError(
221+
"vLLM data parallel requires torchrun/accelerate multi-process launch. "
222+
f"Expected world_size = tensor_parallel_size * data_parallel_size = {expected_world_size}, "
223+
"but got single-process execution. Re-launch lmms_eval with torchrun or another "
224+
"distributed launcher instead of passing data_parallel_size to a single process."
225+
)
215226
if accelerator.num_processes > 1:
216227
kwargs["distributed_executor_backend"] = "external_launcher"
217-
expected_world_size = self.tensor_parallel_size * self.data_parallel_size
218228
if expected_world_size > 1 and accelerator.num_processes != expected_world_size:
219229
raise ValueError("For external_launcher mode, accelerate world size must equal " f"tensor_parallel_size * data_parallel_size ({expected_world_size}), " f"but got {accelerator.num_processes}.")
220230
self.client = LLM(

0 commit comments

Comments
 (0)