diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 839f8b495b..d8ce910bb6 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -54,7 +54,8 @@ "use_routing_bias", "use_routing_scales_on_input", "weight_dtype", - "gated_act", + "activation_type", + "fp4_mode", # CUTLASS fused MoE specific "cutlass_variant", "quantized_input", diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index 4adabdebca..92ee4073bd 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -1,3 +1,4 @@ +import inspect from collections import defaultdict from typing import Optional @@ -5,7 +6,13 @@ import torch import flashinfer -from flashinfer import ActivationType + +try: + from flashinfer import ActivationType +except ImportError: + # ActivationType was not exported from the top-level package until 0.6.3 + from flashinfer.fused_moe.core import ActivationType + from flashinfer.autotuner import autotune from flashinfer.fused_moe import ( trtllm_fp4_block_scale_moe, @@ -15,7 +22,7 @@ fused_topk_deepseek, ) from flashinfer.fused_moe.core import RoutingMethodType -from flashinfer import fp4_quantize +from flashinfer import fp4_quantize, mxfp8_quantize from flashinfer.testing.utils import ( bench_gpu_time, ) @@ -41,8 +48,32 @@ create_moe_output_scale_scalars, FLOAT8_E4M3_MAX, FLOAT4_E2M1_MAX, + SF_VEC_SIZE, ) +# Before 0.6.3, MoE APIs used "gated_act_type" (SwiGlu=0, GeGlu=1) instead of +# "activation_type" (Swiglu=3, Geglu=4). Some prior APIs omit the parameter entirely. +_ACTIVATION_TO_GATED_ACT = { + ActivationType.Swiglu: 0, + ActivationType.Geglu: 1, +} + + +def _activation_kwarg(fn, activation_type: ActivationType) -> dict: + """Return the correct activation keyword argument for *fn* in the installed version.""" + sig = inspect.signature(fn) + if "activation_type" in sig.parameters: + return {"activation_type": activation_type.value} + if "gated_act_type" in sig.parameters: + if activation_type not in _ACTIVATION_TO_GATED_ACT: + raise ValueError( + f"Activation type {activation_type.name} is not supported by the " + f"installed flashinfer version (pre-0.6.3 only supports " + f"{[k.name for k in _ACTIVATION_TO_GATED_ACT]})" + ) + return {"gated_act_type": _ACTIVATION_TO_GATED_ACT[activation_type]} + return {} + def run_moe_test(args): """ @@ -187,6 +218,20 @@ def parse_moe_args(line, parser): "Enable autotuner warmup for supported routines (trtllm_fp4_block_scale_moe and cutlass_fused_moe)." ), ) + parser.add_argument( + "--fp4_mode", + type=str, + required=False, + default="nvfp4", + choices=["nvfp4", "mxfp4_mxfp8", "mxfp4_bf16"], + help=( + "FP4 quantization mode for trtllm_fp4_block_scale_moe: " + "nvfp4 (NvFP4 weights + NvFP4 hidden states, block_size=16), " + "mxfp4_mxfp8 (MXFP4 weights + MXFP8 hidden states, block_size=32), " + "mxfp4_bf16 (MXFP4 weights + BF16 hidden states, block_size=32). " + "Default: nvfp4" + ), + ) # CUTLASS fused MoE specific parser.add_argument( @@ -482,11 +527,16 @@ def testTrtllmFp4BlockScaleMoe(args): routed_scaling_factor=routed_scaling_factor, ) - # For FP4, we need to properly quantize weights and create scales - use_ue8m0 = False + # Determine FP4 quantization mode + fp4_mode = getattr(args, "fp4_mode", "nvfp4") + is_mxfp4 = fp4_mode in ("mxfp4_mxfp8", "mxfp4_bf16") + use_ue8m0 = is_mxfp4 + sf_vec_size = SF_VEC_SIZE["mxfp4" if is_mxfp4 else "nvfp4"] - # Calculate global scale factor for hidden states - hidden_states_scale_global = calculate_fp4_global_scale(hidden_states) + if args.verbose >= 1: + print( + f"[INFO] FP4 mode: {fp4_mode} (use_ue8m0={use_ue8m0}, sf_vec_size={sf_vec_size})" + ) # Quantize weights using proper FP4 quantization gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( @@ -496,52 +546,63 @@ def testTrtllmFp4BlockScaleMoe(args): quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True) ) - # Quantize hidden states - hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quantize_fp4( - hidden_states, hidden_states_scale_global, use_ue8m0, True - ) - - # Reshape hidden states for the kernel (pack 2 FP4 values into 1 byte) - # Keep as uint8 format for FP4 packed data - hidden_states_fp4 = hidden_states_fp4_bytes.view(torch.uint8).reshape( - hidden_states.shape[0], hidden_states.shape[1] // 2 - ) - # Hidden-states scale for FP4 must be 2D: [num_tokens, hidden_size // 16] - hidden_states_scale_linear_fp4 = hidden_states_scale_fp4_bytes.view( - torch.float8_e4m3fn - ) - # Ensure expected shape (16 elements per hidden value for NvFP4) - expected_scale_elems = (num_tokens * hidden_size) // 16 - if hidden_states_scale_linear_fp4.numel() != expected_scale_elems: - if args.verbose >= 1: + # Prepare hidden states and scale based on fp4_mode + if fp4_mode == "mxfp4_bf16": + hidden_states_fp4 = hidden_states.to(torch.bfloat16) + hidden_states_scale_linear_fp4 = None + elif fp4_mode == "mxfp4_mxfp8": + if num_tokens % 128 != 0: + raise ValueError( + f"mxfp4_mxfp8 mode requires num_tokens to be a multiple of 128 " + f"(got {num_tokens}) because mxfp8_quantize with swizzled scale " + f"layout pads rows to 128-element boundaries." + ) + hs_quant, hs_scale = mxfp8_quantize(hidden_states, True) + hidden_states_fp4 = hs_quant + hidden_states_scale_linear_fp4 = hs_scale.view(torch.float8_e4m3fn).reshape( + num_tokens, -1 + ) + else: + hidden_states_scale_global = calculate_fp4_global_scale(hidden_states) + hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quantize_fp4( + hidden_states, hidden_states_scale_global, use_ue8m0, True + ) + hidden_states_fp4 = hidden_states_fp4_bytes.view(torch.uint8).reshape( + num_tokens, hidden_size // 2 + ) + hidden_states_scale_linear_fp4 = hidden_states_scale_fp4_bytes.view( + torch.float8_e4m3fn + ) + expected_scale_elems = (num_tokens * hidden_size) // sf_vec_size + if hidden_states_scale_linear_fp4.numel() != expected_scale_elems: print( - f"[INFO] Adjusting FP4 hidden_states_scale from {hidden_states_scale_linear_fp4.numel()} to {expected_scale_elems} elements" + f"[WARNING] FP4 hidden_states_scale element count mismatch " + f"({hidden_states_scale_linear_fp4.numel()} vs expected {expected_scale_elems}), " + f"substituting all-ones scale tensor. This is likely caused by swizzled " + f"layout padding when num_tokens ({num_tokens}) is not a multiple of 128." + ) + hidden_states_scale_linear_fp4 = torch.ones( + expected_scale_elems, device=device, dtype=torch.float8_e4m3fn ) - hidden_states_scale_linear_fp4 = torch.ones( - expected_scale_elems, device=device, dtype=torch.float8_e4m3fn + hidden_states_scale_linear_fp4 = hidden_states_scale_linear_fp4.reshape( + num_tokens, hidden_size // sf_vec_size ) - hidden_states_scale_linear_fp4 = hidden_states_scale_linear_fp4.reshape( - num_tokens, hidden_size // 16 - ) - # Prepare weights for kernel - # For FP4 weights, keep them as uint8 (packed format) - don't convert to float8_e4m3fn + # Prepare weights for kernel (packed uint8 format, scale as float8_e4m3fn) gemm1_weights_fp4 = gemm1_weights_fp4_bytes.view(torch.uint8).reshape( num_experts, 2 * intermediate_size, hidden_size // 2 ) - # Scale factors should be viewed as float8_e4m3fn gemm1_weights_scale = gemm1_scales_fp4_bytes.view(torch.float8_e4m3fn).reshape( - num_experts, 2 * intermediate_size, hidden_size // 16 + num_experts, 2 * intermediate_size, hidden_size // sf_vec_size ) gemm2_weights_fp4 = gemm2_weights_fp4_bytes.view(torch.uint8).reshape( num_experts, hidden_size, intermediate_size // 2 ) gemm2_weights_scale = gemm2_scales_fp4_bytes.view(torch.float8_e4m3fn).reshape( - num_experts, hidden_size, intermediate_size // 16 + num_experts, hidden_size, intermediate_size // sf_vec_size ) - # Optional parameters for FP4 (using None for simplicity in benchmarking) gemm1_bias = None gemm1_alpha = None gemm1_beta = None @@ -598,8 +659,8 @@ def run_fp4_moe( local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, routing_method_type=routing_method_type, - activation_type=activation_type.value, do_finalize=True, + **_activation_kwarg(trtllm_fp4_block_scale_moe, activation_type), ) backend = "trtllm" @@ -660,6 +721,10 @@ def run_fp4_moe( tflops = calculate_moe_tflops( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time ) + input_format_str = {"nvfp4": "nvfp4", "mxfp4_mxfp8": "mxfp8", "mxfp4_bf16": "bf16"}[ + fp4_mode + ] + weight_format_str = "mxfp4" if is_mxfp4 else "nvfp4" tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, @@ -669,8 +734,8 @@ def run_fp4_moe( median_time, input_dtype, weight_dtype, - input_format="nvfp4", - weight_format="nvfp4", + input_format=input_format_str, + weight_format=weight_format_str, routing_logits_dtype=routing_logits.dtype, active_experts=int(selected_experts.unique().numel()), verbose=args.verbose, @@ -704,6 +769,7 @@ def run_fp4_moe( cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype cur_res["activation_type"] = args.activation_type.name + cur_res["fp4_mode"] = fp4_mode res.append(cur_res) return res @@ -1482,7 +1548,7 @@ def run_fp8_per_tensor_moe( routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, - activation_type=activation_type.value, + **_activation_kwarg(trtllm_fp8_per_tensor_scale_moe, activation_type), ) # Benchmark timing diff --git a/benchmarks/routines/moe_utils.py b/benchmarks/routines/moe_utils.py index 64d9511145..4705b7e9e3 100644 --- a/benchmarks/routines/moe_utils.py +++ b/benchmarks/routines/moe_utils.py @@ -39,6 +39,12 @@ FLOAT8_E4M3_MAX = 448.0 FLOAT4_E2M1_MAX = 6.0 +SF_VEC_SIZE = { + "nvfp4": 16, + "mxfp4": 32, + "mxfp8": 32, +} + def generate_moe_weights( num_experts: int, @@ -113,7 +119,7 @@ def quantize_fp4( - block_scale_factors: float8_e4m3fn tensor - global_scale_factor: float32 scalar tensor """ - sf_vec_size = 16 + sf_vec_size = SF_VEC_SIZE["mxfp4" if use_ue8m0 else "nvfp4"] if global_scale is None: global_scale = calculate_fp4_global_scale(tensor) @@ -537,11 +543,11 @@ def get_effective_bytes( dtype: torch.dtype, fmt: Optional[str], is_weight: bool = False ) -> float: if fmt == "nvfp4": - # 1 e4m3 + 1 e4m3 scale per 16-element block - return 0.5 + 1 / 16 + return 0.5 + 1 / SF_VEC_SIZE["nvfp4"] elif fmt == "mxfp4": - # 1 e2m1 + 1 ue8m0 scale per 32-element block - return 0.5 + 1 / 32 + return 0.5 + 1 / SF_VEC_SIZE["mxfp4"] + elif fmt == "mxfp8": + return 1.0 + 1 / SF_VEC_SIZE["mxfp8"] elif fmt == "fp8": # 1 e4m3 return 1.0