diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index af02f824a52a..e9833a8a5020 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -311,7 +311,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Argument | Description | Defaults | Options | | --- | --- | --- | --- | | `--expert-parallel-size`
`--ep-size`
`--ep` | The expert parallelism size. | `1` | Type: int | -| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`| +| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `ascend_fuseep`| | `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` | | `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` | | `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | diff --git a/python/sglang/srt/batch_overlap/operations_strategy.py b/python/sglang/srt/batch_overlap/operations_strategy.py index 41f40275eb73..d39ad838577d 100644 --- a/python/sglang/srt/batch_overlap/operations_strategy.py +++ b/python/sglang/srt/batch_overlap/operations_strategy.py @@ -7,6 +7,9 @@ from sglang.srt.batch_overlap.operations import Operation from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.utils import is_hip + +_is_hip = is_hip() @dataclass @@ -91,7 +94,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo( def _compute_moe_deepseek_blog_prefill(layer): device_properties = torch.cuda.get_device_properties(device="cuda") total_num_sms = device_properties.multi_processor_count - deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms + deep_gemm_num_sms = None + if not _is_hip: + deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms return OperationsStrategy( deep_gemm_num_sms=deep_gemm_num_sms, @@ -168,7 +173,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo( def _compute_moe_qwen3_prefill(layer): device_properties = torch.cuda.get_device_properties(device="cuda") total_num_sms = device_properties.multi_processor_count - deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms + deep_gemm_num_sms = None + if not _is_hip: + deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms return OperationsStrategy( deep_gemm_num_sms=deep_gemm_num_sms, diff --git a/python/sglang/srt/batch_overlap/two_batch_overlap.py b/python/sglang/srt/batch_overlap/two_batch_overlap.py index cfd2a54ed132..e2840fee0dde 100644 --- a/python/sglang/srt/batch_overlap/two_batch_overlap.py +++ b/python/sglang/srt/batch_overlap/two_batch_overlap.py @@ -30,6 +30,7 @@ from sglang.srt.layers.moe.token_dispatcher import ( DeepEPDispatcher, MooncakeEPDispatcher, + MoriEPDispatcher, ) from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -1027,6 +1028,10 @@ def __init__(self, **kwargs): self._inners = [ MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) ] + elif get_moe_a2a_backend().is_mori(): + self._inners = [ + MoriEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) + ] def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs): return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index d851040cfc37..0a2e579645b0 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -431,7 +431,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # num_kv_splits_indptr = None if forward_batch.forward_mode.is_decode_or_idle(): - if spec_info is None: + if spec_info is None or forward_batch.forward_mode.is_idle(): kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( @@ -1074,6 +1074,17 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], ): + num_kv_splits = None + # num_kv_splits_indptr = None + + work_metadata = None + work_info_set = None + work_indptr = None + + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + if forward_mode.is_decode_or_idle(): kv_indptr = self.kv_indptr kv_indices = self.cuda_graph_kv_indices @@ -1093,6 +1104,58 @@ def init_forward_metadata_replay_cuda_graph( kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + if self.use_mla: + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 + + if _use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + # num_kv_splits_indptr=num_kv_splits_indptr, + ) + elif forward_mode.is_target_verify(): bs = len(req_pool_indices) qo_indptr = self.qo_indptr[: bs + 1] @@ -1117,7 +1180,57 @@ def init_forward_metadata_replay_cuda_graph( self.req_to_token.stride(0), ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = self.num_draft_tokens + + # if self.kv_cache_dtype == fp8_dtype: + if _use_mla_ps_kernel: + + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + # num_kv_splits_indptr=num_kv_splits_indptr, + ) + elif forward_mode.is_draft_extend(): + num_tokens_per_bs = self.speculative_num_steps + 1 seq_lens = seq_lens[:bs] accept_lens = spec_info.accept_length[:bs] qo_indptr = self.qo_indptr[: bs + 1] @@ -1135,6 +1248,54 @@ def init_forward_metadata_replay_cuda_graph( self.req_to_token.stride(0), ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = num_tokens_per_bs + + if _use_mla_ps_kernel: + + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + # num_kv_splits_indptr=num_kv_splits_indptr, + ) + else: raise ValueError("Invalid forward mode") @@ -1366,23 +1527,6 @@ def forward_extend( num_kv_splits = self.forward_metadata.num_kv_splits - if layer.layer_id == 0 and _use_mla_ps_kernel: - self.make_mla_meta_data( - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_last_page_len, - work_metadata, - work_info_set, - work_indptr, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - self.forward_metadata.max_q_len, - fast_mode=fast_mode, - max_split_per_batch=num_kv_splits, - intra_batch_mode=intra_batch_mode, - ) - mla_decode_fwd( q, K_Buffer.view(-1, 1, 1, layer.qk_head_dim), @@ -1418,23 +1562,6 @@ def forward_extend( num_kv_splits = self.forward_metadata.num_kv_splits - if layer.layer_id == 0 and _use_mla_ps_kernel: - self.make_mla_meta_data( - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_last_page_len, - work_metadata, - work_info_set, - work_indptr, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - self.forward_metadata.max_q_len, - fast_mode=fast_mode, - max_split_per_batch=num_kv_splits, - intra_batch_mode=intra_batch_mode, - ) - if self.forward_metadata.run_graph is not True: bs, q_pad, q_mask = pad_sequence_with_mask( @@ -1577,23 +1704,6 @@ def forward_decode( num_kv_splits = self.forward_metadata.num_kv_splits - if layer.layer_id == 0 and _use_mla_ps_kernel: - self.make_mla_meta_data( - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_last_page_len, - work_metadata, - work_info_set, - work_indptr, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - self.forward_metadata.max_q_len, - fast_mode=fast_mode, - max_split_per_batch=num_kv_splits, - intra_batch_mode=intra_batch_mode, - ) - mla_decode_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_buffer.view(-1, 1, 1, layer.qk_head_dim), diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ebcc696ecf0d..da41c667017e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -24,7 +24,10 @@ DeepEPLLCombineInput, DeepEPNormalCombineInput, ) -from sglang.srt.layers.moe.token_dispatcher.moriep import MoriEPNormalCombineInput +from sglang.srt.layers.moe.token_dispatcher.moriep import ( + MoriEPLLCombineInput, + MoriEPNormalCombineInput, +) from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.compressed_tensors.schemes import ( @@ -130,13 +133,14 @@ def __init__( if ( self.deepep_mode.enable_low_latency() and not _is_npu + and not _is_hip and not ( get_moe_runner_backend().is_flashinfer_cutedsl() and self.quant_config.get_name() == "modelopt_fp4" ) ): - # NPU supports low_latency deepep without deepgemm - # FP4 quantization with flashinfer_cutedsl also supports low_latency deepep without deepgemm + # AMD HIP, NPU supports low_latency deepep without deepgemm + # NV FP4 quantization with flashinfer_cutedsl also supports low_latency deepep without deepgemm assert ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" @@ -246,6 +250,7 @@ def run_moe_core( if DispatchOutputChecker.format_is_deepep_normal(dispatch_output) else DeepEPLLCombineInput ) + return combine_input_wrapper( hidden_states=output, topk_ids=dispatch_output.topk_ids, @@ -275,8 +280,10 @@ def forward_aiter( dispatch_output.topk_ids, dispatch_output.topk_weights, ) + if hidden_states.shape[0] == 0: return hidden_states + # in original deepep, idx == -1 meaning invalid and will not be processed. # aiter does not accept -1, we use a expert mask to make these idx invalid # (idx == num_local_experts) meaning not used in aiter fused_moe @@ -592,28 +599,46 @@ def forward( self, hidden_states: torch.Tensor, topk_output: TopKOutput, - forward_shared_experts=None, - alt_stream=None, - disable_sbo=False, ): num_token = hidden_states.shape[0] - output_dtype = hidden_states.dtype + dispatch_output = self.dispatcher.dispatch( + hidden_states=hidden_states, topk_output=topk_output + ) + combine_input = self.run_moe_core(dispatch_output) + hidden_states = self.dispatcher.combine( + combine_input=combine_input, + ) + + return hidden_states[:num_token] + + def run_moe_core( + self, + dispatch_output: DispatchOutput, + ): + # TODO(billishyahao): check aiter path + # billishyahao: for now, fused_moe only support torch.bfloat16 + output_dtype = torch.bfloat16 scale = None is_fp8_quant = isinstance(self.quant_method, Fp8MoEMethod) is_quark_w4a4 = isinstance(self.scheme, QuarkW4A4MXFp4MoE) - # dispatch - dispatch_output = self.dispatcher.dispatch( - hidden_states, topk_output - ) # , scale=scale) - ( dispatch_a1, dispatch_scale, dispatch_ids, dispatch_weights, dispatch_recv_token_num, - ) = dispatch_output + origin_topk_ids, + origin_topk_weights, + ) = ( + dispatch_output.hidden_states, + dispatch_output.hidden_states_scale, + dispatch_output.topk_ids, + dispatch_output.topk_weights, + dispatch_output.num_recv_tokens_per_expert, + dispatch_output.origin_topk_ids, + dispatch_output.origin_topk_weights, + ) w13_weight = self.w13_weight w2_weight = self.w2_weight @@ -670,17 +695,19 @@ def forward( dtype=output_dtype, ) - combine_input_wrapper = MoriEPNormalCombineInput - combine_input = combine_input_wrapper( - hidden_states=hidden_states, - topk_ids=topk_output.topk_ids, - topk_weights=topk_output.topk_weights, - ) + from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker - # combine - result = self.dispatcher.combine(combine_input) + combine_input_wrapper = ( + MoriEPNormalCombineInput + if DispatchOutputChecker.format_is_deepep_normal(dispatch_output) + else MoriEPLLCombineInput + ) - return result[:num_token] + return combine_input_wrapper( + hidden_states=hidden_states, + topk_ids=dispatch_output.origin_topk_ids, + topk_weights=dispatch_output.origin_topk_weights, + ) def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index de8a07ab3000..85b52d4caf32 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -96,9 +96,13 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: a2a_backend = get_moe_a2a_backend() if a2a_backend.is_none(): return StandardDispatcher(moe_runner_config) - elif a2a_backend.is_deepep() or a2a_backend.is_mooncake(): + elif a2a_backend.is_deepep() or a2a_backend.is_mooncake() or a2a_backend.is_mori(): return MaybeTboDeepEPDispatcher( - group=get_tp_group().device_group, + group=( + get_tp_group().device_group + if not a2a_backend.is_mori() + else get_tp_group() + ), router_topk=moe_runner_config.top_k, permute_fusion=True, num_experts=moe_runner_config.num_experts, @@ -121,19 +125,7 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: hidden_size=moe_runner_config.hidden_size, params_dtype=moe_runner_config.params_dtype, ) - elif a2a_backend.is_mori(): - from sglang.srt.layers.moe.token_dispatcher import MoriEPDispatcher - return MoriEPDispatcher( - group=get_tp_group(), - router_topk=moe_runner_config.top_k, - permute_fusion=True, - num_experts=moe_runner_config.num_experts, - num_local_experts=moe_runner_config.num_local_experts, - hidden_size=moe_runner_config.hidden_size, - params_dtype=moe_runner_config.params_dtype, - deepep_mode=get_deepep_mode(), - ) elif a2a_backend.is_flashinfer(): return FlashinferDispatcher( group=get_tp_group().device_group, diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py index 209570073b04..dd40a8d98548 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py @@ -28,6 +28,8 @@ ) from sglang.srt.layers.moe.token_dispatcher.moriep import ( MoriEPDispatcher, + MoriEPLLCombineInput, + MoriEPLLDispatchOutput, MoriEPNormalCombineInput, MoriEPNormalDispatchOutput, ) @@ -53,6 +55,8 @@ "MooncakeEPDispatcher", "MoriEPNormalDispatchOutput", "MoriEPNormalCombineInput", + "MoriEPLLDispatchOutput", + "MoriEPLLCombineInput", "MoriEPDispatcher", "StandardDispatcher", "StandardDispatchOutput", diff --git a/python/sglang/srt/layers/moe/token_dispatcher/moriep.py b/python/sglang/srt/layers/moe/token_dispatcher/moriep.py index 6ee2443b4058..dfb6e38368a5 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/moriep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/moriep.py @@ -12,6 +12,7 @@ DispatchOutput, DispatchOutputFormat, ) +from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPPDispatchHooks from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip @@ -40,21 +41,47 @@ logger = logging.getLogger(__name__) +class MoriEPPDispatchHooks(DeepEPPDispatchHooks): + + def __call__(self, dispatcher: BaseDispatcher): + for hook_fun in self.hook_dict.values(): + hook_fun(dispatcher) + + class MoriEPNormalDispatchOutput(NamedTuple): - """Mori EP dispatch output.""" + """Mori EP normal dispatch output.""" hidden_states: torch.Tensor hidden_states_scale: Optional[torch.Tensor] topk_ids: torch.Tensor topk_weights: torch.Tensor num_recv_tokens_per_expert: List[int] + origin_topk_ids: torch.Tensor + origin_topk_weights: torch.Tensor @property def format(self) -> DispatchOutputFormat: return DispatchOutputFormat.DEEPEP_NORMAL +class MoriEPLLDispatchOutput(NamedTuple): + """Mori EP low latency dispatch output.""" + + hidden_states: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] + topk_ids: torch.Tensor + topk_weights: torch.Tensor + num_recv_tokens_per_expert: List[int] + origin_topk_ids: torch.Tensor + origin_topk_weights: torch.Tensor + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.DEEPEP_LL + + assert isinstance(MoriEPNormalDispatchOutput, DispatchOutput) +assert isinstance(MoriEPLLDispatchOutput, DispatchOutput) class MoriEPNormalCombineInput(NamedTuple): @@ -69,12 +96,26 @@ def format(self) -> CombineInputFormat: return CombineInputFormat.DEEPEP_NORMAL +class MoriEPLLCombineInput(NamedTuple): + """Mori EP combine input.""" + + hidden_states: torch.Tensor + topk_ids: torch.Tensor + topk_weights: torch.Tensor + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.DEEPEP_LL + + assert isinstance(MoriEPNormalCombineInput, CombineInput) +assert isinstance(MoriEPLLCombineInput, CombineInput) class EpMode(Enum): INTRA_NODE = "intra_node" INTER_NODE = "inter_node" + LOW_LATENCY = "low_latency" @dataclass(frozen=True) @@ -113,12 +154,19 @@ def get_ep_dispatch_configs(num_max_dispatch_tokens_per_rank: int = 4096): block_num=64, rdma_block_num=32, ), + # TODO(billishyahao): may need to set different configs for intra node async + EpMode.LOW_LATENCY: EpDispatchConfig( + kernel_type=mori.ops.EpDispatchCombineKernelType.AsyncLL, + warp_num_per_block=8, + block_num=64, + rdma_block_num=32, + ), } # init_mori_op only needs do once in model initial stage # use lru_cache to reuse the same mori_op instance to avoid the init overhead for mori -@lru_cache(maxsize=1) +@lru_cache(maxsize=2) def init_mori_op( group, router_topk, @@ -137,11 +185,16 @@ def init_mori_op( cpu_group = group.cpu_group torch._C._distributed_c10d._register_process_group("mori", cpu_group) mori.shmem.shmem_torch_process_group_init("mori") + + mode = EpMode.INTRA_NODE if world_size <= 8 else EpMode.INTER_NODE + async_mode = get_bool_env_var("SGLANG_MORI_ASYNC_MODE", "false") + if async_mode: + mode = EpMode.LOW_LATENCY + logger.info( - f"[MORI init] {world_size=} {rank=} {hidden_size=} {params_dtype=} {num_max_dispatch_tokens_per_rank=} {num_local_experts=} {router_topk=}" + f"[MORI init] {world_size=} {rank=} {hidden_size=} {params_dtype=} {num_max_dispatch_tokens_per_rank=} {num_local_experts=} {router_topk=} {mode=}" ) - mode = EpMode.INTRA_NODE if world_size <= 8 else EpMode.INTER_NODE cfg = get_ep_dispatch_configs(num_max_dispatch_tokens_per_rank)[mode] kernel_type = cfg.kernel_type @@ -174,6 +227,28 @@ def init_mori_op( return mori_op +class CommStreamPool: + _streams = {} # key -> torch.cuda.Stream + + @classmethod + def _make_key(cls, group): + return (torch.cuda.current_device(), id(group)) + + @classmethod + def getStreamFromPool(cls, group) -> torch.cuda.Stream: + key = cls._make_key(group) + stream = cls._streams.get(key) + if stream is None: + stream = torch.cuda.Stream(priority=0) + cls._streams[key] = stream + return stream + + @classmethod + def clear_group(cls, group): + key = (torch.cuda.current_device(), id(group)) + cls._streams.pop(key, None) + + class _MoriEPDispatcherImplBase: def __init__( self, @@ -184,7 +259,6 @@ def __init__( num_local_experts: int, hidden_size: int, params_dtype: torch.dtype, - return_recv_hook: bool, deepep_mode: DeepEPMode, ): try: @@ -198,7 +272,6 @@ def __init__( self.num_local_experts = num_local_experts self.hidden_size = hidden_size self.params_dtype = params_dtype - self.return_recv_hook = return_recv_hook self.deepep_mode = deepep_mode self.num_max_dispatch_tokens_per_rank = get_int_env_var( @@ -215,6 +288,11 @@ def __init__( num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank, ) + self.quant_config: Optional[dict] = None + + self.overlap_args: Optional[CombineOverlapArgs] = None + self.meta_overlap_args: Optional[dict] = None + def dispatch_a( self, hidden_states: torch.Tensor, @@ -230,23 +308,46 @@ def combine_a( hidden_states: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, - overlap_args: Optional[CombineOverlapArgs] = None, ): raise NotImplementedError def combine_b(self, *args, **kwargs): raise NotImplementedError - def _get_buffer(self): - raise NotImplementedError + def set_quant_config(self, quant_config: dict) -> None: + self.quant_config = quant_config + + def set_overlap_args( + self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict + ) -> None: + self.overlap_args = combine_overlap_args + self.meta_overlap_args = meta_overlap_args + + def clear_overlap_args(self) -> None: + self.overlap_args = None + self.meta_overlap_args = None class _MoriEPDispatcherImplNormal(_MoriEPDispatcherImplBase): - def __init__(self, **kwargs): + def __init__(self, async_finish: bool, **kwargs): super().__init__(**kwargs) + + self.async_finish = async_finish self.quant_config = {} # [kk TODO] need to support mxfp4 type self.quant_func = get_hip_quant(QuantType.per_1x128) + self.enable_dual_stream = get_bool_env_var("SGLANG_MORI_DUAL_STREAM", "false") + self._comm_stream = None + if self.enable_dual_stream: + self._comm_stream = CommStreamPool.getStreamFromPool(self.group) + + def _capture_event_if_async(self) -> Optional[torch.cuda.Event]: + assert self.enable_dual_stream, "dual stream must be enabled" + if not self.async_finish: + return None + ev = torch.cuda.Event(blocking=False, interprocess=False) + ev.record(torch.cuda.current_stream()) + return ev def dispatch_a( self, @@ -255,17 +356,16 @@ def dispatch_a( ): topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids - return ( - hidden_states, - topk_weights, - topk_ids, - ) + previous_event = self._capture_event_if_async() if self._comm_stream else None + + return (hidden_states, topk_weights, topk_ids, previous_event) def dispatch_b( self, hidden_states, topk_weights, topk_ids, + previous_event, ): num_token = hidden_states.shape[0] scale = None @@ -295,14 +395,254 @@ def dispatch_b( recv_scales, recv_topk_ids, packed_recv_count, - ) = self._dispatch_core(hidden_states, topk_weights, topk_ids, scale) + done_event, + ) = self._dispatch_core( + hidden_states, + topk_weights, + topk_ids, + scale=scale, + previous_event=previous_event, + ) + + if self._comm_stream and self.async_finish and done_event is not None: + torch.cuda.current_stream().wait_event(done_event) return MoriEPNormalDispatchOutput( + hidden_states=packed_recv_hidden, + hidden_states_scale=recv_scales, + topk_ids=recv_topk_ids, + topk_weights=recv_topk_weights, + num_recv_tokens_per_expert=packed_recv_count, + origin_topk_ids=topk_ids, + origin_topk_weights=topk_weights, + ) + + def _dispatch_core( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + scale: Optional[torch.Tensor] = None, + previous_event: Optional[torch.cuda.Event] = None, + ): + done_event: Optional[torch.cuda.Event] = None + + if self._comm_stream: + compute_stream = torch.cuda.current_stream() + comm_stream = self._comm_stream # comm stream + + for t in (hidden_states, topk_weights, topk_ids): + t.record_stream(comm_stream) + if scale is not None: + scale.record_stream(comm_stream) + + with torch.cuda.stream(comm_stream): + # if (previous_event) stream_wait(comm_stream, previous_event) + # else stream_wait(comm_stream, compute_stream) + + if previous_event is not None: + comm_stream.wait_event(previous_event) + else: + comm_stream.wait_stream(compute_stream) + + ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + ) = self.mori_op.dispatch(hidden_states, topk_weights, scale, topk_ids) + + if self.async_finish: + done_event = torch.cuda.Event(blocking=False, interprocess=False) + done_event.record(comm_stream) + else: + compute_stream.wait_stream(comm_stream) + + for t in ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + ): + if t is not None: + t.record_stream(comm_stream) + else: + + ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + ) = self.mori_op.dispatch(hidden_states, topk_weights, scale, topk_ids) + + # TODO(billishyahao): EPLB + # get_global_expert_distribution_recorder().on_deepep_dispatch_normal( + + return ( packed_recv_hidden, + recv_topk_weights, recv_scales, recv_topk_ids, + packed_recv_count, + done_event, + ) + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + ): + previous_event = self._capture_event_if_async() if self._comm_stream else None + return hidden_states, topk_ids, topk_weights, previous_event + + def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event): + + hidden_states, done_event = self._combine_core( + hidden_states, topk_ids, topk_weights, previous_event + ) + + if self._comm_stream and self.async_finish and done_event is not None: + torch.cuda.current_stream().wait_event(done_event) + + return hidden_states + + def _combine_core( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + previous_event: Optional[torch.cuda.Event], + ): + done_event: Optional[torch.cuda.Event] = None + + if self._comm_stream: + compute_stream = torch.cuda.current_stream() + comm_stream = self._comm_stream + + for t in (hidden_states, topk_ids, topk_weights): + t.record_stream(comm_stream) + + with torch.cuda.stream(comm_stream): + if previous_event is not None: + comm_stream.wait_event(previous_event) + else: + comm_stream.wait_stream(compute_stream) + + combined_hidden_states = self.mori_op.combine( + hidden_states, None, topk_ids + )[0] + + if self.async_finish: + done_event = torch.cuda.Event(blocking=False, interprocess=False) + done_event.record(comm_stream) + else: + compute_stream.wait_stream(comm_stream) + + combined_hidden_states.record_stream(comm_stream) + + else: + combined_hidden_states = self.mori_op.combine( + hidden_states, None, topk_ids + )[0] + + return combined_hidden_states, done_event + + def set_quant_config(self, quant_config: dict): + self.quant_config = quant_config + + +class _MoriEPDispatcherImplLowLatency(_MoriEPDispatcherImplBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.quant_config = {} + self.quant_func = get_hip_quant(QuantType.per_1x128) + + def dispatch_a( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ): + import mori + + assert ( + self.mori_op.config.kernel_type + is mori.ops.EpDispatchCombineKernelType.AsyncLL + ), "mori asyncll mismatch" + + num_tokens = hidden_states.shape[0] + scale = None + + fp8_dispatch = get_bool_env_var("SGLANG_MORI_FP8_DISP", "False") + + if fp8_dispatch: + # FP8 quant + if num_tokens > 0: + # NOTE: aiter is able to handle token=0 case in UT. But for some reason it failed at e2e case. Root cause TBD. + hidden_states, scale = self.quant_func( + hidden_states, quant_dtype=fp8_dtype + ) + else: + hidden_states = torch.empty( + hidden_states.shape, dtype=fp8_dtype, device=hidden_states.device + ) + scale = torch.empty( + (0, self.hidden_size // 128), + dtype=torch.float32, + device=hidden_states.device, + ) + + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + ( + packed_recv_hidden, recv_topk_weights, + recv_scales, + recv_topk_ids, packed_recv_count, + ) = self._dispatch_core(hidden_states, topk_weights, topk_ids, scale=scale) + + return ( + packed_recv_hidden, + recv_topk_weights, + recv_topk_ids, + recv_scales, + packed_recv_count, + topk_weights, + topk_ids, + ) + + def dispatch_b( + self, + hidden_states, + recv_topk_weights, + recv_topk_ids, + recv_scales, + packed_recv_count, + topk_weights, + topk_ids, + ): + + ##TODO(billishyahao): add assertion here to check async + import mori + + assert ( + self.mori_op.config.kernel_type + is mori.ops.EpDispatchCombineKernelType.AsyncLL + ), "mori asyncll mismatch" + + self.mori_op.dispatch_recv() + + return MoriEPLLDispatchOutput( + hidden_states=hidden_states, + hidden_states_scale=recv_scales, + topk_ids=recv_topk_ids, + topk_weights=recv_topk_weights, + num_recv_tokens_per_expert=packed_recv_count, + origin_topk_ids=topk_ids, + origin_topk_weights=topk_weights, ) def _dispatch_core( @@ -312,16 +652,15 @@ def _dispatch_core( topk_ids: torch.Tensor, scale: Optional[torch.Tensor] = None, ): + ##TODO(billishyahao): add assertion here to check async + ( packed_recv_hidden, recv_topk_weights, recv_scales, recv_topk_ids, packed_recv_count, - ) = self.mori_op.dispatch(hidden_states, topk_weights, scale, topk_ids) - - # TODO(billishyahao): EPLB - # get_global_expert_distribution_recorder().on_deepep_dispatch_normal( + ) = self.mori_op.dispatch_send(hidden_states, topk_weights, scale, topk_ids) return ( packed_recv_hidden, @@ -338,21 +677,32 @@ def combine_a( topk_weights: torch.Tensor, overlap_args: Optional[CombineOverlapArgs] = None, ): - previous_event = None - return hidden_states, topk_ids, topk_weights, previous_event + hidden_states = self._combine_core( + hidden_states, + topk_ids, + topk_weights, + overlap_args=overlap_args, + ) + return hidden_states, topk_ids, topk_weights, overlap_args def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event): - hidden_states = self._combine_core(hidden_states, topk_ids, topk_weights) - return hidden_states + + self.mori_op.combine_recv() + + return hidden_states[0] def _combine_core( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, + overlap_args: Optional[CombineOverlapArgs] = None, ): - combined_hidden_states = self.mori_op.combine(hidden_states, None, topk_ids) - return combined_hidden_states[0] + combined_hidden_states = self.mori_op.combine_send( + hidden_states, None, topk_ids + ) + + return combined_hidden_states def set_quant_config(self, quant_config: dict): self.quant_config = quant_config @@ -380,27 +730,43 @@ def __init__( async_finish: bool = False, return_recv_hook: bool = False, ): + super().__init__() + self.deepep_mode = deepep_mode + common_kwargs = dict( + group=group, + router_topk=router_topk, + permute_fusion=permute_fusion, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + params_dtype=params_dtype, + deepep_mode=deepep_mode, + ) + + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher = _MoriEPDispatcherImplLowLatency( + **common_kwargs, + ) + if self.deepep_mode.enable_normal(): self._normal_dispatcher = _MoriEPDispatcherImplNormal( - group=group, - router_topk=router_topk, - permute_fusion=permute_fusion, - num_experts=num_experts, - num_local_experts=num_local_experts, - hidden_size=hidden_size, - params_dtype=params_dtype, - return_recv_hook=return_recv_hook, - deepep_mode=deepep_mode, + async_finish=async_finish, + **common_kwargs, ) - if self.deepep_mode.enable_low_latency(): - raise NotImplementedError self._stage = _Stage.INITIAL + self._deepep_dispatch_hooks = MoriEPPDispatchHooks() - def dispatch(self, *args, **kwargs) -> DispatchOutput: - self.dispatch_a(*args, **kwargs) + def dispatch( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ) -> DispatchOutput: + self.dispatch_a(hidden_states, topk_output) + if self._deepep_dispatch_hooks is not None: + self._deepep_dispatch_hooks(self) ret = self.dispatch_b() return ret @@ -425,16 +791,14 @@ def dispatch_b(self): def combine( self, combine_input: CombineInput, - overlap_args: Optional[CombineOverlapArgs] = None, ) -> Tuple: - self.combine_a(combine_input, overlap_args) + self.combine_a(combine_input) ret = self.combine_b() return ret def combine_a( self, combine_input: CombineInput, - overlap_args: Optional[CombineOverlapArgs] = None, ): hidden_states, topk_ids, topk_weights = combine_input self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) @@ -442,7 +806,6 @@ def combine_a( hidden_states=hidden_states, topk_ids=topk_ids, topk_weights=topk_weights, - overlap_args=overlap_args, ) self._combine_intermediate_state = inner_state @@ -458,7 +821,7 @@ def _get_impl(self) -> _MoriEPDispatcherImplBase: if resolved_deepep_mode == DeepEPMode.NORMAL: return self._normal_dispatcher elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: - raise NotImplementedError + return self._low_latency_dispatcher else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") @@ -467,7 +830,31 @@ def _update_stage(self, old_stage, new_stage): self._stage = new_stage def set_quant_config(self, quant_config: dict): + super().set_quant_config(quant_config) if self.deepep_mode.enable_low_latency(): - raise NotImplementedError + self._low_latency_dispatcher.set_quant_config(quant_config) if self.deepep_mode.enable_normal(): self._normal_dispatcher.set_quant_config(quant_config) + + def set_overlap_args( + self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict + ): + super().set_overlap_args(combine_overlap_args, meta_overlap_args) + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher.set_overlap_args( + combine_overlap_args, meta_overlap_args + ) + if self.deepep_mode.enable_normal(): + self._normal_dispatcher.set_overlap_args( + combine_overlap_args, meta_overlap_args + ) + + def clear_overlap_args(self): + super().clear_overlap_args() + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher.clear_overlap_args() + if self.deepep_mode.enable_normal(): + self._normal_dispatcher.clear_overlap_args() + + def register_deepep_dispatch_hook(self, hook): + return self._deepep_dispatch_hooks.register_hook(hook) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 1583dd78804f..d88c6ecb5207 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -951,6 +951,7 @@ def _post_combine_hook( and self.alt_stream is not None ): torch.cuda.current_stream().wait_event(shared_event) + if shared_output is not None: x = shared_output # aiter moe call will handle routed_scaling_factor in the function @@ -1054,11 +1055,18 @@ def op_combine_b(self, state): def op_output(self, state): final_hidden_states = state.pop("hidden_states_after_combine") + if get_moe_a2a_backend().is_mori(): + num_tokens = state.pop("num_tokens") + final_hidden_states = final_hidden_states[:num_tokens] + if (shared_output := state.pop("shared_output")) is not None: x = shared_output - x.add_(final_hidden_states, alpha=self.routed_scaling_factor) + if _use_aiter: + x.add_(final_hidden_states) + else: + x.add_(final_hidden_states, alpha=self.routed_scaling_factor) final_hidden_states = x - else: + elif not _use_aiter: final_hidden_states *= self.routed_scaling_factor state.hidden_states_mlp_output = final_hidden_states @@ -2449,6 +2457,7 @@ def op_comm_prepare_attn( state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch) ) + state.num_tokens = hidden_states.shape[0] state.update( dict( forward_batch=forward_batch, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 88cbcf1d5ab0..be89a0fe70de 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2216,15 +2216,17 @@ def _handle_a2a_moe(self): if self.moe_a2a_backend == "mori": self.ep_size = self.tp_size - self.deepep_mode = "normal" - logger.warning("auto set deepep_mode=`normal` for MORI EP") logger.warning( f"MoRI MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) - assert (self.chunked_prefill_size) <= get_int_env_var( - "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 4096 - ), "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK (default 4096) must be larger or equal to chunked_prefill_size" + # Check chunked prefill for mori + # Skip validation if chunked prefill is disabled (i.e., size <= 0). + # Skip validation if disaggregation mode is decode. + if self.chunked_prefill_size > 0 and self.disaggregation_mode != "decode": + assert (self.chunked_prefill_size) <= get_int_env_var( + "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 4096 + ), "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK (default 4096) must be larger or equal to chunked_prefill_size" def _handle_eplb_and_dispatch(self): if self.enable_eplb and (self.expert_distribution_recorder_mode is None): diff --git a/python/sglang/test/bench_one_batch_server_internal.py b/python/sglang/test/bench_one_batch_server_internal.py index 7f2f62065532..e2f3ed09b48a 100644 --- a/python/sglang/test/bench_one_batch_server_internal.py +++ b/python/sglang/test/bench_one_batch_server_internal.py @@ -749,22 +749,34 @@ def run_benchmark_internal( else: tokenizer = get_tokenizer(tokenizer_path) - # Get token capacity internal_state = server_info.get("internal_states", [{}]) - skip_token_capacity_threshold = ( - internal_state[0].get("memory_usage", {}).get("token_capacity", 1000000000) - ) + dp_size = internal_state[0].get("dp_size", None) or 1 # Get effective max running requests max_running_requests_per_dp = internal_state[0].get( "effective_max_running_requests_per_dp", -1 ) - dp_size = server_info.get("dp_size", None) or 1 + + # Get token capacity + skip_token_capacity_threshold = 0 + + for i in range(dp_size): + skip_token_capacity_threshold += ( + internal_state[i] + .get("memory_usage", {}) + .get("token_capacity", 1000000000) + ) + assert ( max_running_requests_per_dp > 0 ), f"effective_max_running_requests_per_dp is not set, {max_running_requests_per_dp=}" skip_max_running_requests_threshold = max_running_requests_per_dp * dp_size + print(f"{max_running_requests_per_dp=}") + print(f"{dp_size=}") + print(f"{skip_max_running_requests_threshold=}") + print(f"{skip_token_capacity_threshold=}") + # Warmup if not bench_args.skip_warmup: print("=" * 8 + " Warmup Begin " + "=" * 8)