Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/generate/vllm_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,21 @@ def test_expert_parallel_size_plumbed_to_sharding(self):
self.assertEqual(sampler.args["tensor_parallel_size"], 4)
self.assertEqual(sampler.args["data_parallel_size"], 1)

def test_reserved_keys_in_engine_kwargs_raise_value_error(self):
# Reserved VllmConfig fields (e.g. tp, dp, ep) must be set directly on
# VllmConfig, not smuggled through engine_kwargs. Passing them via
# engine_kwargs should raise a ValueError at config construction time
# before any vLLM engine args are assembled.
mesh = self._make_mock_mesh(8)
for key in ("expert_parallel_size", "tensor_parallel_size", "data_parallel_size"):
with self.subTest(key=key):
with self.assertRaisesRegex(ValueError, key):
vllm_sampler.VllmConfig(
mesh=mesh,
init_with_random_weights=False,
engine_kwargs={key: 2},
)

def test_default_expert_parallel_size_is_one(self):
mesh = self._make_mock_mesh(8)
config = vllm_sampler.VllmConfig(
Expand Down
18 changes: 18 additions & 0 deletions tunix/generate/vllm_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,26 @@ class VllmConfig:
init=False, default_factory=dict
)

# VllmConfig fields that require special processing before being passed to
# vLLM and must not be passed via engine_kwargs, which is a raw pass-through
# to vLLM EngineArgs.
_RESERVED_KEYS: frozenset[str] = dataclasses.field(
default=frozenset(
{"tensor_parallel_size", "data_parallel_size", "expert_parallel_size"}
),
init=False,
repr=False,
compare=False,
)

def __post_init__(self, engine_kwargs: Optional[Dict[str, Any]]):
engine_kwargs = engine_kwargs or {}
illegal = self._RESERVED_KEYS & engine_kwargs.keys()
if illegal:
raise ValueError(
f"VllmConfig fields must be set directly on VllmConfig, not passed"
f" via engine_kwargs: {sorted(illegal)}"
)
self._processed_engine_kwargs = engine_kwargs
if engine_kwargs:
for key, value in engine_kwargs.items():
Expand Down
6 changes: 3 additions & 3 deletions tunix/rl/rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,16 @@ def __init__(
hbm_utilization=rollout_config.rollout_vllm_hbm_utilization,
lora_config=rollout_config.rollout_vllm_lora_config,
mesh=mesh,
tensor_parallel_size=rollout_config.tensor_parallel_size,
data_parallel_size=rollout_config.data_parallel_size,
expert_parallel_size=rollout_config.expert_parallel_size,
engine_kwargs={
"model": rollout_config.rollout_vllm_model_version,
"max_model_len": cache_config_or_size,
"swap_space": rollout_config.rollout_vllm_swap_space_size_gb,
"async_scheduling": (
rollout_config.rollout_vllm_async_scheduling
),
"tensor_parallel_size": rollout_config.tensor_parallel_size,
"data_parallel_size": rollout_config.data_parallel_size,
"expert_parallel_size": rollout_config.expert_parallel_size,
"max_num_batched_tokens": (
rollout_config.rollout_vllm_max_num_batched_tokens
),
Expand Down