@@ -123,13 +123,15 @@ def gen_config_file(config_path: str,
123123 model_path : str ,
124124 num_ctx_servers : int ,
125125 ctx_tp_size : int ,
126+ ctx_pp_size : int ,
126127 ctx_batch_size : int ,
127128 ctx_max_num_tokens : int ,
128129 ctx_max_seq_len : int ,
129130 ctx_free_gpu_memory_fraction : float ,
130131 ctx_enable_attention_dp : bool ,
131132 num_gen_servers : int ,
132133 gen_tp_size : int ,
134+ gen_pp_size : int ,
133135 gen_batch_size : int ,
134136 gen_max_num_tokens : int ,
135137 gen_max_seq_len : int ,
@@ -148,13 +150,15 @@ def gen_config_file(config_path: str,
148150 model_path: Path to the model
149151 num_ctx_servers: Number of context servers
150152 ctx_tp_size: Tensor parallel size for context servers
153+ ctx_pp_size: Pipeline parallel size for context servers
151154 ctx_batch_size: Batch size for context servers
152155 ctx_max_num_tokens: Max number of tokens for context servers
153156 ctx_max_seq_len: Max sequence length for context servers
154157 ctx_free_gpu_memory_fraction: Free GPU memory fraction for context servers
155158 ctx_enable_attention_dp: Enable attention DP for context servers
156159 num_gen_servers: Number of generation servers
157160 gen_tp_size: Tensor parallel size for generation servers
161+ gen_pp_size: Pipeline parallel size for generation servers
158162 gen_batch_size: Batch size for generation servers
159163 gen_max_num_tokens: Max number of tokens for generation servers
160164 gen_enable_attention_dp: Enable attention DP for generation servers
@@ -187,7 +191,7 @@ def gen_config_file(config_path: str,
187191 'tensor_parallel_size' : ctx_tp_size ,
188192 'moe_expert_parallel_size' : ctx_tp_size ,
189193 'enable_attention_dp' : ctx_enable_attention_dp ,
190- 'pipeline_parallel_size' : 1 ,
194+ 'pipeline_parallel_size' : ctx_pp_size ,
191195 'print_iter_log' : True ,
192196 'disable_overlap_scheduler' : True ,
193197 'kv_cache_config' : {
@@ -205,7 +209,7 @@ def gen_config_file(config_path: str,
205209 'tensor_parallel_size' : gen_tp_size ,
206210 'moe_expert_parallel_size' : gen_tp_size ,
207211 'enable_attention_dp' : gen_enable_attention_dp ,
208- 'pipeline_parallel_size' : 1 ,
212+ 'pipeline_parallel_size' : gen_pp_size ,
209213 'max_batch_size' : gen_batch_size ,
210214 'max_num_tokens' : gen_max_num_tokens ,
211215 'max_seq_len' : gen_max_seq_len ,
@@ -237,15 +241,15 @@ def gen_config_file(config_path: str,
237241
238242 # Generate URLs for context and generation servers
239243 ctx_urls , task_nodes_offset = generate_urls ("ctx" , num_ctx_servers ,
240- ctx_tp_size , 1 ,
244+ ctx_tp_size , ctx_pp_size ,
241245 max_tasks_per_node , nodes ,
242246 task_nodes , node_ports )
243247 if num_ctx_servers > 0 :
244248 config ['context_servers' ]['urls' ] = ctx_urls
245249
246- gen_urls , _ = generate_urls ("gen" , num_gen_servers , gen_tp_size , 1 ,
247- max_tasks_per_node , nodes , task_nodes ,
248- node_ports , task_nodes_offset )
250+ gen_urls , _ = generate_urls ("gen" , num_gen_servers , gen_tp_size ,
251+ gen_pp_size , max_tasks_per_node , nodes ,
252+ task_nodes , node_ports , task_nodes_offset )
249253 config ['generation_servers' ]['urls' ] = gen_urls
250254
251255 # set the hostname to the first node
@@ -300,6 +304,10 @@ def gen_config_file(config_path: str,
300304 type = int ,
301305 required = True ,
302306 help = "Tensor parallel size for context servers" )
307+ parser .add_argument ("--ctx_pp_size" ,
308+ type = int ,
309+ default = 1 ,
310+ help = "Pipeline parallel size for context servers" )
303311 parser .add_argument ("--ctx_batch_size" ,
304312 type = int ,
305313 required = True ,
@@ -328,6 +336,10 @@ def gen_config_file(config_path: str,
328336 type = int ,
329337 required = True ,
330338 help = "Tensor parallel size for generation servers" )
339+ parser .add_argument ("--gen_pp_size" ,
340+ type = int ,
341+ default = 1 ,
342+ help = "Pipeline parallel size for generation servers" )
331343 parser .add_argument ("--gen_batch_size" ,
332344 type = int ,
333345 required = True ,
@@ -372,11 +384,11 @@ def gen_config_file(config_path: str,
372384 args = parser .parse_args ()
373385
374386 gen_config_file (args .config , args .model , args .num_ctx_servers ,
375- args .ctx_tp_size , args .ctx_batch_size ,
387+ args .ctx_tp_size , args .ctx_pp_size , args . ctx_batch_size ,
376388 args .ctx_max_num_tokens , args .ctx_max_seq_len ,
377389 args .ctx_free_gpu_memory_fraction ,
378390 args .ctx_enable_attention_dp , args .num_gen_servers ,
379- args .gen_tp_size , args .gen_batch_size ,
391+ args .gen_tp_size , args .gen_pp_size , args . gen_batch_size ,
380392 args .gen_max_num_tokens , args .gen_max_seq_len ,
381393 args .gen_enable_attention_dp , args .gen_gpu_memory_fraction ,
382394 args .eplb_num_slots , args .mtp_size , args .worker_start_port ,
0 commit comments