@@ -51,6 +51,9 @@ class VLLM(lmms):
5151 Default: "Qwen/Qwen2.5-VL-3B-Instruct"
5252 tensor_parallel_size (int): Number of GPUs to use for tensor parallelism.
5353 Default: 1
54+ data_parallel_size (int): Global number of data-parallel replicas across the
55+ distributed launch. This is a world-size value, not a per-node or per-GPU
56+ local count. Default: 1
5457 gpu_memory_utilization (float): Fraction of GPU memory to use for model weights.
5558 Should be between 0.0 and 1.0. Default: 0.8
5659 batch_size (int): Number of requests to process in parallel per GPU.
@@ -212,9 +215,16 @@ def __init__(
212215 self .accelerator = accelerator
213216 self ._rank = self .accelerator .local_process_index
214217 self ._world_size = self .accelerator .num_processes
218+ expected_world_size = self .tensor_parallel_size * self .data_parallel_size
219+ if self .data_parallel_size > 1 and accelerator .num_processes == 1 :
220+ raise ValueError (
221+ "vLLM data parallel requires torchrun/accelerate multi-process launch. "
222+ f"Expected world_size = tensor_parallel_size * data_parallel_size = { expected_world_size } , "
223+ "but got single-process execution. Re-launch lmms_eval with torchrun or another "
224+ "distributed launcher instead of passing data_parallel_size to a single process."
225+ )
215226 if accelerator .num_processes > 1 :
216227 kwargs ["distributed_executor_backend" ] = "external_launcher"
217- expected_world_size = self .tensor_parallel_size * self .data_parallel_size
218228 if expected_world_size > 1 and accelerator .num_processes != expected_world_size :
219229 raise ValueError ("For external_launcher mode, accelerate world size must equal " f"tensor_parallel_size * data_parallel_size ({ expected_world_size } ), " f"but got { accelerator .num_processes } ." )
220230 self .client = LLM (
0 commit comments