Skip to content

hbarclay/gemm-model

Repository files navigation

A simple model that estimates the performance of warp-specialized persistent GEMMs on GPUs. The model includes some analytical components as well as some empirical factors that could be tuned to different hardware or implementations by measuring real performance on silicon.

Also included is a script to collect some sample data from Cutlass CuTeDSL kernels on B200.

Usage

Setup / Quick Start

The project has only been tested on a B200 machine with CUDA TK 13.0 and CUDA driver major version 580

git clone <repo>
cd <repo-name>
uv sync

Testlists

One testlist is included. It consists of DeepSeek-V3 GEMM sizes generated by Grok for both fwd and bwd, with the sequence dim set to 4096.

Collect Benchmark Data

benchmark_gemm.py takes a testlist as input and sweeps across a list of mma shapes and cluster sizes for each input problem. For each problem shape, cluster shape, and mma shape, a warp-specialized persistent GEMM kernel from the cutlass CuTeDSL examples is benchmarked.

Some simplifying assumptions were made in this script:

  • Only 2CTA instructions are covered
  • Only a subset of the valid tile sizes are tested
  • e5m2 is not covered
  • M,N, and K are all padded to 16-byte-aligned sizes
  • All kernels are run with k-major input, n-major output layouts

Don't forget to lock clocks before benchmarking.

nvidia-smi -pm 1
nvidia-smi -lgc 1300,1300
uv run benchmark_gemm.py --input_csv=testlists/dsv3.csv --dtype=nvfp4 --output_csv=results/dsv3_nvfp4.csv

Model Configuration

Model configuration defaults can be found in config/

Get Model Predictions

uv run predict_gemm.py +input_csv=results/dsv3_nvfp4.csv +output_dir=results/nvfp4/ model=WSPersistentGEMMModel

This will produce an output csv and a bar plot

Implemented GEMM Models

SOL Model

The SOL represents the "speed of light" performance based on the hardware peak math and dram throughputs. The runtime is computed as:

$$\mathrm{runtime} = \max(\mathrm{math}, \mathrm{DRAM})$$

where $\mathrm{math}$ is the SOL time to do $2 \cdot M \cdot N \cdot K$ ops, and $\mathrm{DRAM}$ is the SOL time to both read A and B and write C.

WSPersistentGEMM Model

The WSPersistentGEMM (Warp Specialized Persistent GEMM) model represents the kernel as a sequence of phases, where each phase consists of overlapping workloads each corresponding to the 'specialized' task of each warp. At steady-state, the "mainloop" consists of overlapping DMA (memory read), Math (MMA), and Epilogue (SIMT and memory write) workloads.

Because the workload for each output tile is the same, all mainloop iterations will have the same limiter. The prologue consists of any launch latency, some setup time, and the time to complete the load of the first input tiles. Note that on the last wave, the epilogue will always be exposed.

Blackwell Persistent Warp-Specialized GEMM Design: view from a single CTA for 4 waves

This model does not consider the pipelining of the K-loop; that is, the pipelining of loads and MMAs along the inner (K-dimension), or the inner loop of the GEMM. It simply considers the entire K-dimesion grouped together.

This model computes the runtime as:

$$\mathrm{runtime} = \mathrm{setup\_overhead} + \mathrm{first\_DMA} + \mathrm{mainloop} + \mathrm{last\_wave\_epilogue}$$

where $\mathrm{mainloop}$ is $\max(\mathrm{DMA}, \mathrm{Math}, \mathrm{epilogue})$ for each wave. The last wave can be a partial wave where all SMs are not in use, in which case the math time will be the same but the total read BW utilization might be reduced compared to full waves. This effect is typically referred to as wave-quantization and is implicitly modeled here.

The epilogue is modeled as a constant overhead combined with the time to write the output.

The DMA workload for each SM is modeled as the total load bytes (including both the input tile and scale factors for block-scaled inputs) divided by the cluster size in the multicast dimension for each input tensor (n for A, and m for B). For simplicity, in/out layouts and related kernel functionality are not considered.

Currently, the model does not consider the assignment of CTAs to output tiles ("rasterization") or the corresponding reuse of input tiles in L2 cache. This means that GEMMs with DMA-bound mainloops show a large gap, with the model over-estimating the total DRAM traffic, resulting in an overly pessimistic runtime prediction.

Results

Here, the measured performance of CuTeDSL kernels is compared with perf model predictions for NVFP4, FP8, and FP16 on B200. The clocks are locked to 1.3GHz, which should be low enough to ensure no thermal throttling occurs and the clock is constant for the duration of the kernel.

