Summary
The GPU performance model (gpu_performance_model_base.cc) significantly overestimates INT8 GEMM speedup over FP32, especially for small and medium tensor shapes. Measured speedups on RTX 3080 show up to 576% prediction error at medium shapes, and INT8 being slower than FP32 at small shapes — a regime the model cannot predict at all.
Measured Data (RTX 3080, PyTorch torch._int_mm, 1330 shapes)
| Shape (M=K=N) |
XLA Roofline Prediction |
Measured Speedup |
Prediction Error |
| 64 |
2.01x |
0.98x (INT8 slower!) |
105% |
| 128 |
2.03x |
0.87x |
133% |
| 512 |
3.91x |
0.58x (INT8 much slower!) |
576% |
| 1024 |
3.95x |
2.09x |
89% |
| 2048 |
3.97x |
2.30x |
73% |
| 4096 |
3.99x |
2.75x |
45% |
Mean absolute error across 1330 shapes: 147%
Aggregated by FLOPs magnitude
| FLOPs range |
# shapes |
Avg measured speedup |
XLA would predict |
| <10M |
149 |
0.93x (slower) |
~2.0x |
| 10M-1B |
708 |
0.93x (breakeven) |
~2-4x |
| 1B-10B |
356 |
1.66x |
~4x |
| 10B-100B |
113 |
2.48x |
~4x |
| >100B |
4 |
2.76x |
~4x |
Root Causes
-
Missing kernel launch + quantization overhead: At small shapes, the overhead of INT8 quantize/dequantize dominates the compute savings. The model only considers peak TOPS and memory bandwidth.
-
Peak utilization never reached: Even at large shapes (4096³), measured INT8 speedup is 2.75x vs the theoretical 4x peak ratio. The model assumes near-peak utilization.
-
The code acknowledges this: gpu_hlo_cost_analysis.cc contains the comment: "this is technically incorrect if the element type of this gemm is an integer type, because in that case no floating point operations are involved at all!"
Impact
This misprediction affects any XLA pass that uses the cost model to decide whether to lower an op to INT8 (or to evaluate INT8 graph alternatives). For mixed-precision optimization, the model predicts the wrong quantization decision for shapes below ~10B FLOPs.
Suggested Improvement
A fitted sigmoid model captures the shape-dependent behavior well (R²=0.73 on 1330 measurements):
INT8_speedup(FLOPs) = 1.87 × sigmoid(2.58 × (log10(FLOPs) - 9.6)) + 0.88
This adds one shape-dependent lookup to CalculatePeakMatrixOpsPerNs or a post-hoc correction in EstimateRunTimeForInstruction for integer-typed GEMMs.
Reproduction
Benchmark script (PyTorch, any CUDA GPU):
import torch, time
def bench(fn, warmup=30, repeat=100):
for _ in range(warmup): fn()
torch.cuda.synchronize()
times = []
for _ in range(repeat):
torch.cuda.synchronize()
t0 = time.perf_counter(); fn(); torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
times.sort()
return times[len(times)//2]
for N in [64, 128, 256, 512, 1024, 2048, 4096]:
A = torch.randn(N, N, device='cuda')
B = torch.randn(N, N, device='cuda')
lat_fp32 = bench(lambda: torch.mm(A, B))
A8 = (A/A.abs().max()*127).round().clamp(-128,127).to(torch.int8)
B8 = (B/B.abs().max()*127).round().clamp(-128,127).to(torch.int8)
lat_int8 = bench(lambda: torch._int_mm(A8, B8))
print(f"{N}x{N}: FP32={lat_fp32:.4f}ms INT8={lat_int8:.4f}ms speedup={lat_fp32/lat_int8:.2f}x")
Hardware: RTX 3080 (SM 8.6), CUDA 13.1, PyTorch 2.5.1
Summary
The GPU performance model (
gpu_performance_model_base.cc) significantly overestimates INT8 GEMM speedup over FP32, especially for small and medium tensor shapes. Measured speedups on RTX 3080 show up to 576% prediction error at medium shapes, and INT8 being slower than FP32 at small shapes — a regime the model cannot predict at all.Measured Data (RTX 3080, PyTorch
torch._int_mm, 1330 shapes)Mean absolute error across 1330 shapes: 147%
Aggregated by FLOPs magnitude
Root Causes
Missing kernel launch + quantization overhead: At small shapes, the overhead of INT8 quantize/dequantize dominates the compute savings. The model only considers peak TOPS and memory bandwidth.
Peak utilization never reached: Even at large shapes (4096³), measured INT8 speedup is 2.75x vs the theoretical 4x peak ratio. The model assumes near-peak utilization.
The code acknowledges this:
gpu_hlo_cost_analysis.cccontains the comment: "this is technically incorrect if the element type of this gemm is an integer type, because in that case no floating point operations are involved at all!"Impact
This misprediction affects any XLA pass that uses the cost model to decide whether to lower an op to INT8 (or to evaluate INT8 graph alternatives). For mixed-precision optimization, the model predicts the wrong quantization decision for shapes below ~10B FLOPs.
Suggested Improvement
A fitted sigmoid model captures the shape-dependent behavior well (R²=0.73 on 1330 measurements):
This adds one shape-dependent lookup to
CalculatePeakMatrixOpsPerNsor a post-hoc correction inEstimateRunTimeForInstructionfor integer-typed GEMMs.Reproduction
Benchmark script (PyTorch, any CUDA GPU):
Hardware: RTX 3080 (SM 8.6), CUDA 13.1, PyTorch 2.5.1