Skip to content
221 changes: 181 additions & 40 deletions benchmarks/bench_mxfp4_quantize_backend_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
Benchmark: MXFP4 Quantization Backend Comparison (CUDA vs CuTe-DSL)

Compares the performance of CUDA and CuTe-DSL backends for MXFP4 quantization
across different M and K dimensions. Each configuration is verified for
correctness before timing. Generates heatmaps showing relative performance
(speedup of CuTe-DSL over CUDA).
across different M and K dimensions. Supports both swizzled 128x4 and linear
scale factor layouts. Each configuration is verified for correctness before
timing. Generates heatmaps showing relative performance (speedup of CuTe-DSL
over CUDA).

Can also measure achieved memory bandwidth in TB/s for the CuTe-DSL backend.

Expand Down Expand Up @@ -55,6 +56,7 @@ def verify_mxfp4_correctness(
m: int,
k: int,
dtype: torch.dtype,
is_sf_swizzled_layout: bool,
) -> Tuple[bool, str, float, float]:
"""
Verify that both backends produce correct outputs via roundtrip test.
Expand All @@ -63,19 +65,51 @@ def verify_mxfp4_correctness(
Tuple of (success, message, quant_match_pct, scale_match_pct)
On failure, quant_match_pct and scale_match_pct are 0.0
"""
import flashinfer
from flashinfer.quantization.fp4_quantization import (
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
)

torch.manual_seed(42)
x = torch.randn(m, k, device="cuda", dtype=dtype)
global_sf = ((448 * 6) / x.float().abs().nan_to_num().max()).cuda()

try:
# Test CUDA backend
quant_cuda, scale_cuda = flashinfer.mxfp4_quantize(x, backend="cuda")
dq_cuda = flashinfer.mxfp4_dequantize(quant_cuda, scale_cuda)
quant_cuda, scale_cuda = fp4_quantize(
x,
global_sf,
sf_vec_size=32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=is_sf_swizzled_layout,
backend="cuda",
)
dq_cuda = e2m1_and_ufp8sf_scale_to_float(
quant_cuda.cpu().view(torch.uint8),
scale_cuda.cpu().view(torch.uint8).reshape(-1),
torch.tensor([1.0]),
32,
0,
is_sf_swizzled_layout,
)

# Test CuTe-DSL backend
quant_cute, scale_cute = flashinfer.mxfp4_quantize(x, backend="cute-dsl")
dq_cute = flashinfer.mxfp4_dequantize(quant_cute, scale_cute)
quant_cute, scale_cute = fp4_quantize(
x,
global_sf,
sf_vec_size=32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=is_sf_swizzled_layout,
backend="cute-dsl",
)
dq_cute = e2m1_and_ufp8sf_scale_to_float(
quant_cute.cpu().view(torch.uint8),
scale_cute.cpu().view(torch.uint8).reshape(-1),
torch.tensor([1.0]),
32,
0,
is_sf_swizzled_layout,
)

# Check shapes match
if quant_cuda.shape != quant_cute.shape:
Expand Down Expand Up @@ -131,6 +165,7 @@ def bench_mxfp4_quantize(
m: int,
k: int,
dtype: torch.dtype,
is_sf_swizzled_layout: bool,
backend: str,
) -> float:
"""
Expand All @@ -140,22 +175,38 @@ def bench_mxfp4_quantize(
m: Number of rows
k: Number of columns
dtype: Input dtype (torch.float16 or torch.bfloat16)
is_sf_swizzled_layout: Whether to use swizzled scale factor layout
backend: "cuda" or "cute-dsl"

Returns:
Median execution time in milliseconds
"""
import flashinfer
from flashinfer.quantization.fp4_quantization import fp4_quantize

# Create input tensor
x = torch.randn(m, k, device="cuda", dtype=dtype)

# Warmup and get output shapes
_ = flashinfer.mxfp4_quantize(x, backend=backend)
global_sf = ((448 * 6) / x.float().abs().nan_to_num().max()).cuda()

# Warmup
_ = fp4_quantize(
x,
global_sf,
sf_vec_size=32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=is_sf_swizzled_layout,
backend=backend,
)

# Benchmark
def run_kernel():
flashinfer.mxfp4_quantize(x, backend=backend)
fp4_quantize(
x,
global_sf,
sf_vec_size=32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=is_sf_swizzled_layout,
backend=backend,
)

times = bench_gpu_time(
fn=run_kernel,
Expand Down Expand Up @@ -210,6 +261,7 @@ def run_bandwidth_sweep(
m_values: List[int],
k_values: List[int],
dtype: torch.dtype,
is_sf_swizzled_layout: bool,
) -> Dict[Tuple[int, int], float]:
"""
Run bandwidth benchmark sweep for CuTe-DSL backend only.
Expand All @@ -222,7 +274,10 @@ def run_bandwidth_sweep(
total = len(m_values) * len(k_values)
current = 0

print(f"\nBenchmarking MXFP4 swizzled layout, dtype={dtype} (CuTe-DSL bandwidth)")
layout_str = "swizzled" if is_sf_swizzled_layout else "linear"
print(
f"\nBenchmarking MXFP4 {layout_str} layout, dtype={dtype} (CuTe-DSL bandwidth)"
)
print("=" * 60)

for m in m_values:
Expand All @@ -231,7 +286,9 @@ def run_bandwidth_sweep(
print(f"[{current}/{total}] M={m:5d}, K={k:5d} ... ", end="", flush=True)

# Benchmark CuTe-DSL backend only
time_ms = bench_mxfp4_quantize(m, k, dtype, backend="cute-dsl")
time_ms = bench_mxfp4_quantize(
m, k, dtype, is_sf_swizzled_layout, backend="cute-dsl"
)

# Compute bandwidth
bandwidth = compute_bandwidth_tb_per_sec(m, k, dtype, time_ms)
Expand All @@ -246,6 +303,7 @@ def run_benchmark_sweep(
m_values: List[int],
k_values: List[int],
dtype: torch.dtype,
is_sf_swizzled_layout: bool,
) -> Tuple[Dict[Tuple[int, int], float], Dict[Tuple[int, int], float]]:
"""
Run benchmark sweep for both backends with inline correctness verification.
Expand All @@ -254,6 +312,7 @@ def run_benchmark_sweep(
m_values: List of M dimensions to benchmark
k_values: List of K dimensions to benchmark
dtype: Input dtype
is_sf_swizzled_layout: Whether to use swizzled scale factor layout

Returns:
Tuple of (cuda_times, cute_dsl_times) dictionaries
Expand All @@ -265,7 +324,8 @@ def run_benchmark_sweep(
total = len(m_values) * len(k_values)
current = 0

print(f"\nBenchmarking MXFP4 swizzled layout, dtype={dtype}")
layout_str = "swizzled" if is_sf_swizzled_layout else "linear"
print(f"\nBenchmarking MXFP4 {layout_str} layout, dtype={dtype}")
print("=" * 95)
print(
f"{'Progress':<12} {'M':>5} {'K':>5} | "
Expand All @@ -285,19 +345,23 @@ def run_benchmark_sweep(

# Verify correctness first
success, verify_msg, quant_match, scale_match = verify_mxfp4_correctness(
m, k, dtype
m, k, dtype, is_sf_swizzled_layout
)
if not success:
failures.append((m, k, verify_msg))
print(f"[{current:3d}/{total}] {m:5d} {k:5d} | FAIL: {verify_msg}")
continue

# Benchmark CUDA backend
cuda_time = bench_mxfp4_quantize(m, k, dtype, backend="cuda")
cuda_time = bench_mxfp4_quantize(
m, k, dtype, is_sf_swizzled_layout, backend="cuda"
)
cuda_times[(m, k)] = cuda_time

# Benchmark CuTe-DSL backend
cute_dsl_time = bench_mxfp4_quantize(m, k, dtype, backend="cute-dsl")
cute_dsl_time = bench_mxfp4_quantize(
m, k, dtype, is_sf_swizzled_layout, backend="cute-dsl"
)
cute_dsl_times[(m, k)] = cute_dsl_time

# Compute speedup
Expand Down Expand Up @@ -497,10 +561,11 @@ def print_bandwidth_summary_table(
m_values: List[int],
k_values: List[int],
bandwidth_results: Dict[Tuple[int, int], float],
layout_name: str = "Swizzled Layout",
):
"""Print a summary table of bandwidth results."""
print(f"\n{'=' * 80}")
print("Bandwidth Summary: MXFP4 Swizzled Layout (TB/s)")
print(f"Bandwidth Summary: MXFP4 {layout_name} (TB/s)")
print(f"{'=' * 80}")

# Header
Expand Down Expand Up @@ -537,10 +602,11 @@ def print_summary_table(
k_values: List[int],
cuda_times: Dict[Tuple[int, int], float],
cute_dsl_times: Dict[Tuple[int, int], float],
layout_name: str = "Swizzled Layout",
):
"""Print a summary table of results."""
print(f"\n{'=' * 80}")
print("Summary: MXFP4 Swizzled Layout (Speedup: CUDA time / CuTe-DSL time)")
print(f"Summary: MXFP4 {layout_name} (Speedup: CUDA time / CuTe-DSL time)")
print(f"{'=' * 80}")

# Header
Expand Down Expand Up @@ -618,12 +684,20 @@ def main():
print(f"Data type: {dtype}")

# Define sweep ranges (powers of 2 + common transformer hidden dimensions)
# Note: K must be a multiple of 128 for MXFP4 swizzled layout because:
# - SF vec size is 32, so K/32 gives number of SF blocks per row
# - Swizzled layout pads SF blocks to multiples of 4
# - The CUDA backend's reshape assumes unpadded SF dimensions
# So K/32 must already be a multiple of 4, i.e., K must be multiple of 128
# K constraints:
# - Linear layout: K must be a multiple of 32 (SF_VEC_SIZE)
# - Swizzled layout: K must be a multiple of 128 because K/32 (SF blocks
# per row) must be a multiple of 4 for the swizzled padding to work
# correctly with the CUDA backend's reshape
# We use K values that satisfy both constraints (multiples of 128)
m_values = [
1,
2,
4,
8,
16,
32,
64,
128,
256,
384,
Expand Down Expand Up @@ -667,30 +741,97 @@ def main():
print("BANDWIDTH MEASUREMENT MODE (CuTe-DSL only)")
print("=" * 80)

bandwidth_results = run_bandwidth_sweep(m_values, k_values, dtype)
print_bandwidth_summary_table(m_values, k_values, bandwidth_results)
# Benchmark linear layout
print("\n" + "=" * 80)
print("BENCHMARKING LINEAR (NON-SWIZZLED) LAYOUT - BANDWIDTH")
print("=" * 80)

# Generate bandwidth heatmap
bandwidth_linear = run_bandwidth_sweep(
m_values, k_values, dtype, is_sf_swizzled_layout=False
)
print_bandwidth_summary_table(
m_values, k_values, bandwidth_linear, "Linear Layout"
)
create_bandwidth_heatmap(
m_values,
k_values,
bandwidth_results,
f"MXFP4 Quantization CuTe-DSL Bandwidth ({args.dtype})",
f"{args.output_prefix}_bandwidth_{args.dtype}.png",
bandwidth_linear,
f"MXFP4 Quantization Bandwidth (CuTe-DSL) - Linear Layout - {args.dtype}",
f"{args.output_prefix}_bandwidth_linear_{args.dtype}.png",
)

# Benchmark swizzled layout
print("\n" + "=" * 80)
print("BENCHMARKING SWIZZLED LAYOUT - BANDWIDTH")
print("=" * 80)

bandwidth_swizzled = run_bandwidth_sweep(
m_values, k_values, dtype, is_sf_swizzled_layout=True
)
print_bandwidth_summary_table(
m_values, k_values, bandwidth_swizzled, "Swizzled Layout"
)
create_bandwidth_heatmap(
m_values,
k_values,
bandwidth_swizzled,
f"MXFP4 Quantization Bandwidth (CuTe-DSL) - Swizzled Layout - {args.dtype}",
f"{args.output_prefix}_bandwidth_swizzled_{args.dtype}.png",
)
else:
# Run comparison benchmark (with inline correctness verification)
cuda_times, cute_dsl_times = run_benchmark_sweep(m_values, k_values, dtype)
print_summary_table(m_values, k_values, cuda_times, cute_dsl_times)
# Speedup comparison mode: CUDA vs CuTe-DSL
# Benchmark linear layout (non-swizzled)
print("\n" + "=" * 80)
print("BENCHMARKING LINEAR (NON-SWIZZLED) LAYOUT")
print("=" * 80)

cuda_times_linear, cute_dsl_times_linear = run_benchmark_sweep(
m_values,
k_values,
dtype,
is_sf_swizzled_layout=False,
)
print_summary_table(
m_values,
k_values,
cuda_times_linear,
cute_dsl_times_linear,
"Linear Layout",
)
create_heatmap(
m_values,
k_values,
cuda_times_linear,
cute_dsl_times_linear,
f"MXFP4 Quantization Speedup (CuTe-DSL vs CUDA) - Linear Layout - {args.dtype}",
f"{args.output_prefix}_comparison_linear_{args.dtype}.png",
)

# Benchmark swizzled layout
print("\n" + "=" * 80)
print("BENCHMARKING SWIZZLED LAYOUT")
print("=" * 80)

# Generate heatmap
cuda_times_swizzled, cute_dsl_times_swizzled = run_benchmark_sweep(
m_values,
k_values,
dtype,
is_sf_swizzled_layout=True,
)
print_summary_table(
m_values,
k_values,
cuda_times_swizzled,
cute_dsl_times_swizzled,
"Swizzled Layout",
)
create_heatmap(
m_values,
k_values,
cuda_times,
cute_dsl_times,
f"MXFP4 Quantization Backend Comparison ({args.dtype})",
f"{args.output_prefix}_comparison_{args.dtype}.png",
cuda_times_swizzled,
cute_dsl_times_swizzled,
f"MXFP4 Quantization Speedup (CuTe-DSL vs CUDA) - Swizzled Layout - {args.dtype}",
f"{args.output_prefix}_comparison_swizzled_{args.dtype}.png",
)

print("\n" + "=" * 80)
Expand Down
Loading
Loading