First, let's take a look at the kernel performance vs. SOL for nvfp4 (e2m1), fp8, and fp16. The title of each chart displays types as {in_dtype}_{out_dtype}

NVFP4 Kernel Performance vs. SOL

FP8 Kernel Performance vs. SOL

FP16 Kernel Performance vs. SOL

As you can see, these kernels mostly don't come close to SOL, and the charts don't make the model feel very useful to predict kernel runtime. Math on Blackwell is very fast, so there must be a lot of math per wave to get close to SOL for math-bound shapes. For memory-bound shapes, latencies and other fixed overheads will dominate and prevent 100% DRAM utilization, especially for these tests which ran with relatively large or poorly-chosen tile shapes.

Now, let's look at some results from the WSPersistentGEMM Model. Again, results are displayed as performance vs. model prediction for nvfp4, fp8, and fp16.

NVFP4 Kernel Performance vs. SOL

FP8 Kernel Performance vs. SOL

FP16 Kernel Performance vs. SOL

Better than just SOL! Now, the perf ratios are centered at 1.0x, and real kernel performance is at most 2x faster or 2x slower than the model prediction. Let's take a look at a case with one of the most pessimistic (highest-ratio, rightmost) predictions for nvfp4:

MNK 4096 4096 16384
CTA (128 64)
CLUSTER (2 1)
RATIO: 1.6697582628788883
    predicted: 376.1631394230768
    actual: 225.27999877929688
PROLOGUE
    Overhead: 6.153846153846154
    DMA: 0.1040625
MAINLOOP
    Limiter: DMA
        DMA: 26.64
        MATH: 6.3015384615384615
        EPILOG: 1.3612307692307695
    Last Wave
        DMA: 22.32
        MATH: 6.3015384615384615
        EPILOG: 1.2652307692307694

A DMA-bound kernel, where the model over-estimates the global load bytes. All the other most pessimistic cases are similar. This issue is present because the model does not model the assignment of CTAs to output tiles and the corresponding reuse of input tiles in the L2 cache when counting global loads. Due to time constraints, rather than model raster order and L2 cache directly, a knob was added to represent the global L2 hit rate across all reads. This change is almost certainly too heavy-handed, and a slightly more sophisticated L2 model could probably offer a lot better accuracy. Here are the results with L2 hit rate fixed at 0.4:

NVFP4 Kernel Performance vs. SOL

FP8 Kernel Performance vs. SOL

FP16 Kernel Performance vs. SOL

The change indeed reduced the gap for the most pessimistic cases. The change will not affect the predictions for math- or epilogue-bound cases. Now, the largest ratio is 1.2x for fp16 and fp8.

Now, let's take a look at the lowest-ratio (leftmost) case from fp8:

MNK 4096 7168 257
CTA (64 256)
CLUSTER (2 1)
RATIO: 0.5794853639180193
    predicted: 20.65007692307692
    actual: 35.63520014286041
PROLOGUE
    Overhead: 6.153846153846154
    DMA: 0.111
MAINLOOP
    Limiter: EPILOG
        DMA: 0.89146875
        MATH: 0.3953846153846154
        EPILOG: 1.0652307692307692
    Last Wave
        DMA: 0.096375
        MATH: 0.3953846153846154
        EPILOG: 0.8012307692307692

The K dimension is very small, so this kernel is epilogue-bound. Indeed, checking more of the low-ratio cases indicates the model poorly predicts epilogue-bound performance. Since the epilogue model used here is very simple (just a constant), it is expected that we achieve low accuracy for epilogue-bound shapes. One option would be to tune the estimated constant factor in the epilogue latency by collecting some empirical data. Another would be to model the epilogue in more detail by better reflecting the real kernel implementation.

Conclusion

The model presented here is a very simple representation of the perf-relevant structure of warp-specialized GEMM kernels, tested for B200 on several data types. It is a signficant improvement over a naive SOL model for predicting the performance of GEMM kernels of arbitrary shapes, improving from roughly 40-50% average ratio for the naive SOL model to around 80% for the custom model, with a much better worst-case ratio. The model as-written should be easily portable to different hardware (i.e. Hopper or RTX Blackwell) and extensible to other algorithms (i.e. implicit_gemm or GEMM+X). With additional details about the hardware and software, the model accuracy can be improved. Based on the results, the following additions would improve the accuracy for the worst-predicted cases:

  • L2 cache
  • CTA raterization
  • More detailed epilogue model
  • SM-level limiter, with empirical data about the throughput of different HW units (i.e. can a specific tile size be bound on L2 or SMEM?)

About

A simple GPU gemm kernel model

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages