Skip to content

feat: Support flashinfer_cutedsl MoE runner with flashinfer alltoall backend#22669

Open
samuellees wants to merge 1 commit intosgl-project:mainfrom
samuellees:feat/enable-fp4cutedslmoe+a2a
Open

feat: Support flashinfer_cutedsl MoE runner with flashinfer alltoall backend#22669
samuellees wants to merge 1 commit intosgl-project:mainfrom
samuellees:feat/enable-fp4cutedslmoe+a2a

Conversation

@samuellees
Copy link
Copy Markdown
Contributor

@samuellees samuellees commented Apr 13, 2026

Depends on flashinfer-ai/flashinfer#3021 for topk=10.

Summary

  • Enable CuteDSL FP4 MoE runner (--moe-runner-backend flashinfer_cutedsl) to work with FlashInfer one-sided alltoall dispatch (--moe-a2a-backend flashinfer) for DP attention + EP configurations
  • Fixes multiple issues discovered when combining CuteDSL with alltoall in DP mode where idle ranks have 0 tokens

Changes

server_args.py

  • Add "flashinfer" to cutedsl's moe_a2a_backend whitelist
  • Add "flashinfer_cutedsl" to flashinfer a2a's moe_runner_backend whitelist

qwen2_moe.py

  • M=0 guard: Idle DP ranks (0 tokens) skip shared_expert and gate (FP4 GEMM cannot handle empty tensors), but still call self.experts() to participate in alltoall collective
  • Skip TP allreduce in a2a mode: Alltoall combine already aggregates routed expert results across ranks; additional tensor_model_parallel_all_reduce causes NCCL size mismatch between active/idle ranks
  • shared_expert tp_size=1: In a2a mode, shared_expert must not be TP-sharded (same as deepep), because TP allreduce on shared_expert output fails when idle DP ranks have 0 tokens

flashinfer.py (token dispatcher)

  • Remove dummy token mechanism: Pass 0-size tensors directly to alltoall kernel (matching TRT-LLM behavior). The kernel natively handles local_num_tokens=0
  • Add invalid_token_expert_id: Mark padding slots with invalid expert ID so MoE kernels skip them
  • Increase default max dispatch tokens per rank from 1024 to 8192 to support longer sequences

flashinfer_cutedsl.py (MoE runner)

  • Register fused func ("flashinfer", "flashinfer_cutedsl") for flashinfer alltoall dispatcher
  • Scale wrapper max_num_tokens by ep_size in a2a mode, since wrapper receives ep_size * runtime_max tokens after alltoall dispatch

Reproduce

Hardware: B200 × 4

Baseline (cutedsl + DP, no a2a)

python3 -m sglang.launch_server \
  --model-path nvidia/Qwen3.5-397B-A17B-NVFP4 \
  --trust-remote-code --quantization modelopt_fp4 \
  --tp 4 --ep-size 4 --dp 4 \
  --enable-dp-attention --enable-dp-lm-head \
  --moe-runner-backend flashinfer_cutedsl \
  --mem-fraction-static 0.85 --max-running-requests 256 \
  --stream-interval 5 --attention-backend triton \
  --cuda-graph-bs 1 2 4 8 16 32 64 128 256 \
  --disable-radix-cache --disable-flashinfer-autotune \
  --watchdog-timeout 900 --port 8001

This PR (cutedsl + DP + flashinfer a2a)

SGLANG_MOE_NVFP4_DISPATCH=0 python3 -m sglang.launch_server \
  --model-path nvidia/Qwen3.5-397B-A17B-NVFP4 \
  --trust-remote-code --quantization modelopt_fp4 \
  --tp 4 --ep-size 4 --dp 4 \
  --enable-dp-attention --enable-dp-lm-head \
  --moe-runner-backend flashinfer_cutedsl \
  --moe-a2a-backend flashinfer \
  --mem-fraction-static 0.85 --max-running-requests 256 \
  --stream-interval 5 --attention-backend triton \
  --cuda-graph-bs 1 2 4 8 16 32 64 128 256 \
  --disable-radix-cache --disable-flashinfer-autotune \
  --watchdog-timeout 900 --port 8001

