diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu index a59204ef9a..c8caf3def9 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu @@ -246,10 +246,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { Optional fc1_expert_biases, TensorView fc2_expert_weights, Optional fc2_expert_biases, Optional> quant_scales, Optional input_sf, Optional swiglu_alpha, - Optional swiglu_beta, Optional swiglu_limit, int64_t tp_size, - int64_t tp_rank, int64_t ep_size, int64_t ep_rank, int64_t cluster_size, - int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode, - Optional> profile_ids, bool enable_pdl, + Optional swiglu_beta, Optional swiglu_limit, + bool swizzled_input_sf, int64_t tp_size, int64_t tp_rank, int64_t ep_size, + int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, + bool min_latency_mode, Optional> profile_ids, bool enable_pdl, ActivationType base_activation_type = ActivationType::Swiglu) { std::lock_guard lock(mMutex); @@ -382,7 +382,6 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; // HACK Define default values for parameters we don't have good values for - bool const swizzled_input_sf = true; // Assume input_sf is swizzled by default int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default bool const use_lora = false; // No lora support yet #ifdef USING_OSS_CUTLASS_MOE_GEMM @@ -428,12 +427,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { Optional fc2_expert_biases, Optional> quant_scales, Optional input_sf, Optional swiglu_alpha, Optional swiglu_beta, - Optional swiglu_limit, TensorView num_active_experts_per_node, - TensorView experts_to_token_score, TensorView active_expert_global_ids, - int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, - int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, - bool min_latency_mode, Optional> profile_ids, - bool enable_pdl, + Optional swiglu_limit, bool swizzled_input_sf, + TensorView num_active_experts_per_node, TensorView experts_to_token_score, + TensorView active_expert_global_ids, int64_t tp_size, int64_t tp_rank, + int64_t ep_size, int64_t ep_rank, int64_t cluster_size, + int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode, + Optional> profile_ids, bool enable_pdl, ActivationType base_activation_type = ActivationType::Swiglu) { std::lock_guard lock(mMutex); @@ -569,13 +568,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; // HACK Define default values for parameters we don't have good values for - bool const swizzled_input_sf_ml = true; // Assume input_sf is swizzled by default int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default bool const use_lora_ml = false; // No lora support yet #ifdef USING_OSS_CUTLASS_MOE_GEMM mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, - swizzled_input_sf_ml, reinterpret_cast(token_selected_experts.data_ptr()), + swizzled_input_sf, reinterpret_cast(token_selected_experts.data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().data_ptr()) : nullptr, @@ -592,7 +590,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { #else mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, - swizzled_input_sf_ml, reinterpret_cast(token_selected_experts.data_ptr()), + swizzled_input_sf, reinterpret_cast(token_selected_experts.data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().data_ptr()) : nullptr, @@ -730,15 +728,15 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { Optional fc2_expert_biases, Optional> quant_scales, Optional input_sf, Optional swiglu_alpha, Optional swiglu_beta, Optional swiglu_limit, - int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, - int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, + bool swizzled_input_sf, int64_t tp_size, int64_t tp_rank, int64_t ep_size, + int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode, Optional> profile_ids, bool enable_pdl, int64_t base_activation_type) { runMoe(output, input, token_selected_experts, token_final_scales, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases, quant_scales, input_sf, - swiglu_alpha, swiglu_beta, swiglu_limit, tp_size, tp_rank, ep_size, ep_rank, - cluster_size, cluster_rank, enable_alltoall, min_latency_mode, profile_ids, - enable_pdl, static_cast(base_activation_type)); + swiglu_alpha, swiglu_beta, swiglu_limit, swizzled_input_sf, tp_size, tp_rank, + ep_size, ep_rank, cluster_size, cluster_rank, enable_alltoall, min_latency_mode, + profile_ids, enable_pdl, static_cast(base_activation_type)); }); } else if (name == "run_moe_min_latency") { return Function::FromTyped( @@ -748,18 +746,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { Optional fc2_expert_biases, Optional> quant_scales, Optional input_sf, Optional swiglu_alpha, Optional swiglu_beta, Optional swiglu_limit, - TensorView num_active_experts_per_node, TensorView experts_to_token_score, - TensorView active_expert_global_ids, int64_t tp_size, int64_t tp_rank, - int64_t ep_size, int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank, - bool enable_alltoall, bool min_latency_mode, Optional> profile_ids, - bool enable_pdl, int64_t base_activation_type) { - runMoeMinLantency( - output, input, token_selected_experts, token_final_scales, fc1_expert_weights, - fc1_expert_biases, fc2_expert_weights, fc2_expert_biases, quant_scales, input_sf, - swiglu_alpha, swiglu_beta, swiglu_limit, num_active_experts_per_node, - experts_to_token_score, active_expert_global_ids, tp_size, tp_rank, ep_size, - ep_rank, cluster_size, cluster_rank, enable_alltoall, min_latency_mode, profile_ids, - enable_pdl, static_cast(base_activation_type)); + bool swizzled_input_sf, TensorView num_active_experts_per_node, + TensorView experts_to_token_score, TensorView active_expert_global_ids, + int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, + int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, + bool min_latency_mode, Optional> profile_ids, bool enable_pdl, + int64_t base_activation_type) { + runMoeMinLantency(output, input, token_selected_experts, token_final_scales, + fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, + fc2_expert_biases, quant_scales, input_sf, swiglu_alpha, swiglu_beta, + swiglu_limit, swizzled_input_sf, num_active_experts_per_node, + experts_to_token_score, active_expert_global_ids, tp_size, tp_rank, + ep_size, ep_rank, cluster_size, cluster_rank, enable_alltoall, + min_latency_mode, profile_ids, enable_pdl, + static_cast(base_activation_type)); }); } else { return Function(nullptr); diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index c0fa47962c..d5fd1826a1 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -385,6 +385,7 @@ def cutlass_fused_moe( swiglu_alpha: Optional[torch.Tensor] = None, swiglu_beta: Optional[torch.Tensor] = None, swiglu_limit: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, @@ -501,6 +502,7 @@ def cutlass_fused_moe( swiglu_alpha, swiglu_beta, swiglu_limit, + swizzled_input_sf, *min_latency_output, tp_size, tp_rank, @@ -542,6 +544,7 @@ def _fake_cutlass_fused_moe( swiglu_alpha: Optional[torch.Tensor] = None, swiglu_beta: Optional[torch.Tensor] = None, swiglu_limit: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, @@ -612,6 +615,7 @@ def cutlass_fused_moe( tune_max_num_tokens: int = 8192, enable_pdl: Optional[bool] = None, activation_type: ActivationType = ActivationType.Swiglu, + swizzled_input_sf: bool = True, ) -> torch.Tensor: """Compute a Mixture of Experts (MoE) layer using CUTLASS backend. @@ -722,6 +726,12 @@ def cutlass_fused_moe( activation_type: ActivationType = ActivationType.Swiglu Activation to apply on for GEMM1, note that Relu2 means non-gated GEMM1 + swizzled_input_sf : bool = True + Whether the input scaling factor (input_sf) is in swizzled layout. Defaults to True. + Set to False when input_sf is in linear layout, e.g. after FP4 allgather/alltoall + communication where the scaling factors are received in linear (non-swizzled) format. + Only relevant when input_sf is not None. + Returns ------- out: torch.Tensor @@ -788,6 +798,7 @@ def cutlass_fused_moe( swiglu_alpha, swiglu_beta, swiglu_limit, + swizzled_input_sf, tp_size, tp_rank, ep_size, diff --git a/tests/moe/test_trtllm_cutlass_fused_moe.py b/tests/moe/test_trtllm_cutlass_fused_moe.py index 9c8a547583..ee2952dbc8 100644 --- a/tests/moe/test_trtllm_cutlass_fused_moe.py +++ b/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -22,6 +22,7 @@ from torch.nn import functional as F import flashinfer.fused_moe as fused_moe +from flashinfer.utils import is_sm100a_supported from flashinfer import ( autotune, fp4_quantize, @@ -1796,5 +1797,128 @@ def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor: torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1) +@pytest.mark.skipif( + not is_sm100a_supported(torch.device("cuda")), + reason="NVFP4 is only supported on SM100+", +) +def test_moe_nvfp4_unswizzled_input_sf(): + """Test cutlass_fused_moe with swizzled_input_sf=False (linear layout input_sf). + + In FP4 allgather/alltoall scenarios, the input scaling factors received after + communication are in linear layout (not swizzled). This test verifies that + passing swizzled_input_sf=False produces the same output as first swizzling + the input_sf and passing swizzled_input_sf=True. + """ + torch.manual_seed(42) + batch_size = 32 + hidden_size = 128 + intermediate_size = 128 + num_experts = 4 + top_k = 2 + otype = torch.float16 + quant_blocksize = 16 + + def round_up(x, y): + return (x + y - 1) // y * y + + e = num_experts + m = batch_size + n = intermediate_size + k = hidden_size + w1_n = 2 * n # Swiglu + + w1 = torch.randn((e, w1_n, k), device="cuda", dtype=otype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 + + sf_w1_2n = round_up(w1_n, 128) + sf_w1_k = round_up(k // quant_blocksize, 4) + sf_w2_k = round_up(k, 128) + sf_w2_n = round_up(n // quant_blocksize, 4) + + w1_blockscale = torch.empty( + (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + ) + w2_blockscale = torch.empty( + (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn + ) + w1_q = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8) + w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) + w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) + + for expert in range(e): + w1_amax = torch.abs(w1[expert]).max().to(torch.float32) + w2_amax = torch.abs(w2[expert]).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert]) + w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert]) + + x = torch.randn(m, k, dtype=otype).cuda() + a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + router_logits = torch.randn(m, e, dtype=otype).cuda() + routing_weights, selected_experts = compute_routing(router_logits, top_k) + + quant_scales = [ + a1_gs, + w1_blockscale.view(torch.int32), + 1.0 / (a1_gs * w1_gs), + a2_gs, + w2_blockscale.view(torch.int32), + 1.0 / (a2_gs * w2_gs), + ] + + # Quantize input with swizzled layout (default) + hidden_states_swizzled, input_sf_swizzled = fp4_quantize( + x, a1_gs, is_sf_swizzled_layout=True + ) + # Quantize input with linear layout (as received after allgather/alltoall) + hidden_states_linear, input_sf_linear = fp4_quantize( + x, a1_gs, is_sf_swizzled_layout=False + ) + + # Both quantizations should produce the same quantized values + assert torch.equal(hidden_states_swizzled, hidden_states_linear) + # The SF buffers must differ — otherwise the test would pass trivially + # even if fp4_quantize ignored is_sf_swizzled_layout + assert not torch.equal(input_sf_swizzled, input_sf_linear), ( + "input_sf_swizzled and input_sf_linear should have different layouts" + ) + + output_swizzled = torch.zeros(m, k, dtype=otype, device="cuda") + output_linear = torch.zeros(m, k, dtype=otype, device="cuda") + + # swizzled_input_sf=True with swizzled input_sf (default behavior) + fused_moe.cutlass_fused_moe( + hidden_states_swizzled, + selected_experts.to(torch.int), + routing_weights, + w1_q.contiguous().view(torch.long), + w2_q.contiguous().view(torch.long), + otype, + quant_scales=quant_scales, + input_sf=input_sf_swizzled, + swizzled_input_sf=True, + output=output_swizzled, + ) + + # swizzled_input_sf=False with linear input_sf (post-allgather scenario) + fused_moe.cutlass_fused_moe( + hidden_states_linear, + selected_experts.to(torch.int), + routing_weights, + w1_q.contiguous().view(torch.long), + w2_q.contiguous().view(torch.long), + otype, + quant_scales=quant_scales, + input_sf=input_sf_linear, + swizzled_input_sf=False, + output=output_linear, + ) + + torch.testing.assert_close(output_swizzled, output_linear, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": pytest.main([__file__, "-v"])