Skip to content

GPU performance model significantly overestimates INT8 GEMM speedup for small/medium shapes #40680

@swjng

Description

@swjng

Summary

The GPU performance model (gpu_performance_model_base.cc) significantly overestimates INT8 GEMM speedup over FP32, especially for small and medium tensor shapes. Measured speedups on RTX 3080 show up to 576% prediction error at medium shapes, and INT8 being slower than FP32 at small shapes — a regime the model cannot predict at all.

Measured Data (RTX 3080, PyTorch torch._int_mm, 1330 shapes)

Shape (M=K=N) XLA Roofline Prediction Measured Speedup Prediction Error
64 2.01x 0.98x (INT8 slower!) 105%
128 2.03x 0.87x 133%
512 3.91x 0.58x (INT8 much slower!) 576%
1024 3.95x 2.09x 89%
2048 3.97x 2.30x 73%
4096 3.99x 2.75x 45%

Mean absolute error across 1330 shapes: 147%

Aggregated by FLOPs magnitude

FLOPs range # shapes Avg measured speedup XLA would predict
<10M 149 0.93x (slower) ~2.0x
10M-1B 708 0.93x (breakeven) ~2-4x
1B-10B 356 1.66x ~4x
10B-100B 113 2.48x ~4x
>100B 4 2.76x ~4x

Root Causes

  1. Missing kernel launch + quantization overhead: At small shapes, the overhead of INT8 quantize/dequantize dominates the compute savings. The model only considers peak TOPS and memory bandwidth.

  2. Peak utilization never reached: Even at large shapes (4096³), measured INT8 speedup is 2.75x vs the theoretical 4x peak ratio. The model assumes near-peak utilization.

  3. The code acknowledges this: gpu_hlo_cost_analysis.cc contains the comment: "this is technically incorrect if the element type of this gemm is an integer type, because in that case no floating point operations are involved at all!"

Impact

This misprediction affects any XLA pass that uses the cost model to decide whether to lower an op to INT8 (or to evaluate INT8 graph alternatives). For mixed-precision optimization, the model predicts the wrong quantization decision for shapes below ~10B FLOPs.

Suggested Improvement

A fitted sigmoid model captures the shape-dependent behavior well (R²=0.73 on 1330 measurements):

INT8_speedup(FLOPs) = 1.87 × sigmoid(2.58 × (log10(FLOPs) - 9.6)) + 0.88

This adds one shape-dependent lookup to CalculatePeakMatrixOpsPerNs or a post-hoc correction in EstimateRunTimeForInstruction for integer-typed GEMMs.

Reproduction

Benchmark script (PyTorch, any CUDA GPU):

import torch, time

def bench(fn, warmup=30, repeat=100):
    for _ in range(warmup): fn()
    torch.cuda.synchronize()
    times = []
    for _ in range(repeat):
        torch.cuda.synchronize()
        t0 = time.perf_counter(); fn(); torch.cuda.synchronize()
        times.append((time.perf_counter()-t0)*1000)
    times.sort()
    return times[len(times)//2]

for N in [64, 128, 256, 512, 1024, 2048, 4096]:
    A = torch.randn(N, N, device='cuda')
    B = torch.randn(N, N, device='cuda')
    lat_fp32 = bench(lambda: torch.mm(A, B))
    A8 = (A/A.abs().max()*127).round().clamp(-128,127).to(torch.int8)
    B8 = (B/B.abs().max()*127).round().clamp(-128,127).to(torch.int8)
    lat_int8 = bench(lambda: torch._int_mm(A8, B8))
    print(f"{N}x{N}: FP32={lat_fp32:.4f}ms INT8={lat_int8:.4f}ms speedup={lat_fp32/lat_int8:.2f}x")

Hardware: RTX 3080 (SM 8.6), CUDA 13.1, PyTorch 2.5.1

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions