Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
StandardCombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.token_dispatcher.flashinfer import (
FlashinferCombineInput,
FlashinferDispatchOutput,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,6 +253,12 @@ def ensure_cutedsl_wrapper(layer: torch.nn.Module) -> None:
getattr(server_args, "cuda_graph_max_bs", None) or 512,
getattr(server_args, "chunked_prefill_size", None) or 8192,
)
# In DP attention + EP mode, MoE receives tokens from all DP ranks:
# - Standard path (allgather): dp_size * local_tokens
# - A2A path (dispatch): ep_size * runtime_max_per_rank
dp_size = getattr(server_args, "dp_size", 1) or 1
if dp_size > 1:
max_num_tokens *= dp_size
top_k = layer.top_k if layer.top_k is not None else layer.moe_runner_config.top_k
# inference_mode(False) ensures the wrapper's pre-allocated CUDA-graph
# buffers are normal tensors. This call typically happens inside
Expand Down Expand Up @@ -351,3 +361,74 @@ def fused_experts_none_to_flashinfer_cutedsl_fp4(
)

return StandardCombineInput(hidden_states=output)


@register_fused_func("flashinfer", "flashinfer_cutedsl")
def fused_experts_flashinfer_to_flashinfer_cutedsl_fp4(
dispatch_output: FlashinferDispatchOutput,
quant_info: CuteDslFp4MoeQuantInfo,
runner_config: MoeRunnerConfig,
) -> FlashinferCombineInput:
"""CuteDSL fused func for flashinfer alltoall dispatcher.

Two cases depending on whether the dispatcher did FP4 quantization:
- bf16 input (SGLANG_MOE_NVFP4_DISPATCH=0): quantize with cutedsl's scale
- FP4 input (SGLANG_MOE_NVFP4_DISPATCH=1): pass through (same fp4_quantize params)
"""
from flashinfer import fp4_quantize

from sglang.srt.layers.moe.token_dispatcher.flashinfer import (
FlashinferCombineInput,
)
from sglang.srt.layers.moe.topk import TopKOutputChecker

assert runner_config.activation == "silu", "Only silu is supported for CuteDSL MoE."

hidden_states = dispatch_output.hidden_states
x_sf = dispatch_output.hidden_states_scale
topk_output = dispatch_output.topk_output
assert TopKOutputChecker.format_is_standard(topk_output)

topk_ids = topk_output.topk_ids
topk_weights = topk_output.topk_weights
if topk_ids.dtype != torch.int32:
topk_ids = topk_ids.to(torch.int32)

if hidden_states.dtype == torch.bfloat16 or hidden_states.dtype == torch.float16:
# Dispatcher sent bf16 (NVFP4 dispatch disabled) — quantize ourselves
x_fp4, x_sf = fp4_quantize(
hidden_states,
quant_info.input_scale,
sf_vec_size=_FP4_SF_VEC_SIZE,
is_sf_swizzled_layout=False,
)
else:
# Dispatcher already quantized to FP4 — use as-is
# NOTE: dispatcher uses global_scale and interleaves x_sf;
# cutedsl uses input_scale and non-interleaved x_sf.
# These may differ. For now pass through; if kernel crashes,
# need to de-interleave x_sf or disable NVFP4 dispatch.
x_fp4 = hidden_states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Are we able to verify this path? If it is not supported, could we consider forcibly disabling NVFP4_DISPATCH in server_args when cutedsl + flashinfer a2a is detected, and emit a warning?


output = quant_info.wrapper.run(
x=x_fp4,
x_sf=x_sf,
token_selected_experts=topk_ids,
token_final_scales=topk_weights,
w1_weight=quant_info.w13_weight,
w1_weight_sf=quant_info.w13_weight_sf,
w1_alpha=quant_info.w1_alpha,
fc2_input_scale=quant_info.fc2_input_scale,
w2_weight=quant_info.w2_weight,
w2_weight_sf=quant_info.w2_weight_sf,
w2_alpha=quant_info.w2_alpha,
)

# Note: output contains routed expert results; shared_expert is handled separately

# Write into pre-allocated workspace buffer if available
if dispatch_output.moe_output is not None:
dispatch_output.moe_output.copy_(output)
output = dispatch_output.moe_output

return FlashinferCombineInput(hidden_states=output)
23 changes: 8 additions & 15 deletions python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(

# TODO: Can this be a server arg and shared with deepep/mooncakeep?
self.max_num_tokens = (
get_int_env_var("SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 1024)
get_int_env_var("SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 8192)
* self.ep_size
)

Expand Down Expand Up @@ -178,15 +178,11 @@ def dispatch(
topk_ids = topk_output.topk_ids
topk_weights = topk_output.topk_weights

# Handle case where there are no tokens on this DP worker
# moe_a2a.dispatch requires at least one token
self.has_dummy_token = False
if x.shape[0] == 0:
logger.warning("No tokens on this DP worker, using dummy token")
self.has_dummy_token = True
x = self.dummy_x
topk_ids = self.dummy_topk_ids
topk_weights = self.dummy_topk_weights
# Track if this DP worker has no tokens (idle rank).
# Unlike the old dummy-token approach, we pass 0-size tensors directly
# to the alltoall kernel, which handles local_num_tokens=0 natively
# (same as TRT-LLM). The kernel keeps 1 thread alive for sync.
self.has_dummy_token = x.shape[0] == 0
Copy link
Copy Markdown
Contributor

@YAMY1234 YAMY1234 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could rename it, something like is_idle_rank for better clarity?


global_scale = self.quant_config.get("input_global_scale", None)
if global_scale is not None:
Expand Down Expand Up @@ -216,9 +212,10 @@ def dispatch(
else x.shape[0]
)
recv_tensors = self.moe_a2a.dispatch(
self.dummy_topk_ids_current_rank if self.has_dummy_token else topk_ids,
topk_ids,
payloads,
self.runtime_max_tokens_per_rank,
invalid_token_expert_id=self.num_experts,
expert_id_payload_index=expert_id_payload_index,
)
if x_sf is not None:
Expand Down Expand Up @@ -257,10 +254,6 @@ def combine(self, combine_input: FlashinferCombineInput) -> torch.Tensor:
payload_in_workspace=self.payload_in_workspace,
)

# Remove dummy token if it was added in dispatch
if self.has_dummy_token:
hidden_states = hidden_states[1:, :]

del self.runtime_max_tokens_per_rank
del self.has_dummy_token
return hidden_states
20 changes: 13 additions & 7 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def __init__(
prefix=add_prefix("shared_expert", prefix),
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
if (get_moe_a2a_backend().is_deepep()
or get_moe_a2a_backend().is_flashinfer())
else {}
),
)
Expand Down Expand Up @@ -329,9 +330,15 @@ def forward(
if get_moe_a2a_backend().is_deepep():
return self._forward_deepep(hidden_states, forward_batch)

if (
if hidden_states.shape[0] == 0:
# M=0 guard for idle DP ranks: skip shared_experts and gate
# (which crash on empty tensors in FP4 GEMM), but still call
# self.experts() to participate in alltoall collective.
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
final_hidden_states = self.experts(hidden_states, topk_output)
elif (
self.alt_stream is not None
and hidden_states.shape[0] > 0
and get_is_capture_mode()
):
final_hidden_states, shared_output = self.forward_normal_dual_stream(
Expand All @@ -342,19 +349,18 @@ def forward(
final_hidden_states = self._forward_router_experts(hidden_states)

if shared_output is not None:
# In-place add is required to keep final_hidden_states in the
# symmetric memory pool (when --enable-symm-mem is used).
# An out-of-place add would allocate a new tensor outside symm
# memory, breaking subsequent symmetric collective operations.
final_hidden_states += shared_output
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not get_moe_a2a_backend().is_flashinfer()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

# Debug removed - was causing issues during CUDA graph capture

return final_hidden_states.view(num_tokens, hidden_dim)


Expand Down
10 changes: 6 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2744,9 +2744,10 @@ def _handle_moe_kernel_config(self):
assert self.moe_a2a_backend in [
"none",
"deepep",
"flashinfer",
], (
f"flashinfer_cutedsl supports moe_a2a_backend='none' (standard path) "
f"or 'deepep' (DeepEP low-latency path), got '{self.moe_a2a_backend}'."
f"flashinfer_cutedsl supports moe_a2a_backend='none', 'deepep', or 'flashinfer', "
f"got '{self.moe_a2a_backend}'."
)
self.disable_shared_experts_fusion = True
logger.warning(
Expand Down Expand Up @@ -2865,8 +2866,9 @@ def _handle_a2a_moe(self):
"SGLANG_MOE_NVFP4_DISPATCH is set to True for Flashinfer MoE A2A"
)
assert self.moe_runner_backend in [
"flashinfer_cutlass"
], "Flashinfer MoE A2A is only supported with flashinfer_cutlass moe runner backend"
"flashinfer_cutlass",
"flashinfer_cutedsl",
], "Flashinfer MoE A2A is only supported with flashinfer_cutlass or flashinfer_cutedsl moe runner backend"

if self.moe_a2a_backend == "mori":
self.ep_size = self.tp_size
Expand Down
91 changes: 91 additions & 0 deletions test/registered/moe/test_cutedsl_a2a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Test CuteDSL FP4 MoE + FlashInfer alltoall on B200 with DP attention.

Config: Qwen3.5-397B-A17B-NVFP4, B200x4, EP=4 DP=4, cutedsl + flashinfer a2a.
"""

import unittest
from types import SimpleNamespace

import torch

from sglang.srt.utils import kill_process_tree
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)

register_cuda_ci(est_time=600, suite="stage-c-test-4-gpu-b200")

MODEL = "nvidia/Qwen3.5-397B-A17B-NVFP4"

SKIP_TEST = torch.cuda.get_device_capability() < (10, 0)
SKIP_REASON = "Requires Blackwell (B200, sm_100a) or above."


@unittest.skipIf(SKIP_TEST, SKIP_REASON)
class TestCuteDslFlashinferA2A(CustomTestCase):
"""CuteDSL FP4 MoE + FlashInfer one-sided alltoall + DP4 EP4 on B200."""

@classmethod
def setUpClass(cls):
cls.model = MODEL
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 3,
other_args=[
"--trust-remote-code",
"--quantization",
"modelopt_fp4",
"--tp",
"4",
"--ep-size",
"4",
"--dp",
"4",
"--enable-dp-attention",
"--enable-dp-lm-head",
"--moe-runner-backend",
"flashinfer_cutedsl",
"--moe-a2a-backend",
"flashinfer",
"--disable-radix-cache",
"--disable-flashinfer-autotune",
"--watchdog-timeout",
"900",
],
env={
"FLASHINFER_DISABLE_VERSION_CHECK": "1",
"SGLANG_MOE_NVFP4_DISPATCH": "0",
},
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
base_url=self.base_url,
eval_name="gsm8k",
num_examples=1319,
max_tokens=10240,
repeat=1,
num_threads=1319,
num_shots=8,
temperature=0.6,
top_p=0.95,
top_k=20,
)
metrics = run_eval(args)
print(metrics)
self.assertGreater(metrics["score"], 0.90)


if __name__ == "__main__":
unittest.main()
Loading