@@ -64,6 +64,7 @@ class VllmConfig:
6464 mesh : jax .sharding .Mesh = None
6565 data_parallel_size : int = - 1
6666 tensor_parallel_size : int = - 1
67+ expert_parallel_size : int = 1
6768
6869 # vLLM engine args that can be directly passed in without additional processing, e.g. max_model_len, async_scheduling, etc.
6970 engine_kwargs : dataclasses .InitVar [Optional [Dict [str , Any ]]] = None
@@ -204,25 +205,49 @@ def _find_total_size(self, mesh: jax.sharding.Mesh) -> int:
204205 # since vllm doesn't support DP yet, simply return the total rank size.
205206 return math .prod (mesh .shape .values ())
206207
207- def _vllm_config ( self , config : VllmConfig ):
208- """Setup vllm config from Tunix Vllm config."""
209- args = config . _processed_engine_kwargs . copy ()
210-
208+ def _configure_sharding (
209+ self , config : VllmConfig , args : Dict [ str , Any ]
210+ ) -> None :
211+ """Resolves parallelism sizes and sets the sharding config in args."""
211212 tensor_parallel_size = config .tensor_parallel_size
212213 data_parallel_size = config .data_parallel_size
214+ expert_parallel_size = config .expert_parallel_size
213215 total_mesh_devices = self ._find_total_size (config .mesh )
214216
217+ if total_mesh_devices % expert_parallel_size != 0 :
218+ raise ValueError (
219+ f"Total mesh devices ({ total_mesh_devices } ) must be divisible by"
220+ f" expert_parallel_size ({ expert_parallel_size } )."
221+ )
222+
215223 if config .tensor_parallel_size == - 1 and config .data_parallel_size == - 1 :
216- tensor_parallel_size = total_mesh_devices
224+ tensor_parallel_size = total_mesh_devices // expert_parallel_size
217225 data_parallel_size = 1
218226 elif config .tensor_parallel_size == - 1 :
219- tensor_parallel_size = total_mesh_devices // data_parallel_size
227+ tensor_parallel_size = (
228+ total_mesh_devices // (data_parallel_size * expert_parallel_size )
229+ )
220230 elif config .data_parallel_size == - 1 :
221- data_parallel_size = total_mesh_devices // tensor_parallel_size
231+ data_parallel_size = (
232+ total_mesh_devices // (tensor_parallel_size * expert_parallel_size )
233+ )
222234
223235 args ["data_parallel_size" ] = data_parallel_size
224236 args ["tensor_parallel_size" ] = tensor_parallel_size
225237
238+ device_indexes = config .mesh .device_ids .flatten ().tolist ()
239+ args ["additional_config" ]["sharding" ] = {
240+ "sharding_strategy" : {
241+ "expert_parallelism" : expert_parallel_size ,
242+ "device_indexes" : device_indexes ,
243+ "enable_dp_attention" : config .enable_dp_attention ,
244+ }
245+ }
246+
247+ def _vllm_config (self , config : VllmConfig ):
248+ """Setup vllm config from Tunix Vllm config."""
249+ args = config ._processed_engine_kwargs .copy ()
250+
226251 # Init vLLM model with random weights to speed up bootstrap time, because
227252 # model weights are synced from trainer later on
228253 if config .init_with_random_weights :
@@ -235,14 +260,7 @@ def _vllm_config(self, config: VllmConfig):
235260 if config .lora_config is not None :
236261 args ["additional_config" ]["lora_config" ] = config .lora_config
237262
238- device_indexes = config .mesh .device_ids .flatten ().tolist ()
239-
240- args ["additional_config" ]["sharding" ] = {
241- "sharding_strategy" : {
242- "device_indexes" : device_indexes ,
243- "enable_dp_attention" : config .enable_dp_attention ,
244- }
245- }
263+ self ._configure_sharding (config , args )
246264
247265 return args
248266
0 commit comments