Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
Optional<TensorView> fc1_expert_biases, TensorView fc2_expert_weights,
Optional<TensorView> fc2_expert_biases, Optional<Array<Tensor>> quant_scales,
Optional<TensorView> input_sf, Optional<TensorView> swiglu_alpha,
Optional<TensorView> swiglu_beta, Optional<TensorView> swiglu_limit, int64_t tp_size,
int64_t tp_rank, int64_t ep_size, int64_t ep_rank, int64_t cluster_size,
int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode,
Optional<Array<int64_t>> profile_ids, bool enable_pdl,
Optional<TensorView> swiglu_beta, Optional<TensorView> swiglu_limit,
bool swizzled_input_sf, int64_t tp_size, int64_t tp_rank, int64_t ep_size,
int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall,
bool min_latency_mode, Optional<Array<int64_t>> profile_ids, bool enable_pdl,
ActivationType base_activation_type = ActivationType::Swiglu) {
std::lock_guard<std::mutex> lock(mMutex);

Expand Down Expand Up @@ -382,7 +382,6 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
// TODO: support lora in the future
::tensorrt_llm::kernels::LoraParams lora_params{};
// HACK Define default values for parameters we don't have good values for
bool const swizzled_input_sf = true; // Assume input_sf is swizzled by default
int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default
bool const use_lora = false; // No lora support yet
#ifdef USING_OSS_CUTLASS_MOE_GEMM
Expand Down Expand Up @@ -428,12 +427,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
Optional<TensorView> fc2_expert_biases,
Optional<Array<Tensor>> quant_scales, Optional<TensorView> input_sf,
Optional<TensorView> swiglu_alpha, Optional<TensorView> swiglu_beta,
Optional<TensorView> swiglu_limit, TensorView num_active_experts_per_node,
TensorView experts_to_token_score, TensorView active_expert_global_ids,
int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank,
int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall,
bool min_latency_mode, Optional<Array<int64_t>> profile_ids,
bool enable_pdl,
Optional<TensorView> swiglu_limit, bool swizzled_input_sf,
TensorView num_active_experts_per_node, TensorView experts_to_token_score,
TensorView active_expert_global_ids, int64_t tp_size, int64_t tp_rank,
int64_t ep_size, int64_t ep_rank, int64_t cluster_size,
int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode,
Optional<Array<int64_t>> profile_ids, bool enable_pdl,
ActivationType base_activation_type = ActivationType::Swiglu) {
std::lock_guard<std::mutex> lock(mMutex);

Expand Down Expand Up @@ -569,13 +568,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
// TODO: support lora in the future
::tensorrt_llm::kernels::LoraParams lora_params{};
// HACK Define default values for parameters we don't have good values for
bool const swizzled_input_sf_ml = true; // Assume input_sf is swizzled by default
int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default
bool const use_lora_ml = false; // No lora support yet
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mKernelRunner->runMoe(
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
swizzled_input_sf_ml, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
: nullptr,
Expand All @@ -592,7 +590,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
#else
mKernelRunner->runMoe(
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
swizzled_input_sf_ml, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
: nullptr,
Expand Down Expand Up @@ -730,15 +728,15 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
Optional<TensorView> fc2_expert_biases, Optional<Array<Tensor>> quant_scales,
Optional<TensorView> input_sf, Optional<TensorView> swiglu_alpha,
Optional<TensorView> swiglu_beta, Optional<TensorView> swiglu_limit,
int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank,
int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall,
bool swizzled_input_sf, int64_t tp_size, int64_t tp_rank, int64_t ep_size,
int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall,
bool min_latency_mode, Optional<Array<int64_t>> profile_ids, bool enable_pdl,
int64_t base_activation_type) {
runMoe(output, input, token_selected_experts, token_final_scales, fc1_expert_weights,
fc1_expert_biases, fc2_expert_weights, fc2_expert_biases, quant_scales, input_sf,
swiglu_alpha, swiglu_beta, swiglu_limit, tp_size, tp_rank, ep_size, ep_rank,
cluster_size, cluster_rank, enable_alltoall, min_latency_mode, profile_ids,
enable_pdl, static_cast<ActivationType>(base_activation_type));
swiglu_alpha, swiglu_beta, swiglu_limit, swizzled_input_sf, tp_size, tp_rank,
ep_size, ep_rank, cluster_size, cluster_rank, enable_alltoall, min_latency_mode,
profile_ids, enable_pdl, static_cast<ActivationType>(base_activation_type));
});
} else if (name == "run_moe_min_latency") {
return Function::FromTyped(
Expand All @@ -748,18 +746,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
Optional<TensorView> fc2_expert_biases, Optional<Array<Tensor>> quant_scales,
Optional<TensorView> input_sf, Optional<TensorView> swiglu_alpha,
Optional<TensorView> swiglu_beta, Optional<TensorView> swiglu_limit,
TensorView num_active_experts_per_node, TensorView experts_to_token_score,
TensorView active_expert_global_ids, int64_t tp_size, int64_t tp_rank,
int64_t ep_size, int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank,
bool enable_alltoall, bool min_latency_mode, Optional<Array<int64_t>> profile_ids,
bool enable_pdl, int64_t base_activation_type) {
runMoeMinLantency(
output, input, token_selected_experts, token_final_scales, fc1_expert_weights,
fc1_expert_biases, fc2_expert_weights, fc2_expert_biases, quant_scales, input_sf,
swiglu_alpha, swiglu_beta, swiglu_limit, num_active_experts_per_node,
experts_to_token_score, active_expert_global_ids, tp_size, tp_rank, ep_size,
ep_rank, cluster_size, cluster_rank, enable_alltoall, min_latency_mode, profile_ids,
enable_pdl, static_cast<ActivationType>(base_activation_type));
bool swizzled_input_sf, TensorView num_active_experts_per_node,
TensorView experts_to_token_score, TensorView active_expert_global_ids,
int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank,
int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall,
bool min_latency_mode, Optional<Array<int64_t>> profile_ids, bool enable_pdl,
int64_t base_activation_type) {
runMoeMinLantency(output, input, token_selected_experts, token_final_scales,
fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
fc2_expert_biases, quant_scales, input_sf, swiglu_alpha, swiglu_beta,
swiglu_limit, swizzled_input_sf, num_active_experts_per_node,
experts_to_token_score, active_expert_global_ids, tp_size, tp_rank,
ep_size, ep_rank, cluster_size, cluster_rank, enable_alltoall,
min_latency_mode, profile_ids, enable_pdl,
static_cast<ActivationType>(base_activation_type));
});
} else {
return Function(nullptr);
Expand Down
11 changes: 11 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def cutlass_fused_moe(
swiglu_alpha: Optional[torch.Tensor] = None,
swiglu_beta: Optional[torch.Tensor] = None,
swiglu_limit: Optional[torch.Tensor] = None,
swizzled_input_sf: bool = True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for cutlass_fused_moe (starting on line 741) is missing documentation for the new swizzled_input_sf parameter. Please add it for completeness and to inform users about this new option.

tp_size: int = 1,
tp_rank: int = 0,
ep_size: int = 1,
Expand Down Expand Up @@ -501,6 +502,7 @@ def cutlass_fused_moe(
swiglu_alpha,
swiglu_beta,
swiglu_limit,
swizzled_input_sf,
*min_latency_output,
tp_size,
tp_rank,
Expand Down Expand Up @@ -542,6 +544,7 @@ def _fake_cutlass_fused_moe(
swiglu_alpha: Optional[torch.Tensor] = None,
swiglu_beta: Optional[torch.Tensor] = None,
swiglu_limit: Optional[torch.Tensor] = None,
swizzled_input_sf: bool = True,
tp_size: int = 1,
tp_rank: int = 0,
ep_size: int = 1,
Expand Down Expand Up @@ -612,6 +615,7 @@ def cutlass_fused_moe(
tune_max_num_tokens: int = 8192,
enable_pdl: Optional[bool] = None,
activation_type: ActivationType = ActivationType.Swiglu,
swizzled_input_sf: bool = True,
) -> torch.Tensor:
"""Compute a Mixture of Experts (MoE) layer using CUTLASS backend.

Expand Down Expand Up @@ -722,6 +726,12 @@ def cutlass_fused_moe(
activation_type: ActivationType = ActivationType.Swiglu
Activation to apply on for GEMM1, note that Relu2 means non-gated GEMM1

swizzled_input_sf : bool = True
Whether the input scaling factor (input_sf) is in swizzled layout. Defaults to True.
Set to False when input_sf is in linear layout, e.g. after FP4 allgather/alltoall
communication where the scaling factors are received in linear (non-swizzled) format.
Only relevant when input_sf is not None.

Returns
-------
out: torch.Tensor
Expand Down Expand Up @@ -788,6 +798,7 @@ def cutlass_fused_moe(
swiglu_alpha,
swiglu_beta,
swiglu_limit,
swizzled_input_sf,
tp_size,
tp_rank,
ep_size,
Expand Down
124 changes: 124 additions & 0 deletions tests/moe/test_trtllm_cutlass_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.nn import functional as F

import flashinfer.fused_moe as fused_moe
from flashinfer.utils import is_sm100a_supported
from flashinfer import (
autotune,
fp4_quantize,
Expand Down Expand Up @@ -1796,5 +1797,128 @@ def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor:
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1)


@pytest.mark.skipif(
not is_sm100a_supported(torch.device("cuda")),
reason="NVFP4 is only supported on SM100+",
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
def test_moe_nvfp4_unswizzled_input_sf():
"""Test cutlass_fused_moe with swizzled_input_sf=False (linear layout input_sf).

In FP4 allgather/alltoall scenarios, the input scaling factors received after
communication are in linear layout (not swizzled). This test verifies that
passing swizzled_input_sf=False produces the same output as first swizzling
the input_sf and passing swizzled_input_sf=True.
"""
torch.manual_seed(42)
batch_size = 32
hidden_size = 128
intermediate_size = 128
num_experts = 4
top_k = 2
otype = torch.float16
quant_blocksize = 16

def round_up(x, y):
return (x + y - 1) // y * y

e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
w1_n = 2 * n # Swiglu

w1 = torch.randn((e, w1_n, k), device="cuda", dtype=otype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10

sf_w1_2n = round_up(w1_n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)

w1_blockscale = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w2_blockscale = torch.empty(
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
)
w1_q = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)

for expert in range(e):
w1_amax = torch.abs(w1[expert]).max().to(torch.float32)
w2_amax = torch.abs(w2[expert]).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert])
w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert])

x = torch.randn(m, k, dtype=otype).cuda()
a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
router_logits = torch.randn(m, e, dtype=otype).cuda()
routing_weights, selected_experts = compute_routing(router_logits, top_k)

quant_scales = [
a1_gs,
w1_blockscale.view(torch.int32),
1.0 / (a1_gs * w1_gs),
a2_gs,
w2_blockscale.view(torch.int32),
1.0 / (a2_gs * w2_gs),
]

# Quantize input with swizzled layout (default)
hidden_states_swizzled, input_sf_swizzled = fp4_quantize(
x, a1_gs, is_sf_swizzled_layout=True
)
# Quantize input with linear layout (as received after allgather/alltoall)
hidden_states_linear, input_sf_linear = fp4_quantize(
x, a1_gs, is_sf_swizzled_layout=False
)

# Both quantizations should produce the same quantized values
assert torch.equal(hidden_states_swizzled, hidden_states_linear)
# The SF buffers must differ — otherwise the test would pass trivially
# even if fp4_quantize ignored is_sf_swizzled_layout
assert not torch.equal(input_sf_swizzled, input_sf_linear), (
"input_sf_swizzled and input_sf_linear should have different layouts"
)

output_swizzled = torch.zeros(m, k, dtype=otype, device="cuda")
output_linear = torch.zeros(m, k, dtype=otype, device="cuda")

# swizzled_input_sf=True with swizzled input_sf (default behavior)
fused_moe.cutlass_fused_moe(
hidden_states_swizzled,
selected_experts.to(torch.int),
routing_weights,
w1_q.contiguous().view(torch.long),
w2_q.contiguous().view(torch.long),
otype,
quant_scales=quant_scales,
input_sf=input_sf_swizzled,
swizzled_input_sf=True,
output=output_swizzled,
)

# swizzled_input_sf=False with linear input_sf (post-allgather scenario)
fused_moe.cutlass_fused_moe(
hidden_states_linear,
selected_experts.to(torch.int),
routing_weights,
w1_q.contiguous().view(torch.long),
w2_q.contiguous().view(torch.long),
otype,
quant_scales=quant_scales,
input_sf=input_sf_linear,
swizzled_input_sf=False,
output=output_linear,
)

torch.testing.assert_close(output_swizzled, output_linear, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading