Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
"use_routing_bias",
"use_routing_scales_on_input",
"weight_dtype",
"gated_act",
"activation_type",
"fp4_mode",
# CUTLASS fused MoE specific
"cutlass_variant",
"quantized_input",
Expand Down
145 changes: 105 additions & 40 deletions benchmarks/routines/moe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import inspect
from collections import defaultdict
from typing import Optional

import numpy as np
import torch

import flashinfer
from flashinfer import ActivationType

try:
from flashinfer import ActivationType
except ImportError:
# ActivationType was not exported from the top-level package until 0.6.3
from flashinfer.fused_moe.core import ActivationType

from flashinfer.autotuner import autotune
from flashinfer.fused_moe import (
trtllm_fp4_block_scale_moe,
Expand All @@ -15,7 +22,7 @@
fused_topk_deepseek,
)
from flashinfer.fused_moe.core import RoutingMethodType
from flashinfer import fp4_quantize
from flashinfer import fp4_quantize, mxfp8_quantize
from flashinfer.testing.utils import (
bench_gpu_time,
)
Expand Down Expand Up @@ -43,6 +50,29 @@
FLOAT4_E2M1_MAX,
)

# Before 0.6.3, MoE APIs used "gated_act_type" (SwiGlu=0, GeGlu=1) instead of
# "activation_type" (Swiglu=3, Geglu=4). Some prior APIs omit the parameter entirely.
_ACTIVATION_TO_GATED_ACT = {
ActivationType.Swiglu: 0,
ActivationType.Geglu: 1,
}


def _activation_kwarg(fn, activation_type: ActivationType) -> dict:
"""Return the correct activation keyword argument for *fn* in the installed version."""
sig = inspect.signature(fn)
if "activation_type" in sig.parameters:
return {"activation_type": activation_type.value}
if "gated_act_type" in sig.parameters:
if activation_type not in _ACTIVATION_TO_GATED_ACT:
raise ValueError(
f"Activation type {activation_type.name} is not supported by the "
f"installed flashinfer version (pre-0.6.3 only supports "
f"{[k.name for k in _ACTIVATION_TO_GATED_ACT]})"
)
return {"gated_act_type": _ACTIVATION_TO_GATED_ACT[activation_type]}
return {}


def run_moe_test(args):
"""
Expand Down Expand Up @@ -187,6 +217,20 @@ def parse_moe_args(line, parser):
"Enable autotuner warmup for supported routines (trtllm_fp4_block_scale_moe and cutlass_fused_moe)."
),
)
parser.add_argument(
"--fp4_mode",
type=str,
required=False,
default="nvfp4",
choices=["nvfp4", "mxfp4_mxfp8", "mxfp4_bf16"],
help=(
"FP4 quantization mode for trtllm_fp4_block_scale_moe: "
"nvfp4 (NvFP4 weights + NvFP4 hidden states, block_size=16), "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these block_size are hardcoded numbers, we can create a utility function or a dict to store the block_size, e.g.:

sf_vec_size = {
    "nvfp4": 16,
    "mxfp4": 32,
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in latest commit d149911

"mxfp4_mxfp8 (MXFP4 weights + MXFP8 hidden states, block_size=32), "
"mxfp4_bf16 (MXFP4 weights + BF16 hidden states, block_size=32). "
"Default: nvfp4"
),
)

# CUTLASS fused MoE specific
parser.add_argument(
Expand Down Expand Up @@ -482,11 +526,16 @@ def testTrtllmFp4BlockScaleMoe(args):
routed_scaling_factor=routed_scaling_factor,
)

# For FP4, we need to properly quantize weights and create scales
use_ue8m0 = False
# Determine FP4 quantization mode
fp4_mode = getattr(args, "fp4_mode", "nvfp4")
is_mxfp4 = fp4_mode in ("mxfp4_mxfp8", "mxfp4_bf16")
use_ue8m0 = is_mxfp4
sf_vec_size = 32 if is_mxfp4 else 16

# Calculate global scale factor for hidden states
hidden_states_scale_global = calculate_fp4_global_scale(hidden_states)
if args.verbose >= 1:
print(
f"[INFO] FP4 mode: {fp4_mode} (use_ue8m0={use_ue8m0}, sf_vec_size={sf_vec_size})"
)

# Quantize weights using proper FP4 quantization
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
Expand All @@ -496,52 +545,63 @@ def testTrtllmFp4BlockScaleMoe(args):
quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True)
)

# Quantize hidden states
hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quantize_fp4(
hidden_states, hidden_states_scale_global, use_ue8m0, True
)

# Reshape hidden states for the kernel (pack 2 FP4 values into 1 byte)
# Keep as uint8 format for FP4 packed data
hidden_states_fp4 = hidden_states_fp4_bytes.view(torch.uint8).reshape(
hidden_states.shape[0], hidden_states.shape[1] // 2
)
# Hidden-states scale for FP4 must be 2D: [num_tokens, hidden_size // 16]
hidden_states_scale_linear_fp4 = hidden_states_scale_fp4_bytes.view(
torch.float8_e4m3fn
)
# Ensure expected shape (16 elements per hidden value for NvFP4)
expected_scale_elems = (num_tokens * hidden_size) // 16
if hidden_states_scale_linear_fp4.numel() != expected_scale_elems:
if args.verbose >= 1:
# Prepare hidden states and scale based on fp4_mode
if fp4_mode == "mxfp4_bf16":
hidden_states_fp4 = hidden_states.to(torch.bfloat16)
hidden_states_scale_linear_fp4 = None
elif fp4_mode == "mxfp4_mxfp8":
if num_tokens % 128 != 0:
raise ValueError(
f"mxfp4_mxfp8 mode requires num_tokens to be a multiple of 128 "
f"(got {num_tokens}) because mxfp8_quantize with swizzled scale "
f"layout pads rows to 128-element boundaries."
)
hs_quant, hs_scale = mxfp8_quantize(hidden_states, True)
Comment thread
bkryu marked this conversation as resolved.
hidden_states_fp4 = hs_quant
hidden_states_scale_linear_fp4 = hs_scale.view(torch.float8_e4m3fn).reshape(
num_tokens, -1
)
Comment thread
bkryu marked this conversation as resolved.
else:
hidden_states_scale_global = calculate_fp4_global_scale(hidden_states)
hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quantize_fp4(
hidden_states, hidden_states_scale_global, use_ue8m0, True
)
Comment thread
bkryu marked this conversation as resolved.
hidden_states_fp4 = hidden_states_fp4_bytes.view(torch.uint8).reshape(
num_tokens, hidden_size // 2
)
hidden_states_scale_linear_fp4 = hidden_states_scale_fp4_bytes.view(
torch.float8_e4m3fn
)
expected_scale_elems = (num_tokens * hidden_size) // sf_vec_size
if hidden_states_scale_linear_fp4.numel() != expected_scale_elems:
print(
f"[INFO] Adjusting FP4 hidden_states_scale from {hidden_states_scale_linear_fp4.numel()} to {expected_scale_elems} elements"
f"[WARNING] FP4 hidden_states_scale element count mismatch "
f"({hidden_states_scale_linear_fp4.numel()} vs expected {expected_scale_elems}), "
f"substituting all-ones scale tensor. This is likely caused by swizzled "
f"layout padding when num_tokens ({num_tokens}) is not a multiple of 128."
)
hidden_states_scale_linear_fp4 = torch.ones(
expected_scale_elems, device=device, dtype=torch.float8_e4m3fn
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
hidden_states_scale_linear_fp4 = torch.ones(
expected_scale_elems, device=device, dtype=torch.float8_e4m3fn
hidden_states_scale_linear_fp4 = hidden_states_scale_linear_fp4.reshape(
num_tokens, hidden_size // sf_vec_size
)
hidden_states_scale_linear_fp4 = hidden_states_scale_linear_fp4.reshape(
num_tokens, hidden_size // 16
)

# Prepare weights for kernel
# For FP4 weights, keep them as uint8 (packed format) - don't convert to float8_e4m3fn
# Prepare weights for kernel (packed uint8 format, scale as float8_e4m3fn)
gemm1_weights_fp4 = gemm1_weights_fp4_bytes.view(torch.uint8).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2
)
# Scale factors should be viewed as float8_e4m3fn
gemm1_weights_scale = gemm1_scales_fp4_bytes.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 16
num_experts, 2 * intermediate_size, hidden_size // sf_vec_size
)

gemm2_weights_fp4 = gemm2_weights_fp4_bytes.view(torch.uint8).reshape(
num_experts, hidden_size, intermediate_size // 2
)
gemm2_weights_scale = gemm2_scales_fp4_bytes.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 16
num_experts, hidden_size, intermediate_size // sf_vec_size
)

# Optional parameters for FP4 (using None for simplicity in benchmarking)
gemm1_bias = None
gemm1_alpha = None
gemm1_beta = None
Expand Down Expand Up @@ -598,8 +658,8 @@ def run_fp4_moe(
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type,
activation_type=activation_type.value,
do_finalize=True,
**_activation_kwarg(trtllm_fp4_block_scale_moe, activation_type),
)

backend = "trtllm"
Expand Down Expand Up @@ -660,6 +720,10 @@ def run_fp4_moe(
tflops = calculate_moe_tflops(
num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time
)
input_format_str = {"nvfp4": "nvfp4", "mxfp4_mxfp8": "mxfp8", "mxfp4_bf16": "bf16"}[
fp4_mode
]
weight_format_str = "mxfp4" if is_mxfp4 else "nvfp4"
tb_per_sec = calculate_moe_kernel_bandwidth(
num_tokens,
hidden_size,
Expand All @@ -669,8 +733,8 @@ def run_fp4_moe(
median_time,
input_dtype,
weight_dtype,
input_format="nvfp4",
weight_format="nvfp4",
input_format=input_format_str,
weight_format=weight_format_str,
routing_logits_dtype=routing_logits.dtype,
active_experts=int(selected_experts.unique().numel()),
verbose=args.verbose,
Expand Down Expand Up @@ -704,6 +768,7 @@ def run_fp4_moe(
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
cur_res["activation_type"] = args.activation_type.name
cur_res["fp4_mode"] = fp4_mode
res.append(cur_res)

return res
Expand Down Expand Up @@ -1482,7 +1547,7 @@ def run_fp8_per_tensor_moe(
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
activation_type=activation_type.value,
**_activation_kwarg(trtllm_fp8_per_tensor_scale_moe, activation_type),
)

# Benchmark timing
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/routines/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def quantize_fp4(
- block_scale_factors: float8_e4m3fn tensor
- global_scale_factor: float32 scalar tensor
"""
sf_vec_size = 16
sf_vec_size = 32 if use_ue8m0 else 16

if global_scale is None:
global_scale = calculate_fp4_global_scale(tensor)
Expand Down Expand Up @@ -542,6 +542,9 @@ def get_effective_bytes(
elif fmt == "mxfp4":
# 1 e2m1 + 1 ue8m0 scale per 32-element block
return 0.5 + 1 / 32
elif fmt == "mxfp8":
# 1 e4m3 + 1 ue8m0 scale per 32-element block
return 1.0 + 1 / 32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto; addressed in latest commit d149911

elif fmt == "fp8":
# 1 e4m3
return 1.0
Expand Down
Loading