2121 init_distributed_environment ,
2222 set_custom_all_reduce ,
2323 set_mscclpp_all_reduce ,
24+ set_torch_symm_mem_all_reduce ,
2425)
2526from sglang .srt .layers .dp_attention import (
2627 get_attention_tp_group ,
@@ -118,6 +119,8 @@ def init_torch_distributed(self):
118119 backend = "gloo"
119120 elif self .device == "npu" :
120121 backend = "hccl"
122+ else :
123+ backend = "gloo"
121124
122125 before_avail_memory = get_available_gpu_memory (self .device , self .gpu_id )
123126 if not self .server_args .enable_p2p_check :
@@ -129,6 +132,7 @@ def init_torch_distributed(self):
129132 dist_init_method = f"tcp://127.0.0.1:{ self .dist_port } "
130133 set_custom_all_reduce (not self .server_args .disable_custom_all_reduce )
131134 set_mscclpp_all_reduce (self .server_args .enable_mscclpp )
135+ set_torch_symm_mem_all_reduce (self .server_args .enable_torch_symm_mem )
132136
133137 if not self .is_draft_worker :
134138 if self .device == "cpu" :
@@ -153,14 +157,21 @@ def init_torch_distributed(self):
153157 local_rank = self .gpu_id ,
154158 distributed_init_method = dist_init_method ,
155159 timeout = self .server_args .dist_timeout ,
160+ moe_a2a_backend = self .server_args .moe_a2a_backend ,
161+ recovered_rank = self .server_args .elastic_ep_rejoin ,
156162 )
157163
158164 # Use monkey patch modified function
159165 sglang .srt .distributed .parallel_state .initialize_model_parallel (
160166 tensor_model_parallel_size = self .tp_size ,
161167 pipeline_model_parallel_size = self .pp_size ,
162168 expert_model_parallel_size = self .moe_ep_size ,
169+ attention_data_parallel_size = self .dp_size ,
170+ attention_context_model_parallel_size = self .attn_cp_size ,
171+ moe_data_model_parallel_size = self .moe_dp_size ,
163172 duplicate_tp_group = self .server_args .enable_pdmux ,
173+ enable_symm_mem = self .server_args .enable_symm_mem ,
174+ recovered_rank = self .server_args .elastic_ep_rejoin ,
164175 pp_start_layer = self .pp_start_layer ,
165176 pp_end_layer = self .pp_end_layer ,
166177 hidden_layers = self .model_config .num_hidden_layers ,
@@ -225,6 +236,7 @@ def form_sgl_server_args(
225236 lora_eviction_policy : Optional [str ] = "lru" ,
226237 lora_backend : Optional [str ] = "triton" ,
227238 max_lora_chunk_size : Optional [int ] = 128 ,
239+ max_num_tokens_per_batch : int = 16384 ,
228240):
229241 """Creates a SGL ServerArgs object"""
230242 sgl_server_args = ServerArgs (
@@ -247,6 +259,7 @@ def form_sgl_server_args(
247259 lora_backend = lora_backend ,
248260 max_lora_chunk_size = max_lora_chunk_size ,
249261 dp_size = dp_size ,
262+ max_total_tokens = max_num_tokens_per_batch ,
250263 )
251264 return sgl_server_args
252265
@@ -338,6 +351,7 @@ def initialize_sgl_model_runner(
338351 lora_eviction_policy ,
339352 lora_backend ,
340353 max_lora_chunk_size ,
354+ max_num_tokens_per_batch = max_num_tokens_per_batch ,
341355 )
342356 initialize_moe_config (server_args )
343357 quant_method = None
0 commit comments