GSM8K eval

python3 -m sglang.test.run_eval --port 8001 --eval-name gsm8k \
  --num-examples 1319 --max-tokens 10240 --repeat 1 \
  --num-threads 1319 --num-shots 8 \
  --temperature 0.6 --top-p 0.95 --top-k 20

GPQA eval

python3 -m sglang.test.run_eval --port 8001 --eval-name gpqa \
  --num-examples 198 --max-tokens 81920 --repeat 8 \
  --temperature 0.6 --top-p 0.95 --top-k 20

Accuracy

Config GSM8K (1319 examples) GPQA (198×8)
cutedsl + DP (baseline, no a2a) 97.6%
cutedsl + DP + flashinfer a2a 97.6% 86.2%

Accuracy identical between baseline and a2a.

Test plan

  • Unit test: test/registered/moe/test_cutedsl_a2a.py — GSM8K > 90%
  • Smoke test: The capital of France isParis.
  • GPQA: 86.2% (8 repeats: 0.859-0.874)
  • GSM8K: 97.6% (matches baseline)
  • Hardware: B200 × 4, EP=4 DP=4

…backend

Enable CuteDSL FP4 MoE runner to work with FlashInfer one-sided
alltoall (NVLink) dispatch for DP attention + EP configurations.

Changes:
- server_args.py: Add flashinfer_cutedsl to flashinfer a2a whitelist
- qwen2_moe.py: M=0 guard for idle DP ranks (skip shared_expert/gate),
  skip TP allreduce in a2a mode (combine already aggregates),
  shared_expert tp_size=1 for a2a (matching deepep behavior)
- flashinfer.py: Remove dummy token mechanism, pass 0-size tensors
  directly to alltoall kernel (matching TRT-LLM), add
  invalid_token_expert_id for padding sanitization, increase
  default max dispatch tokens per rank
- flashinfer_cutedsl.py: Register fused func for flashinfer a2a,
  scale wrapper max_num_tokens by ep_size for a2a layout

Tested: Qwen3.5-397B-A17B-NVFP4, B200x4, EP=4 DP=4
- Output verified correct (Paris.)
- GPQA accuracy: 86.2% (8 repeats, 198 examples)
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@samuellees
Copy link
Copy Markdown
Contributor Author

samuellees commented Apr 13, 2026

# Unlike the old dummy-token approach, we pass 0-size tensors directly
# to the alltoall kernel, which handles local_num_tokens=0 natively
# (same as TRT-LLM). The kernel keeps 1 thread alive for sync.
self.has_dummy_token = x.shape[0] == 0
Copy link
Copy Markdown
Contributor

@YAMY1234 YAMY1234 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could rename it, something like is_idle_rank for better clarity?

Comment on lines 151 to 169
self.dummy_x = torch.empty(
(1, hidden_size),
dtype=torch.bfloat16,
device="cuda",
)
# -1 will be ignored by flashinfer cutlass moe
self.dummy_topk_ids = torch.full(
(1, self.router_topk), -1, dtype=torch.int32, device="cuda"
)
# Hack for dispatch with dummy token - will route the dummy token to this rank so it doesn't require any transfer.
self.dummy_topk_ids_current_rank = torch.full(
(1, self.router_topk),
self.ep_rank * self.num_local_experts,
dtype=torch.int32,
device="cuda",
)
self.dummy_topk_weights = torch.zeros(
(1, self.router_topk), dtype=torch.float32, device="cuda"
)
Copy link
Copy Markdown
Contributor

@YAMY1234 YAMY1234 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will these still be needed in other places? since we are removing the usage of these variables in this file

# cutedsl uses input_scale and non-interleaved x_sf.
# These may differ. For now pass through; if kernel crashes,
# need to de-interleave x_sf or disable NVFP4 dispatch.
x_fp4 = hidden_states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Are we able to verify this path? If it is not supported, could we consider forcibly disabling NVFP4_DISPATCH in server_args when cutedsl + flashinfer a2a is detected, and emit a warning?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants