Skip to content

Commit a512490

Browse files
committed
expert parallelism config
1 parent efb4913 commit a512490

File tree

6 files changed

+153
-23
lines changed

6 files changed

+153
-23
lines changed

tests/generate/utils_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from jax import sharding
2121
import jax.numpy as jnp
2222
import numpy as np
23+
from unittest import mock
2324
from tunix.generate import utils
2425
from tunix.rl import reshard
2526

@@ -1185,5 +1186,47 @@ def test_transfer_state_directly_scanned_layers_casting(self):
11851186
)
11861187

11871188

1189+
class ResolveParallelismSizesTest(parameterized.TestCase):
1190+
1191+
def _make_mesh(self, total_devices):
1192+
"""Returns a mock mesh with the given total device count."""
1193+
mesh = mock.MagicMock()
1194+
mesh.shape = {"axis": total_devices}
1195+
return mesh
1196+
1197+
@parameterized.named_parameters(
1198+
("tp_and_dp_inferred_no_ep", 8, -1, -1, 1, 8, 1, 1),
1199+
("tp_and_dp_inferred_with_ep", 8, -1, -1, 2, 4, 1, 2),
1200+
("tp_inferred_with_ep", 8, -1, 2, 2, 2, 2, 2),
1201+
("dp_inferred_with_ep", 8, 2, -1, 2, 2, 2, 2),
1202+
("all_explicit", 8, 4, 2, 1, 4, 2, 1),
1203+
)
1204+
def test_resolve_parallelism_sizes(
1205+
self,
1206+
total_devices,
1207+
tp_in,
1208+
dp_in,
1209+
ep_in,
1210+
expected_tp,
1211+
expected_dp,
1212+
expected_ep,
1213+
):
1214+
mesh = self._make_mesh(total_devices)
1215+
tp, dp, ep = utils.resolve_parallelism_sizes(
1216+
mesh=mesh,
1217+
tensor_parallel_size=tp_in,
1218+
data_parallel_size=dp_in,
1219+
expert_parallel_size=ep_in,
1220+
)
1221+
self.assertEqual(tp, expected_tp)
1222+
self.assertEqual(dp, expected_dp)
1223+
self.assertEqual(ep, expected_ep)
1224+
1225+
def test_resolve_parallelism_sizes_indivisible_ep_raises(self):
1226+
mesh = self._make_mesh(8)
1227+
with self.assertRaisesRegex(ValueError, "expert_parallel_size"):
1228+
utils.resolve_parallelism_sizes(mesh=mesh, expert_parallel_size=3)
1229+
1230+
11881231
if __name__ == "__main__":
11891232
absltest.main()

tests/generate/vllm_sampler_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,5 +358,58 @@ async def dispatch_requests():
358358
)
359359

360360

361+
class VllmSamplerConfigTest(absltest.TestCase):
362+
"""Unit tests for VllmSampler config plumbing (no hardware required)."""
363+
364+
def _make_mock_mesh(self, total_devices):
365+
mesh = mock.MagicMock()
366+
mesh.shape = {"axis": total_devices}
367+
mesh.device_ids.flatten.return_value.tolist.return_value = list(
368+
range(total_devices)
369+
)
370+
return mesh
371+
372+
def _make_sampler(self, config):
373+
with mock.patch("tunix.generate.vllm_sampler.LLM"), mock.patch(
374+
"tunix.generate.vllm_sampler.tok_adapter.TokenizerAdapter"
375+
):
376+
return vllm_sampler.VllmSampler(
377+
tokenizer=mock.MagicMock(), config=config
378+
)
379+
380+
def test_expert_parallel_size_plumbed_to_sharding(self):
381+
mesh = self._make_mock_mesh(8)
382+
config = vllm_sampler.VllmConfig(
383+
mesh=mesh,
384+
expert_parallel_size=2,
385+
init_with_random_weights=False,
386+
)
387+
sampler = self._make_sampler(config)
388+
389+
sharding_strategy = sampler.args["additional_config"]["sharding"][
390+
"sharding_strategy"
391+
]
392+
# EP=2 should appear in the sharding strategy passed to vLLM.
393+
self.assertEqual(sharding_strategy["expert_parallelism"], 2)
394+
# With 8 total devices and EP=2, TP should be inferred as 4 and DP as 1.
395+
self.assertEqual(sampler.args["tensor_parallel_size"], 4)
396+
self.assertEqual(sampler.args["data_parallel_size"], 1)
397+
398+
def test_default_expert_parallel_size_is_one(self):
399+
mesh = self._make_mock_mesh(8)
400+
config = vllm_sampler.VllmConfig(
401+
mesh=mesh,
402+
init_with_random_weights=False,
403+
)
404+
sampler = self._make_sampler(config)
405+
406+
sharding_strategy = sampler.args["additional_config"]["sharding"][
407+
"sharding_strategy"
408+
]
409+
self.assertEqual(sharding_strategy["expert_parallelism"], 1)
410+
self.assertEqual(sampler.args["tensor_parallel_size"], 8)
411+
self.assertEqual(sampler.args["data_parallel_size"], 1)
412+
413+
361414
if __name__ == "__main__":
362415
absltest.main()

