File tree Expand file tree Collapse file tree 3 files changed +36
-3
lines changed
Expand file tree Collapse file tree 3 files changed +36
-3
lines changed Original file line number Diff line number Diff line change @@ -395,6 +395,21 @@ def test_expert_parallel_size_plumbed_to_sharding(self):
395395 self .assertEqual (sampler .args ["tensor_parallel_size" ], 4 )
396396 self .assertEqual (sampler .args ["data_parallel_size" ], 1 )
397397
398+ def test_reserved_keys_in_engine_kwargs_raise_value_error (self ):
399+ # Reserved VllmConfig fields (e.g. tp, dp, ep) must be set directly on
400+ # VllmConfig, not smuggled through engine_kwargs. Passing them via
401+ # engine_kwargs should raise a ValueError at config construction time
402+ # before any vLLM engine args are assembled.
403+ mesh = self ._make_mock_mesh (8 )
404+ for key in ("expert_parallel_size" , "tensor_parallel_size" , "data_parallel_size" ):
405+ with self .subTest (key = key ):
406+ with self .assertRaisesRegex (ValueError , key ):
407+ vllm_sampler .VllmConfig (
408+ mesh = mesh ,
409+ init_with_random_weights = False ,
410+ engine_kwargs = {key : 2 },
411+ )
412+
398413 def test_default_expert_parallel_size_is_one (self ):
399414 mesh = self ._make_mock_mesh (8 )
400415 config = vllm_sampler .VllmConfig (
Original file line number Diff line number Diff line change @@ -71,8 +71,26 @@ class VllmConfig:
7171 init = False , default_factory = dict
7272 )
7373
74+ # VllmConfig fields that require special processing before being passed to
75+ # vLLM and must not be passed via engine_kwargs, which is a raw pass-through
76+ # to vLLM EngineArgs.
77+ _RESERVED_KEYS : frozenset [str ] = dataclasses .field (
78+ default = frozenset (
79+ {"tensor_parallel_size" , "data_parallel_size" , "expert_parallel_size" }
80+ ),
81+ init = False ,
82+ repr = False ,
83+ compare = False ,
84+ )
85+
7486 def __post_init__ (self , engine_kwargs : Optional [Dict [str , Any ]]):
7587 engine_kwargs = engine_kwargs or {}
88+ illegal = self ._RESERVED_KEYS & engine_kwargs .keys ()
89+ if illegal :
90+ raise ValueError (
91+ f"VllmConfig fields must be set directly on VllmConfig, not passed"
92+ f" via engine_kwargs: { sorted (illegal )} "
93+ )
7694 self ._processed_engine_kwargs = engine_kwargs
7795 if engine_kwargs :
7896 for key , value in engine_kwargs .items ():
Original file line number Diff line number Diff line change @@ -52,16 +52,16 @@ def __init__(
5252 hbm_utilization = rollout_config .rollout_vllm_hbm_utilization ,
5353 lora_config = rollout_config .rollout_vllm_lora_config ,
5454 mesh = mesh ,
55+ tensor_parallel_size = rollout_config .tensor_parallel_size ,
56+ data_parallel_size = rollout_config .data_parallel_size ,
57+ expert_parallel_size = rollout_config .expert_parallel_size ,
5558 engine_kwargs = {
5659 "model" : rollout_config .rollout_vllm_model_version ,
5760 "max_model_len" : cache_config_or_size ,
5861 "swap_space" : rollout_config .rollout_vllm_swap_space_size_gb ,
5962 "async_scheduling" : (
6063 rollout_config .rollout_vllm_async_scheduling
6164 ),
62- "tensor_parallel_size" : rollout_config .tensor_parallel_size ,
63- "data_parallel_size" : rollout_config .data_parallel_size ,
64- "expert_parallel_size" : rollout_config .expert_parallel_size ,
6565 "max_num_batched_tokens" : (
6666 rollout_config .rollout_vllm_max_num_batched_tokens
6767 ),
You can’t perform that action at this time.
0 commit comments