Skip to content

Commit b8d95e2

Browse files
committed
move sharding parallelism keys outside of engine kwargs
1 parent a3389dc commit b8d95e2

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

tests/generate/vllm_sampler_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,21 @@ def test_expert_parallel_size_plumbed_to_sharding(self):
395395
self.assertEqual(sampler.args["tensor_parallel_size"], 4)
396396
self.assertEqual(sampler.args["data_parallel_size"], 1)
397397

398+
def test_reserved_keys_in_engine_kwargs_raise_value_error(self):
399+
# Reserved VllmConfig fields (e.g. tp, dp, ep) must be set directly on
400+
# VllmConfig, not smuggled through engine_kwargs. Passing them via
401+
# engine_kwargs should raise a ValueError at config construction time
402+
# before any vLLM engine args are assembled.
403+
mesh = self._make_mock_mesh(8)
404+
for key in ("expert_parallel_size", "tensor_parallel_size", "data_parallel_size"):
405+
with self.subTest(key=key):
406+
with self.assertRaisesRegex(ValueError, key):
407+
vllm_sampler.VllmConfig(
408+
mesh=mesh,
409+
init_with_random_weights=False,
410+
engine_kwargs={key: 2},
411+
)
412+
398413
def test_default_expert_parallel_size_is_one(self):
399414
mesh = self._make_mock_mesh(8)
400415
config = vllm_sampler.VllmConfig(

tunix/generate/vllm_sampler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,26 @@ class VllmConfig:
7171
init=False, default_factory=dict
7272
)
7373

74+
# VllmConfig fields that require special processing before being passed to
75+
# vLLM and must not be passed via engine_kwargs, which is a raw pass-through
76+
# to vLLM EngineArgs.
77+
_RESERVED_KEYS: frozenset[str] = dataclasses.field(
78+
default=frozenset(
79+
{"tensor_parallel_size", "data_parallel_size", "expert_parallel_size"}
80+
),
81+
init=False,
82+
repr=False,
83+
compare=False,
84+
)
85+
7486
def __post_init__(self, engine_kwargs: Optional[Dict[str, Any]]):
7587
engine_kwargs = engine_kwargs or {}
88+
illegal = self._RESERVED_KEYS & engine_kwargs.keys()
89+
if illegal:
90+
raise ValueError(
91+
f"VllmConfig fields must be set directly on VllmConfig, not passed"
92+
f" via engine_kwargs: {sorted(illegal)}"
93+
)
7694
self._processed_engine_kwargs = engine_kwargs
7795
if engine_kwargs:
7896
for key, value in engine_kwargs.items():

tunix/rl/rollout/vllm_rollout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,16 @@ def __init__(
5252
hbm_utilization=rollout_config.rollout_vllm_hbm_utilization,
5353
lora_config=rollout_config.rollout_vllm_lora_config,
5454
mesh=mesh,
55+
tensor_parallel_size=rollout_config.tensor_parallel_size,
56+
data_parallel_size=rollout_config.data_parallel_size,
57+
expert_parallel_size=rollout_config.expert_parallel_size,
5558
engine_kwargs={
5659
"model": rollout_config.rollout_vllm_model_version,
5760
"max_model_len": cache_config_or_size,
5861
"swap_space": rollout_config.rollout_vllm_swap_space_size_gb,
5962
"async_scheduling": (
6063
rollout_config.rollout_vllm_async_scheduling
6164
),
62-
"tensor_parallel_size": rollout_config.tensor_parallel_size,
63-
"data_parallel_size": rollout_config.data_parallel_size,
64-
"expert_parallel_size": rollout_config.expert_parallel_size,
6565
"max_num_batched_tokens": (
6666
rollout_config.rollout_vllm_max_num_batched_tokens
6767
),

0 commit comments

Comments
 (0)