tunix/generate/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,50 @@ def intersect_trees(
991991
gc.collect()
992992

993993

994+
def resolve_parallelism_sizes(
995+
mesh: jax.sharding.Mesh,
996+
tensor_parallel_size: int = -1,
997+
data_parallel_size: int = -1,
998+
expert_parallel_size: int = 1,
999+
) -> tuple[int, int, int]:
1000+
"""Resolves tensor, data, and expert parallelism sizes from the mesh.
1001+
1002+
Any size passed as -1 is inferred from the total number of mesh devices and
1003+
the other sizes. Raises ValueError if the mesh size is not divisible by
1004+
expert_parallel_size.
1005+
1006+
Args:
1007+
mesh: The JAX device mesh.
1008+
tensor_parallel_size: Desired tensor parallelism degree, or -1 to infer.
1009+
data_parallel_size: Desired data parallelism degree, or -1 to infer.
1010+
expert_parallel_size: Desired expert parallelism degree.
1011+
1012+
Returns:
1013+
A tuple of (tensor_parallel_size, data_parallel_size, expert_parallel_size).
1014+
"""
1015+
total_mesh_devices = math.prod(mesh.shape.values())
1016+
1017+
if total_mesh_devices % expert_parallel_size != 0:
1018+
raise ValueError(
1019+
f"Total mesh devices ({total_mesh_devices}) must be divisible by"
1020+
f" expert_parallel_size ({expert_parallel_size})."
1021+
)
1022+
1023+
if tensor_parallel_size == -1 and data_parallel_size == -1:
1024+
tensor_parallel_size = total_mesh_devices // expert_parallel_size
1025+
data_parallel_size = 1
1026+
elif tensor_parallel_size == -1:
1027+
tensor_parallel_size = (
1028+
total_mesh_devices // (data_parallel_size * expert_parallel_size)
1029+
)
1030+
elif data_parallel_size == -1:
1031+
data_parallel_size = (
1032+
total_mesh_devices // (tensor_parallel_size * expert_parallel_size)
1033+
)
1034+
1035+
return tensor_parallel_size, data_parallel_size, expert_parallel_size
1036+
1037+
9941038
def verify_state_closeness(golden_state, state, atol=1e-2):
9951039
"""Check if the golden NNX state is close to the other NNX state.
9961040

tunix/generate/vllm_sampler.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import atexit
1818
import dataclasses
1919
from itertools import count
20-
import math
2120
import os
2221
from typing import Any, Dict, List, Optional, Tuple, Union
2322

@@ -64,6 +63,7 @@ class VllmConfig:
6463
mesh: jax.sharding.Mesh = None
6564
data_parallel_size: int = -1
6665
tensor_parallel_size: int = -1
66+
expert_parallel_size: int = 1
6767

6868
# vLLM engine args that can be directly passed in without additional processing, e.g. max_model_len, async_scheduling, etc.
6969
engine_kwargs: dataclasses.InitVar[Optional[Dict[str, Any]]] = None
@@ -199,30 +199,10 @@ def load_checkpoint(self, path_or_weights: str | jaxtyping.PyTree):
199199
else:
200200
raise NotImplementedError("Only support in memory weight sync as of now.")
201201

202-
def _find_total_size(self, mesh: jax.sharding.Mesh) -> int:
203-
"""Finds the tensor parallel size from the mesh."""
204-
# since vllm doesn't support DP yet, simply return the total rank size.
205-
return math.prod(mesh.shape.values())
206-
207202
def _vllm_config(self, config: VllmConfig):
208203
"""Setup vllm config from Tunix Vllm config."""
209204
args = config._processed_engine_kwargs.copy()
210205

211-
tensor_parallel_size = config.tensor_parallel_size
212-
data_parallel_size = config.data_parallel_size
213-
total_mesh_devices = self._find_total_size(config.mesh)
214-
215-
if config.tensor_parallel_size == -1 and config.data_parallel_size == -1:
216-
tensor_parallel_size = total_mesh_devices
217-
data_parallel_size = 1
218-
elif config.tensor_parallel_size == -1:
219-
tensor_parallel_size = total_mesh_devices // data_parallel_size
220-
elif config.data_parallel_size == -1:
221-
data_parallel_size = total_mesh_devices // tensor_parallel_size
222-
223-
args["data_parallel_size"] = data_parallel_size
224-
args["tensor_parallel_size"] = tensor_parallel_size
225-
226206
# Init vLLM model with random weights to speed up bootstrap time, because
227207
# model weights are synced from trainer later on
228208
if config.init_with_random_weights:
@@ -235,10 +215,19 @@ def _vllm_config(self, config: VllmConfig):
235215
if config.lora_config is not None:
236216
args["additional_config"]["lora_config"] = config.lora_config
237217

238-
device_indexes = config.mesh.device_ids.flatten().tolist()
218+
tp, dp, ep = utils.resolve_parallelism_sizes(
219+
mesh=config.mesh,
220+
tensor_parallel_size=config.tensor_parallel_size,
221+
data_parallel_size=config.data_parallel_size,
222+
expert_parallel_size=config.expert_parallel_size,
223+
)
224+
args["tensor_parallel_size"] = tp
225+
args["data_parallel_size"] = dp
239226

227+
device_indexes = config.mesh.device_ids.flatten().tolist()
240228
args["additional_config"]["sharding"] = {
241229
"sharding_strategy": {
230+
"expert_parallelism": ep,
242231
"device_indexes": device_indexes,
243232
"enable_dp_attention": config.enable_dp_attention,
244233
}
@@ -414,7 +403,6 @@ def __call__(
414403
sampling_params.top_p = top_p
415404
if top_k is not None:
416405
sampling_params.top_k = top_k
417-
418406
if seed is not None:
419407
sampling_params.seed = seed
420408

tunix/rl/rollout/base_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class RolloutConfig:
111111
# Parallelism configs.
112112
tensor_parallel_size: int = -1
113113
data_parallel_size: int = -1
114+
expert_parallel_size: int = 1
114115

115116
# vLLM specific rollout configs.
116117

tunix/rl/rollout/vllm_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
),
6262
"tensor_parallel_size": rollout_config.tensor_parallel_size,
6363
"data_parallel_size": rollout_config.data_parallel_size,
64+
"expert_parallel_size": rollout_config.expert_parallel_size,
6465
"max_num_batched_tokens": (
6566
rollout_config.rollout_vllm_max_num_batched_tokens
6667
),

0 commit comments

Comments
 (0)