Skip to content

Commit e82c781

Browse files
committed
fix expert_parallel_size
1 parent a3389dc commit e82c781

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

tests/generate/vllm_sampler_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,30 @@ def test_expert_parallel_size_plumbed_to_sharding(self):
395395
self.assertEqual(sampler.args["tensor_parallel_size"], 4)
396396
self.assertEqual(sampler.args["data_parallel_size"], 1)
397397

398+
def test_expert_parallel_size_via_engine_kwargs_not_leaked_to_vllm(self):
399+
# Regression test: expert_parallel_size passed via engine_kwargs should be
400+
# consumed by tunix config processing and translated into
401+
# additional_config["sharding"]["sharding_strategy"]["expert_parallelism"].
402+
# It must NOT appear as a top-level key in sampler.args, because vLLM's
403+
# EngineArgs has no such parameter and would raise an error.
404+
mesh = self._make_mock_mesh(8)
405+
config = vllm_sampler.VllmConfig(
406+
mesh=mesh,
407+
init_with_random_weights=False,
408+
engine_kwargs={"expert_parallel_size": 2},
409+
)
410+
sampler = self._make_sampler(config)
411+
412+
self.assertNotIn(
413+
"expert_parallel_size",
414+
sampler.args,
415+
"expert_parallel_size must not be passed directly to vLLM engine args",
416+
)
417+
sharding_strategy = sampler.args["additional_config"]["sharding"][
418+
"sharding_strategy"
419+
]
420+
self.assertEqual(sharding_strategy["expert_parallelism"], 2)
421+
398422
def test_default_expert_parallel_size_is_one(self):
399423
mesh = self._make_mock_mesh(8)
400424
config = vllm_sampler.VllmConfig(

tunix/generate/vllm_sampler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ def load_checkpoint(self, path_or_weights: str | jaxtyping.PyTree):
202202
def _vllm_config(self, config: VllmConfig):
203203
"""Setup vllm config from Tunix Vllm config."""
204204
args = config._processed_engine_kwargs.copy()
205+
# expert_parallel_size is a tunix-owned concept translated into
206+
# additional_config["sharding"]["sharding_strategy"]["expert_parallelism"].
207+
# It is not a vLLM EngineArgs parameter and must not be passed through.
208+
args.pop("expert_parallel_size", None)
205209

206210
# Init vLLM model with random weights to speed up bootstrap time, because
207211
# model weights are synced from trainer later on

0 commit comments

Comments
 (0)