Skip to content

Commit 6945926

Browse files
committed
expert parallelism config
1 parent b121235 commit 6945926

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

tunix/generate/vllm_sampler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class VllmConfig:
6464
mesh: jax.sharding.Mesh = None
6565
data_parallel_size: int = -1
6666
tensor_parallel_size: int = -1
67+
expert_parallel_size: int = 1
6768

6869
# vLLM engine args that can be directly passed in without additional processing, e.g. max_model_len, async_scheduling, etc.
6970
engine_kwargs: dataclasses.InitVar[Optional[Dict[str, Any]]] = None
@@ -210,15 +211,20 @@ def _vllm_config(self, config: VllmConfig):
210211

211212
tensor_parallel_size = config.tensor_parallel_size
212213
data_parallel_size = config.data_parallel_size
214+
expert_parallel_size = config.expert_parallel_size
213215
total_mesh_devices = self._find_total_size(config.mesh)
214216

215217
if config.tensor_parallel_size == -1 and config.data_parallel_size == -1:
216-
tensor_parallel_size = total_mesh_devices
218+
tensor_parallel_size = total_mesh_devices // expert_parallel_size
217219
data_parallel_size = 1
218220
elif config.tensor_parallel_size == -1:
219-
tensor_parallel_size = total_mesh_devices // data_parallel_size
221+
tensor_parallel_size = (
222+
total_mesh_devices // (data_parallel_size * expert_parallel_size)
223+
)
220224
elif config.data_parallel_size == -1:
221-
data_parallel_size = total_mesh_devices // tensor_parallel_size
225+
data_parallel_size = (
226+
total_mesh_devices // (tensor_parallel_size * expert_parallel_size)
227+
)
222228

223229
args["data_parallel_size"] = data_parallel_size
224230
args["tensor_parallel_size"] = tensor_parallel_size

tunix/rl/rollout/base_rollout.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class RolloutConfig:
111111
# Parallelism configs.
112112
tensor_parallel_size: int = -1
113113
data_parallel_size: int = -1
114+
expert_parallel_size: int = 1
114115

115116
# vLLM specific rollout configs.
116117

@@ -149,6 +150,9 @@ class RolloutConfig:
149150
# axes, which can help reduce memory usage for large models with few KV heads.
150151
rollout_vllm_enable_dp_attention: bool = False
151152

153+
# Whether to enable expert parallelism for vLLM rollout engine.
154+
rollout_vllm_enable_expert_parallelism: bool = False
155+
152156
# Maximum number of batched tokens allowed in vLLM. This allows for pending prefill requests
153157
# to be batched along with decode requests if enough tokens are available. Only used when
154158
# chunked prefill is enabled.

tunix/rl/rollout/vllm_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
),
6262
"tensor_parallel_size": rollout_config.tensor_parallel_size,
6363
"data_parallel_size": rollout_config.data_parallel_size,
64+
"expert_parallel_size": rollout_config.expert_parallel_size,
6465
"max_num_batched_tokens": (
6566
rollout_config.rollout_vllm_max_num_batched_tokens
6667
),

0 commit comments

Comments
 (0)