Skip to content

fix: skip w13 swap in FlashInfer CUTLASS path for non-gated MoE#17

Open
djmmoss wants to merge 1 commit intoTomerBN-Nvidia:ultra-rl-v0.17from
djmmoss:dmoss/fix-flashinfer-cutlass-non-gated-moe
Open

fix: skip w13 swap in FlashInfer CUTLASS path for non-gated MoE#17
djmmoss wants to merge 1 commit intoTomerBN-Nvidia:ultra-rl-v0.17from
djmmoss:dmoss/fix-flashinfer-cutlass-non-gated-moe

Conversation

@djmmoss
Copy link
Copy Markdown

@djmmoss djmmoss commented Apr 25, 2026

Summary

convert_to_unquantized_kernel_format() in vllm/model_executor/layers/fused_moe/oracle/unquantized.py unconditionally calls swap_w13_to_w31(layer.w13_weight) when the FLASHINFER_CUTLASS backend is selected. That helper assumes the second-to-last dim of w13 is [w1; w3] (the gate/up projections in a SwiGLU-style gated MoE) and flips the two halves so the kernel sees [w3; w1].

For non-gated MoE (is_act_and_mul=False), w13_weight has shape [num_experts, intermediate_size, hidden_size] — there is only one logical weight, no halves to swap. swap_w13_to_w31 reshapes with // 2 and flips, silently splitting the single weight matrix down the middle and swapping the halves, which scrambles the weights.

The model still runs, but produces nonsense output: at temperature=0, generation degenerates into prompt repetition, finish_reason is always length, and content is empty.

This patch gates the swap on layer.moe_config.is_act_and_mul. For non-gated layouts, the weight is left untouched.

Test plan

  • Reproduced the issue with a non-gated MoE model under VLLM_USE_FLASHINFER_MOE_FP16=1 VLLM_FLASHINFER_MOE_BACKEND=throughput (CUTLASS path) — 0/10 prompts produced coherent output before the fix.
  • After the fix, CUTLASS produces correct, coherent reasoning that matches the Triton baseline on math/algebra/sequence/deduction prompts at temperature=0. Numerical drift causes minor wording differences but final answers match.
  • Throughput benefit preserved: +19.7% output throughput / -17.9% mean TPOT vs Triton on a 500-prompt input=8000/output=1000 benchmark (TP=8, multi-node).

🤖 Generated with Claude Code

convert_to_unquantized_kernel_format() unconditionally calls
swap_w13_to_w31(layer.w13_weight) when the FLASHINFER_CUTLASS backend
is selected. That helper assumes the second-to-last dim of w13 is
[w1; w3] and flips the two halves.

For non-gated MoE (is_act_and_mul=False, e.g. NemotronH), w13_weight has
shape [num_experts, intermediate_size, hidden_size] — there is only one
logical weight, no halves to swap. Reshaping with `// 2` and flipping
silently scrambles the weights, producing nonsense output (the model
echoes the prompt or hits length limit without producing real content).

Symptom on NemotronH ultra_v3: at temperature=0, the model's reasoning
becomes pure prompt repetition; finish_reason is always 'length' and
content is empty.

Fix: gate the swap on layer.moe_config.is_act_and_mul. For non-gated
layouts, leave the weight untouched.

Verified: with the fix, CUTLASS produces correct, coherent reasoning
that matches Triton on math/algebra/sequence/deduction prompts.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@djmmoss djmmoss force-pushed the dmoss/fix-flashinfer-cutlass-non-gated-moe branch from 68d108d to 37fbdbf Compare April 25, 2026 23:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant