Skip to content

[PERF] FP8 Blockwise GEMM Performance #2146

@vadiklyutiy

Description

@vadiklyutiy

FP8 Blockwise GEMM Performance Lower Than Expected on B200

Summary

FP8 blockwise GEMM operations (gemm_fp8_nt_groupwise) are slower than FP16 linear operations in most test configurations on NVIDIA B200 GPU, which is unexpected given FP8's theoretical performance advantages.

Environment

  • GPU: NVIDIA Blackwell B200
  • FlashInfer: Latest version
  • Precision:
    • FP16: torch.float16 for linear operations
    • FP8: torch.float8_e4m3fn with blockwise quantization (128x128 blocks)

Problem Description

Benchmarking FP16 torch.nn.functional.linear against FP8 gemm_fp8_nt_groupwise with both Cutlass and TRT-LLM backends shows that:

  1. FP16 is faster in most cases: Ratio to FP16 < 1.0 indicates FP8 is slower
  2. Cutlass backend: Consistently slower than FP16 (ratios 0.58-0.97)
  3. TRT-LLM backend: Shows better performance in some configurations but still slower for other

Benchmark Results

Benchmark methodology:

  • Measurement: Median latency over 100 iterations with CUPTI profiling
  • L2 flush: Enabled for both FP16 and FP8
  • Scale mode: MN (per-token activation, blockwise weight)
  • Block size: 128x128

Full Results Table

===============================================================================================================================
SUMMARY TABLE
===============================================================================================================================
  Out    In  Batch |        FP16        |            FP8 Cutlass            |            FP8 TRT-LLM
                   | Median(us)  Std(%) | Median(us)  Std(%)  Ratio to FP16 | Median(us)  Std(%)  Ratio to FP16
-------------------------------------------------------------------------------------------------------------------------------
 2048  1024    128 |       5.02    2.48 |       7.39    2.42           0.68 |       4.90    1.75           1.03
 2048  1024    256 |       5.33    3.38 |       7.52    1.40           0.71 |       4.99    1.96           1.07
 2048  1024    512 |       5.98    2.53 |       7.65    1.26           0.78 |       8.70    1.59           0.69
 2048  1024   1024 |       7.33    1.74 |       8.03    1.40           0.91 |      16.03    0.68           0.46
 2048   128    128 |       3.15    4.35 |       5.47    1.98           0.58 |          -       -              -
 2048   128    256 |       3.46    3.34 |       5.62    1.62           0.62 |          -       -              -
 2048   128    512 |       3.86    2.77 |       5.60    1.56           0.69 |          -       -              -
 2048   128   1024 |       3.90    3.17 |       5.73    1.88           0.68 |          -       -              -
  256  2048    128 |       8.64    5.11 |       8.91    2.63           0.97 |       5.84    3.00           1.48
  256  2048    256 |       5.54    2.51 |       8.90    1.81           0.62 |       5.95    1.52           0.93
  256  2048    512 |       5.76    2.08 |       9.02    1.44           0.64 |       6.13    1.85           0.94
  256  2048   1024 |       6.08    2.23 |       9.31    1.54           0.65 |       6.37    1.62           0.95
 2560  2048    128 |       7.07    1.49 |       9.76    1.46           0.72 |       6.83    1.37           1.04
 2560  2048    256 |       7.39    1.72 |       9.95    1.37           0.74 |      11.23    0.66           0.66
 2560  2048    512 |       8.86    1.60 |      10.22    1.21           0.87 |      16.54    0.84           0.54
 2560  2048   1024 |      11.23    1.39 |      15.74    1.52           0.71 |      27.14    0.76           0.41
 3072  2048    128 |       7.23    1.98 |      10.08    1.46           0.72 |       7.14    1.39           1.01
 3072  2048    256 |       7.62    2.14 |      10.29    1.42           0.74 |      11.87    1.00           0.64
 3072  2048    512 |       9.54    1.54 |      10.59    1.26           0.90 |      17.41    0.83           0.55
 3072  2048   1024 |      13.54    1.18 |      16.40    1.65           0.83 |      31.30    0.67           0.43
===============================================================================================================================

Note: "Ratio to FP16" = FP16_time / FP8_time. Values > 1.0 indicate FP8 is faster; < 1.0 indicates FP8 is slower.

