diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py index a9213d83bfee..592bdc5836a6 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py @@ -18,6 +18,10 @@ StandardCombineInput, StandardDispatchOutput, ) + from sglang.srt.layers.moe.token_dispatcher.flashinfer import ( + FlashinferCombineInput, + FlashinferDispatchOutput, + ) logger = logging.getLogger(__name__) @@ -249,6 +253,12 @@ def ensure_cutedsl_wrapper(layer: torch.nn.Module) -> None: getattr(server_args, "cuda_graph_max_bs", None) or 512, getattr(server_args, "chunked_prefill_size", None) or 8192, ) + # In DP attention + EP mode, MoE receives tokens from all DP ranks: + # - Standard path (allgather): dp_size * local_tokens + # - A2A path (dispatch): ep_size * runtime_max_per_rank + dp_size = getattr(server_args, "dp_size", 1) or 1 + if dp_size > 1: + max_num_tokens *= dp_size top_k = layer.top_k if layer.top_k is not None else layer.moe_runner_config.top_k # inference_mode(False) ensures the wrapper's pre-allocated CUDA-graph # buffers are normal tensors. This call typically happens inside @@ -351,3 +361,74 @@ def fused_experts_none_to_flashinfer_cutedsl_fp4( ) return StandardCombineInput(hidden_states=output) + + +@register_fused_func("flashinfer", "flashinfer_cutedsl") +def fused_experts_flashinfer_to_flashinfer_cutedsl_fp4( + dispatch_output: FlashinferDispatchOutput, + quant_info: CuteDslFp4MoeQuantInfo, + runner_config: MoeRunnerConfig, +) -> FlashinferCombineInput: + """CuteDSL fused func for flashinfer alltoall dispatcher. + + Two cases depending on whether the dispatcher did FP4 quantization: + - bf16 input (SGLANG_MOE_NVFP4_DISPATCH=0): quantize with cutedsl's scale + - FP4 input (SGLANG_MOE_NVFP4_DISPATCH=1): pass through (same fp4_quantize params) + """ + from flashinfer import fp4_quantize + + from sglang.srt.layers.moe.token_dispatcher.flashinfer import ( + FlashinferCombineInput, + ) + from sglang.srt.layers.moe.topk import TopKOutputChecker + + assert runner_config.activation == "silu", "Only silu is supported for CuteDSL MoE." + + hidden_states = dispatch_output.hidden_states + x_sf = dispatch_output.hidden_states_scale + topk_output = dispatch_output.topk_output + assert TopKOutputChecker.format_is_standard(topk_output) + + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + if topk_ids.dtype != torch.int32: + topk_ids = topk_ids.to(torch.int32) + + if hidden_states.dtype == torch.bfloat16 or hidden_states.dtype == torch.float16: + # Dispatcher sent bf16 (NVFP4 dispatch disabled) — quantize ourselves + x_fp4, x_sf = fp4_quantize( + hidden_states, + quant_info.input_scale, + sf_vec_size=_FP4_SF_VEC_SIZE, + is_sf_swizzled_layout=False, + ) + else: + # Dispatcher already quantized to FP4 — use as-is + # NOTE: dispatcher uses global_scale and interleaves x_sf; + # cutedsl uses input_scale and non-interleaved x_sf. + # These may differ. For now pass through; if kernel crashes, + # need to de-interleave x_sf or disable NVFP4 dispatch. + x_fp4 = hidden_states + + output = quant_info.wrapper.run( + x=x_fp4, + x_sf=x_sf, + token_selected_experts=topk_ids, + token_final_scales=topk_weights, + w1_weight=quant_info.w13_weight, + w1_weight_sf=quant_info.w13_weight_sf, + w1_alpha=quant_info.w1_alpha, + fc2_input_scale=quant_info.fc2_input_scale, + w2_weight=quant_info.w2_weight, + w2_weight_sf=quant_info.w2_weight_sf, + w2_alpha=quant_info.w2_alpha, + ) + + # Note: output contains routed expert results; shared_expert is handled separately + + # Write into pre-allocated workspace buffer if available + if dispatch_output.moe_output is not None: + dispatch_output.moe_output.copy_(output) + output = dispatch_output.moe_output + + return FlashinferCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py b/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py index 7b5080bb860f..5d8c7798742a 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py @@ -100,7 +100,7 @@ def __init__( # TODO: Can this be a server arg and shared with deepep/mooncakeep? self.max_num_tokens = ( - get_int_env_var("SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 1024) + get_int_env_var("SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 8192) * self.ep_size ) @@ -178,15 +178,11 @@ def dispatch( topk_ids = topk_output.topk_ids topk_weights = topk_output.topk_weights - # Handle case where there are no tokens on this DP worker - # moe_a2a.dispatch requires at least one token - self.has_dummy_token = False - if x.shape[0] == 0: - logger.warning("No tokens on this DP worker, using dummy token") - self.has_dummy_token = True - x = self.dummy_x - topk_ids = self.dummy_topk_ids - topk_weights = self.dummy_topk_weights + # Track if this DP worker has no tokens (idle rank). + # Unlike the old dummy-token approach, we pass 0-size tensors directly + # to the alltoall kernel, which handles local_num_tokens=0 natively + # (same as TRT-LLM). The kernel keeps 1 thread alive for sync. + self.has_dummy_token = x.shape[0] == 0 global_scale = self.quant_config.get("input_global_scale", None) if global_scale is not None: @@ -216,9 +212,10 @@ def dispatch( else x.shape[0] ) recv_tensors = self.moe_a2a.dispatch( - self.dummy_topk_ids_current_rank if self.has_dummy_token else topk_ids, + topk_ids, payloads, self.runtime_max_tokens_per_rank, + invalid_token_expert_id=self.num_experts, expert_id_payload_index=expert_id_payload_index, ) if x_sf is not None: @@ -257,10 +254,6 @@ def combine(self, combine_input: FlashinferCombineInput) -> torch.Tensor: payload_in_workspace=self.payload_in_workspace, ) - # Remove dummy token if it was added in dispatch - if self.has_dummy_token: - hidden_states = hidden_states[1:, :] - del self.runtime_max_tokens_per_rank del self.has_dummy_token return hidden_states diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 8f3475a24323..d1add59e9c7b 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -208,7 +208,8 @@ def __init__( prefix=add_prefix("shared_expert", prefix), **( dict(tp_rank=0, tp_size=1) - if get_moe_a2a_backend().is_deepep() + if (get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_flashinfer()) else {} ), ) @@ -329,9 +330,15 @@ def forward( if get_moe_a2a_backend().is_deepep(): return self._forward_deepep(hidden_states, forward_batch) - if ( + if hidden_states.shape[0] == 0: + # M=0 guard for idle DP ranks: skip shared_experts and gate + # (which crash on empty tensors in FP4 GEMM), but still call + # self.experts() to participate in alltoall collective. + shared_output = None + topk_output = self.topk.empty_topk_output(hidden_states.device) + final_hidden_states = self.experts(hidden_states, topk_output) + elif ( self.alt_stream is not None - and hidden_states.shape[0] > 0 and get_is_capture_mode() ): final_hidden_states, shared_output = self.forward_normal_dual_stream( @@ -342,19 +349,18 @@ def forward( final_hidden_states = self._forward_router_experts(hidden_states) if shared_output is not None: - # In-place add is required to keep final_hidden_states in the - # symmetric memory pool (when --enable-symm-mem is used). - # An out-of-place add would allocate a new tensor outside symm - # memory, breaking subsequent symmetric collective operations. final_hidden_states += shared_output if ( self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter and not should_use_flashinfer_cutlass_moe_fp4_allgather() + and not get_moe_a2a_backend().is_flashinfer() ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + # Debug removed - was causing issues during CUDA graph capture + return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b6c5676cb20a..a3aaf0468282 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2744,9 +2744,10 @@ def _handle_moe_kernel_config(self): assert self.moe_a2a_backend in [ "none", "deepep", + "flashinfer", ], ( - f"flashinfer_cutedsl supports moe_a2a_backend='none' (standard path) " - f"or 'deepep' (DeepEP low-latency path), got '{self.moe_a2a_backend}'." + f"flashinfer_cutedsl supports moe_a2a_backend='none', 'deepep', or 'flashinfer', " + f"got '{self.moe_a2a_backend}'." ) self.disable_shared_experts_fusion = True logger.warning( @@ -2865,8 +2866,9 @@ def _handle_a2a_moe(self): "SGLANG_MOE_NVFP4_DISPATCH is set to True for Flashinfer MoE A2A" ) assert self.moe_runner_backend in [ - "flashinfer_cutlass" - ], "Flashinfer MoE A2A is only supported with flashinfer_cutlass moe runner backend" + "flashinfer_cutlass", + "flashinfer_cutedsl", + ], "Flashinfer MoE A2A is only supported with flashinfer_cutlass or flashinfer_cutedsl moe runner backend" if self.moe_a2a_backend == "mori": self.ep_size = self.tp_size diff --git a/test/registered/moe/test_cutedsl_a2a.py b/test/registered/moe/test_cutedsl_a2a.py new file mode 100644 index 000000000000..e0919539190e --- /dev/null +++ b/test/registered/moe/test_cutedsl_a2a.py @@ -0,0 +1,91 @@ +"""Test CuteDSL FP4 MoE + FlashInfer alltoall on B200 with DP attention. + +Config: Qwen3.5-397B-A17B-NVFP4, B200x4, EP=4 DP=4, cutedsl + flashinfer a2a. +""" + +import unittest +from types import SimpleNamespace + +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=600, suite="stage-c-test-4-gpu-b200") + +MODEL = "nvidia/Qwen3.5-397B-A17B-NVFP4" + +SKIP_TEST = torch.cuda.get_device_capability() < (10, 0) +SKIP_REASON = "Requires Blackwell (B200, sm_100a) or above." + + +@unittest.skipIf(SKIP_TEST, SKIP_REASON) +class TestCuteDslFlashinferA2A(CustomTestCase): + """CuteDSL FP4 MoE + FlashInfer one-sided alltoall + DP4 EP4 on B200.""" + + @classmethod + def setUpClass(cls): + cls.model = MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 3, + other_args=[ + "--trust-remote-code", + "--quantization", + "modelopt_fp4", + "--tp", + "4", + "--ep-size", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--enable-dp-lm-head", + "--moe-runner-backend", + "flashinfer_cutedsl", + "--moe-a2a-backend", + "flashinfer", + "--disable-radix-cache", + "--disable-flashinfer-autotune", + "--watchdog-timeout", + "900", + ], + env={ + "FLASHINFER_DISABLE_VERSION_CHECK": "1", + "SGLANG_MOE_NVFP4_DISPATCH": "0", + }, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + base_url=self.base_url, + eval_name="gsm8k", + num_examples=1319, + max_tokens=10240, + repeat=1, + num_threads=1319, + num_shots=8, + temperature=0.6, + top_p=0.95, + top_k=20, + ) + metrics = run_eval(args) + print(metrics) + self.assertGreater(metrics["score"], 0.90) + + +if __name__ == "__main__": + unittest.main()