Skip to content

Commit b114d23

Browse files
committed
expert parallelism config
1 parent efb4913 commit b114d23

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

tunix/generate/vllm_sampler.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class VllmConfig:
6464
mesh: jax.sharding.Mesh = None
6565
data_parallel_size: int = -1
6666
tensor_parallel_size: int = -1
67+
expert_parallel_size: int = 1
6768

6869
# vLLM engine args that can be directly passed in without additional processing, e.g. max_model_len, async_scheduling, etc.
6970
engine_kwargs: dataclasses.InitVar[Optional[Dict[str, Any]]] = None
@@ -204,25 +205,49 @@ def _find_total_size(self, mesh: jax.sharding.Mesh) -> int:
204205
# since vllm doesn't support DP yet, simply return the total rank size.
205206
return math.prod(mesh.shape.values())
206207

207-
def _vllm_config(self, config: VllmConfig):
208-
"""Setup vllm config from Tunix Vllm config."""
209-
args = config._processed_engine_kwargs.copy()
210-
208+
def _configure_sharding(
209+
self, config: VllmConfig, args: Dict[str, Any]
210+
) -> None:
211+
"""Resolves parallelism sizes and sets the sharding config in args."""
211212
tensor_parallel_size = config.tensor_parallel_size
212213
data_parallel_size = config.data_parallel_size
214+
expert_parallel_size = config.expert_parallel_size
213215
total_mesh_devices = self._find_total_size(config.mesh)
214216

217+
if total_mesh_devices % expert_parallel_size != 0:
218+
raise ValueError(
219+
f"Total mesh devices ({total_mesh_devices}) must be divisible by"
220+
f" expert_parallel_size ({expert_parallel_size})."
221+
)
222+
215223
if config.tensor_parallel_size == -1 and config.data_parallel_size == -1:
216-
tensor_parallel_size = total_mesh_devices
224+
tensor_parallel_size = total_mesh_devices // expert_parallel_size
217225
data_parallel_size = 1
218226
elif config.tensor_parallel_size == -1:
219-
tensor_parallel_size = total_mesh_devices // data_parallel_size
227+
tensor_parallel_size = (
228+
total_mesh_devices // (data_parallel_size * expert_parallel_size)
229+
)
220230
elif config.data_parallel_size == -1:
221-
data_parallel_size = total_mesh_devices // tensor_parallel_size
231+
data_parallel_size = (
232+
total_mesh_devices // (tensor_parallel_size * expert_parallel_size)
233+
)
222234

223235
args["data_parallel_size"] = data_parallel_size
224236
args["tensor_parallel_size"] = tensor_parallel_size
225237

238+
device_indexes = config.mesh.device_ids.flatten().tolist()
239+
args["additional_config"]["sharding"] = {
240+
"sharding_strategy": {
241+
"expert_parallelism": expert_parallel_size,
242+
"device_indexes": device_indexes,
243+
"enable_dp_attention": config.enable_dp_attention,
244+
}
245+
}
246+
247+
def _vllm_config(self, config: VllmConfig):
248+
"""Setup vllm config from Tunix Vllm config."""
249+
args = config._processed_engine_kwargs.copy()
250+
226251
# Init vLLM model with random weights to speed up bootstrap time, because
227252
# model weights are synced from trainer later on
228253
if config.init_with_random_weights:
@@ -235,14 +260,7 @@ def _vllm_config(self, config: VllmConfig):
235260
if config.lora_config is not None:
236261
args["additional_config"]["lora_config"] = config.lora_config
237262

238-
device_indexes = config.mesh.device_ids.flatten().tolist()
239-
240-
args["additional_config"]["sharding"] = {
241-
"sharding_strategy": {
242-
"device_indexes": device_indexes,
243-
"enable_dp_attention": config.enable_dp_attention,
244-
}
245-
}
263+
self._configure_sharding(config, args)
246264

247265
return args
248266

tunix/rl/rollout/base_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class RolloutConfig:
111111
# Parallelism configs.
112112
tensor_parallel_size: int = -1
113113
data_parallel_size: int = -1
114+
expert_parallel_size: int = 1
114115

115116
# vLLM specific rollout configs.
116117

tunix/rl/rollout/vllm_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
tpu_backend_type=rollout_config.rollout_vllm_tpu_backend_type,
5050
additional_config=rollout_config.rollout_vllm_additional_config,
5151
enable_dp_attention=rollout_config.rollout_vllm_enable_dp_attention,
52+
expert_parallel_size=rollout_config.expert_parallel_size,
5253
hbm_utilization=rollout_config.rollout_vllm_hbm_utilization,
5354
lora_config=rollout_config.rollout_vllm_lora_config,
5455
mesh=mesh,

0 commit comments

Comments
 (0)