Key Observations

  • Cutlass backend: Consistently slower than FP16 (0.58x - 0.97x)
  • TRT-LLM backend: Mixed results; only faster for small batch sizes (up to 1.48x), degrades significantly with larger batches (down to 0.41x)
  • TRT-LLM limitation: Does not support K=128 configurations

Expected Behavior

FP8 operations should be faster than FP16 on B200 hardware due to:

  • Higher compute throughput for FP8 operations
  • Reduced memory bandwidth requirements (1 byte vs 2 bytes per element)
  • B200's dedicated FP8 tensor cores

Expected speedup: ~1.5-2x over FP16 for compute-bound workloads.

Actual Behavior

FP8 operations are slower than FP16 in most configurations.

Code to Reproduce

Click to expand full benchmark script: bench_linear_fp8_fp16.py
"""
Combined performance benchmark comparing:
- FP16 torch.nn.functional.linear
- FP8 gemm_fp8_nt_groupwise (cutlass backend)
- FP8 gemm_fp8_nt_groupwise (trtllm backend)

Tests various shapes and batch sizes with unified reporting.
"""

import os
import sys
import torch
import numpy as np

# Set environment variables
os.environ['FLASHINFER_DISABLE_VERSION_CHECK'] = '1'

from flashinfer.testing.utils import bench_gpu_time_with_cupti, quantize_fp8

# Import FlashInfer after path setup
sys.path.insert(0, '/home/scratch.vgimpelson_ent/flashinfer')
from flashinfer.gemm import gemm_fp8_nt_groupwise


def create_fp8_tensors(m, n, k, scale_major_mode="MN"):
    """Create FP8 tensors with proper scaling for gemm_fp8_nt_groupwise."""
    block_size = 128
    
    # Create input tensors in bfloat16
    a_bf16 = torch.randn(m, k, device='cuda', dtype=torch.bfloat16)
    b_bf16 = torch.randn(n, k, device='cuda', dtype=torch.bfloat16)
    
    a_scale_shape = (k // block_size, m)
    a_tile_shape = (1, block_size)
    
    b_scale_shape = (k // block_size, n // block_size)
    b_tile_shape = (block_size, block_size)
    
    a_fp8, a_scale = quantize_fp8(a_bf16, a_scale_shape, a_tile_shape, scale_major_mode)
    b_fp8, b_scale = quantize_fp8(b_bf16, b_scale_shape, b_tile_shape, scale_major_mode)
    
    # Create output tensor
    out = torch.empty(m, n, device='cuda', dtype=torch.bfloat16)
    
    return a_fp8, b_fp8, a_scale, b_scale, out


def benchmark_fp16_linear(batch_size, out_features, in_features):
    """Benchmark FP16 torch.nn.functional.linear."""
    device = 'cuda'
    dtype = torch.float16
    
    # Create tensors
    weight = torch.randn(out_features, in_features, device=device, dtype=dtype)
    bias = torch.randn(out_features, device=device, dtype=dtype)
    input_tensor = torch.randn(batch_size, in_features, device=device, dtype=dtype)
    
    # Warmup
    for _ in range(5):
        _ = torch.nn.functional.linear(input_tensor, weight, bias)
    torch.cuda.synchronize()
    
    # Benchmark function
    def benchmark_fn():
        return torch.nn.functional.linear(input_tensor, weight, bias)
    
    # Run benchmark with CUPTI
    times = bench_gpu_time_with_cupti(
        benchmark_fn,
        l2_flush=True,
        repeat_iters=100,
    )
    
    # Calculate statistics (convert ms to us)
    times_np = np.array(times) * 1000  # ms to us
    median_time = np.median(times_np)
    std_time = np.std(times_np)
    std_pct = (std_time / median_time) * 100 if median_time > 0 else 0
    
    # Calculate FLOPS
    flops = 2 * batch_size * in_features * out_features + batch_size * out_features
    tflops_per_sec = flops / (median_time * 1e-6) / 1e12
    
    return {
        'median_us': median_time,
        'std_pct': std_pct,
        'tflops_per_sec': tflops_per_sec,
    }


def benchmark_fp8_gemm(batch_size, out_features, in_features, backend='cutlass'):
    """Benchmark FP8 gemm_fp8_nt_groupwise."""
    m, n, k = batch_size, out_features, in_features
    scale_major_mode = 'MN'
    
    # Create tensors
    a_fp8, b_fp8, a_scale, b_scale, out = create_fp8_tensors(m, n, k, scale_major_mode)
    
    # Warmup
    for _ in range(5):
        gemm_fp8_nt_groupwise(
            a_fp8, b_fp8, a_scale, b_scale,
            scale_major_mode=scale_major_mode,
            mma_sm=1,
            out=out,
            backend=backend,
        )
    torch.cuda.synchronize()
    
    # Benchmark function
    def benchmark_fn():
        return gemm_fp8_nt_groupwise(
            a_fp8, b_fp8, a_scale, b_scale,
            scale_major_mode=scale_major_mode,
            mma_sm=1,
            out=out,
            backend=backend,
        )
    
    # Run benchmark with CUPTI
    times = bench_gpu_time_with_cupti(
        benchmark_fn,
        l2_flush=True,
        repeat_iters=100,
    )
    
    # Calculate statistics (convert ms to us)
    times_np = np.array(times) * 1000  # ms to us
    median_time = np.median(times_np)
    std_time = np.std(times_np)
    std_pct = (std_time / median_time) * 100 if median_time > 0 else 0
    
    # Calculate FLOPS
    flops = 2 * m * n * k
    tflops_per_sec = flops / (median_time * 1e-6) / 1e12
    
    return {
        'median_us': median_time,
        'std_pct': std_pct,
        'tflops_per_sec': tflops_per_sec,
    }


def run_combined_benchmark():
    """Run combined benchmark for all configurations."""
    
    print("=" * 127)
    print("Combined Performance Benchmark: FP16 Linear vs FP8 GEMM (Cutlass vs TRT-LLM)")
    print("=" * 127)
    print("Configurations:")
    print("  - FP16: torch.nn.functional.linear")
    print("  - FP8 Cutlass: gemm_fp8_nt_groupwise (cutlass backend)")
    print("  - FP8 TRT-LLM: gemm_fp8_nt_groupwise (trtllm backend)")
    print("  - Benchmarking with CUPTI, L2 flush enabled")
    print("=" * 127)
    print()
    
    # Test configurations
    shapes = [
        [2048, 1024],
        [2048, 128],
        [256, 2048],
        [2560, 2048],
        [3072, 2048],
    ]
    
    batch_sizes = [128, 256, 512, 1024]
    
    results = []
    
    for out_features, in_features in shapes:
        for batch_size in batch_sizes:
            print(f"\nTesting: Batch={batch_size}, Out={out_features}, In={in_features}")
            
            result = {
                'batch': batch_size,
                'out': out_features,
                'in': in_features,
            }
            
            # Benchmark FP16
            try:
                print("  Running FP16 linear...")
                fp16_result = benchmark_fp16_linear(batch_size, out_features, in_features)
                result['fp16_median'] = fp16_result['median_us']
                result['fp16_std_pct'] = fp16_result['std_pct']
                result['fp16_tflops'] = fp16_result['tflops_per_sec']
            except Exception as e:
                print(f"  FP16 ERROR: {e}")
                result['fp16_median'] = None
                result['fp16_std_pct'] = None
                result['fp16_tflops'] = None
            
            # Benchmark FP8 Cutlass
            try:
                print("  Running FP8 Cutlass...")
                cutlass_result = benchmark_fp8_gemm(batch_size, out_features, in_features, backend='cutlass')
                result['cutlass_median'] = cutlass_result['median_us']
                result['cutlass_std_pct'] = cutlass_result['std_pct']
                result['cutlass_tflops'] = cutlass_result['tflops_per_sec']
                
                # Calculate ratio
                if result['fp16_median'] is not None:
                    result['cutlass_ratio'] = result['fp16_median'] / result['cutlass_median']
                else:
                    result['cutlass_ratio'] = None
            except Exception as e:
                print(f"  Cutlass ERROR: {e}")
                result['cutlass_median'] = None
                result['cutlass_std_pct'] = None
                result['cutlass_tflops'] = None
                result['cutlass_ratio'] = None
            
            # Benchmark FP8 TRT-LLM
            try:
                print("  Running FP8 TRT-LLM...")
                trtllm_result = benchmark_fp8_gemm(batch_size, out_features, in_features, backend='trtllm')
                result['trtllm_median'] = trtllm_result['median_us']
                result['trtllm_std_pct'] = trtllm_result['std_pct']
                result['trtllm_tflops'] = trtllm_result['tflops_per_sec']
                
                # Calculate ratio
                if result['fp16_median'] is not None:
                    result['trtllm_ratio'] = result['fp16_median'] / result['trtllm_median']
                else:
                    result['trtllm_ratio'] = None
            except Exception as e:
                print(f"  TRT-LLM ERROR: {e}")
                result['trtllm_median'] = None
                result['trtllm_std_pct'] = None
                result['trtllm_tflops'] = None
                result['trtllm_ratio'] = None
            
            results.append(result)
    
    # Print summary table
    print("\n" + "=" * 127)
    print("SUMMARY TABLE")
    print("=" * 127)
    # First header line with column categories - must align "|" with data rows
    # Data format: {5} {5} {6} | {10} {7} | {10} {7} {14} | {10} {7} {14}
    # Section widths after "|": 18 | 33 | 33
    print(f"{'Out':>5} {'In':>5} {'Batch':>6} | "
          f"{'FP16':^18} | "
          f"{'FP8 Cutlass':^33} | "
          f"{'FP8 TRT-LLM':^33}")
    # Second header line with specific metrics - must match exact spacing of data rows
    print(f"{'':>5} {'':>5} {'':>6} | "
          f"{'Median(us)':>10} {'Std(%)':>7} | "
          f"{'Median(us)':>10} {'Std(%)':>7} {'Ratio to FP16':>14} | "
          f"{'Median(us)':>10} {'Std(%)':>7} {'Ratio to FP16':>14}")
    print("-" * 127)
    
    for r in results:
        # FP16 values
        fp16_med = f"{r['fp16_median']:10.2f}" if r['fp16_median'] is not None else "         -"
        fp16_std = f"{r['fp16_std_pct']:7.2f}" if r['fp16_std_pct'] is not None else "      -"
        
        # Cutlass values
        cutlass_med = f"{r['cutlass_median']:10.2f}" if r['cutlass_median'] is not None else "         -"
        cutlass_std = f"{r['cutlass_std_pct']:7.2f}" if r['cutlass_std_pct'] is not None else "      -"
        cutlass_ratio = f"{r['cutlass_ratio']:14.2f}" if r['cutlass_ratio'] is not None else "             -"
        
        # TRT-LLM values
        trtllm_med = f"{r['trtllm_median']:10.2f}" if r['trtllm_median'] is not None else "         -"
        trtllm_std = f"{r['trtllm_std_pct']:7.2f}" if r['trtllm_std_pct'] is not None else "      -"
        trtllm_ratio = f"{r['trtllm_ratio']:14.2f}" if r['trtllm_ratio'] is not None else "             -"
        
        print(f"{r['out']:5d} {r['in']:5d} {r['batch']:6d} | "
              f"{fp16_med} {fp16_std} | "
              f"{cutlass_med} {cutlass_std} {cutlass_ratio} | "
              f"{trtllm_med} {trtllm_std} {trtllm_ratio}")
    
    print("=" * 127)
    print("Notes:")
    print("  - Median: Median execution time in microseconds")
    print("  - Std(%): Standard deviation as percentage of median")
    print("  - Ratio to FP16: FP16 time / FP8 time (higher is better for FP8)")
    print("  - '-' indicates unsupported configuration")
    print("=" * 127)
    
    return results


if __name__ == "__main__":
    results = run_combined_benchmark()

Additional Information

  • Benchmark uses CUPTI profiling with L2 cache flushing for accurate measurements
  • All measurements are median of 100 iterations
  • FP16 uses PyTorch's default cuBLAS implementation
  • FP8 uses MN scale mode (per-token activation, blockwise weight quantization)
  • Actual shapes were taken from Qwen3-next model. But I expect we have similar behavior for another shapes.

Any improving FP8 performance would be greatly appreciated!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions