Skip to content

Commit f96d9d5

Browse files
committed
generalized dict and error message
1 parent bb85d89 commit f96d9d5

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

tests/generate/vllm_sampler_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,11 @@ def test_expert_parallel_size_plumbed_to_sharding(self):
395395
self.assertEqual(sampler.args["tensor_parallel_size"], 4)
396396
self.assertEqual(sampler.args["data_parallel_size"], 1)
397397

398-
def test_sharding_keys_in_engine_kwargs_raise_value_error(self):
399-
# Sharding parallelism sizes (tp, dp, ep) are tunix-owned VllmConfig fields
400-
# and must be set directly on VllmConfig, not smuggled through engine_kwargs.
401-
# Passing them via engine_kwargs should raise a ValueError at config
402-
# construction time before any vLLM engine args are assembled.
398+
def test_reserved_keys_in_engine_kwargs_raise_value_error(self):
399+
# Reserved VllmConfig fields (e.g. tp, dp, ep) must be set directly on
400+
# VllmConfig, not smuggled through engine_kwargs. Passing them via
401+
# engine_kwargs should raise a ValueError at config construction time
402+
# before any vLLM engine args are assembled.
403403
mesh = self._make_mock_mesh(8)
404404
for key in ("expert_parallel_size", "tensor_parallel_size", "data_parallel_size"):
405405
with self.subTest(key=key):

tunix/generate/vllm_sampler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ class VllmConfig:
7171
init=False, default_factory=dict
7272
)
7373

74-
# Parallelism sizes are tunix-owned VllmConfig fields that require special
75-
# processing before being passed to vLLM. They must not be passed via
76-
# engine_kwargs, which is a direct pass-through to vLLM EngineArgs.
77-
_SHARDING_KEYS: frozenset[str] = dataclasses.field(
74+
# VllmConfig fields that require special processing before being passed to
75+
# vLLM and must not be passed via engine_kwargs, which is a raw pass-through
76+
# to vLLM EngineArgs.
77+
_RESERVED_KEYS: frozenset[str] = dataclasses.field(
7878
default=frozenset(
7979
{"tensor_parallel_size", "data_parallel_size", "expert_parallel_size"}
8080
),
@@ -85,11 +85,11 @@ class VllmConfig:
8585

8686
def __post_init__(self, engine_kwargs: Optional[Dict[str, Any]]):
8787
engine_kwargs = engine_kwargs or {}
88-
illegal = self._SHARDING_KEYS & engine_kwargs.keys()
88+
illegal = self._RESERVED_KEYS & engine_kwargs.keys()
8989
if illegal:
9090
raise ValueError(
91-
f"Sharding parallelism sizes must be set directly on VllmConfig, not"
92-
f" passed via engine_kwargs: {sorted(illegal)}"
91+
f"VllmConfig fields must be set directly on VllmConfig, not passed"
92+
f" via engine_kwargs: {sorted(illegal)}"
9393
)
9494
self._processed_engine_kwargs = engine_kwargs
9595
if engine_kwargs:

0 commit comments

Comments
 (0)