A simple model that estimates the performance of warp-specialized persistent GEMMs on GPUs. The model includes some analytical components as well as some empirical factors that could be tuned to different hardware or implementations by measuring real performance on silicon.
Also included is a script to collect some sample data from Cutlass CuTeDSL kernels on B200.
The project has only been tested on a B200 machine with CUDA TK 13.0 and CUDA driver major version 580
git clone <repo>
cd <repo-name>
uv sync
One testlist is included. It consists of DeepSeek-V3 GEMM sizes generated by Grok for both fwd and bwd, with the sequence dim set to 4096.
benchmark_gemm.py takes a testlist as input and sweeps across a list of mma shapes and cluster sizes for each input problem. For each problem shape, cluster shape, and mma shape, a warp-specialized persistent GEMM kernel from the cutlass CuTeDSL examples is benchmarked.
Some simplifying assumptions were made in this script:
- Only 2CTA instructions are covered
- Only a subset of the valid tile sizes are tested
- e5m2 is not covered
- M,N, and K are all padded to 16-byte-aligned sizes
- All kernels are run with k-major input, n-major output layouts
Don't forget to lock clocks before benchmarking.
nvidia-smi -pm 1
nvidia-smi -lgc 1300,1300
uv run benchmark_gemm.py --input_csv=testlists/dsv3.csv --dtype=nvfp4 --output_csv=results/dsv3_nvfp4.csv
Model configuration defaults can be found in config/
uv run predict_gemm.py +input_csv=results/dsv3_nvfp4.csv +output_dir=results/nvfp4/ model=WSPersistentGEMMModel
This will produce an output csv and a bar plot
The SOL represents the "speed of light" performance based on the hardware peak math and dram throughputs. The runtime is computed as:
where
The WSPersistentGEMM (Warp Specialized Persistent GEMM) model represents the kernel as a sequence of phases, where each phase consists of overlapping workloads each corresponding to the 'specialized' task of each warp. At steady-state, the "mainloop" consists of overlapping DMA (memory read), Math (MMA), and Epilogue (SIMT and memory write) workloads.
Because the workload for each output tile is the same, all mainloop iterations will have the same limiter. The prologue consists of any launch latency, some setup time, and the time to complete the load of the first input tiles. Note that on the last wave, the epilogue will always be exposed.
This model does not consider the pipelining of the K-loop; that is, the pipelining of loads and MMAs along the inner (K-dimension), or the inner loop of the GEMM. It simply considers the entire K-dimesion grouped together.
This model computes the runtime as:
where
The epilogue is modeled as a constant overhead combined with the time to write the output.
The DMA workload for each SM is modeled as the total load bytes (including both the input tile and scale factors for block-scaled inputs) divided by the cluster size in the multicast dimension for each input tensor (n for A, and m for B). For simplicity, in/out layouts and related kernel functionality are not considered.
Currently, the model does not consider the assignment of CTAs to output tiles ("rasterization") or the corresponding reuse of input tiles in L2 cache. This means that GEMMs with DMA-bound mainloops show a large gap, with the model over-estimating the total DRAM traffic, resulting in an overly pessimistic runtime prediction.
Here, the measured performance of CuTeDSL kernels is compared with perf model predictions for NVFP4, FP8, and FP16 on B200. The clocks are locked to 1.3GHz, which should be low enough to ensure no thermal throttling occurs and the clock is constant for the duration of the kernel.
First, let's take a look at the kernel performance vs. SOL for nvfp4 (e2m1), fp8, and fp16. The title of each chart displays types as {in_dtype}_{out_dtype}
As you can see, these kernels mostly don't come close to SOL, and the charts don't make the model feel very useful to predict kernel runtime. Math on Blackwell is very fast, so there must be a lot of math per wave to get close to SOL for math-bound shapes. For memory-bound shapes, latencies and other fixed overheads will dominate and prevent 100% DRAM utilization, especially for these tests which ran with relatively large or poorly-chosen tile shapes.
Now, let's look at some results from the WSPersistentGEMM Model. Again, results are displayed as performance vs. model prediction for nvfp4, fp8, and fp16.
Better than just SOL! Now, the perf ratios are centered at 1.0x, and real kernel performance is at most 2x faster or 2x slower than the model prediction. Let's take a look at a case with one of the most pessimistic (highest-ratio, rightmost) predictions for nvfp4:
MNK 4096 4096 16384
CTA (128 64)
CLUSTER (2 1)
RATIO: 1.6697582628788883
predicted: 376.1631394230768
actual: 225.27999877929688
PROLOGUE
Overhead: 6.153846153846154
DMA: 0.1040625
MAINLOOP
Limiter: DMA
DMA: 26.64
MATH: 6.3015384615384615
EPILOG: 1.3612307692307695
Last Wave
DMA: 22.32
MATH: 6.3015384615384615
EPILOG: 1.2652307692307694
A DMA-bound kernel, where the model over-estimates the global load bytes. All the other most pessimistic cases are similar. This issue is present because the model does not model the assignment of CTAs to output tiles and the corresponding reuse of input tiles in the L2 cache when counting global loads. Due to time constraints, rather than model raster order and L2 cache directly, a knob was added to represent the global L2 hit rate across all reads. This change is almost certainly too heavy-handed, and a slightly more sophisticated L2 model could probably offer a lot better accuracy. Here are the results with L2 hit rate fixed at 0.4:
The change indeed reduced the gap for the most pessimistic cases. The change will not affect the predictions for math- or epilogue-bound cases. Now, the largest ratio is 1.2x for fp16 and fp8.
Now, let's take a look at the lowest-ratio (leftmost) case from fp8:
MNK 4096 7168 257
CTA (64 256)
CLUSTER (2 1)
RATIO: 0.5794853639180193
predicted: 20.65007692307692
actual: 35.63520014286041
PROLOGUE
Overhead: 6.153846153846154
DMA: 0.111
MAINLOOP
Limiter: EPILOG
DMA: 0.89146875
MATH: 0.3953846153846154
EPILOG: 1.0652307692307692
Last Wave
DMA: 0.096375
MATH: 0.3953846153846154
EPILOG: 0.8012307692307692
The K dimension is very small, so this kernel is epilogue-bound. Indeed, checking more of the low-ratio cases indicates the model poorly predicts epilogue-bound performance. Since the epilogue model used here is very simple (just a constant), it is expected that we achieve low accuracy for epilogue-bound shapes. One option would be to tune the estimated constant factor in the epilogue latency by collecting some empirical data. Another would be to model the epilogue in more detail by better reflecting the real kernel implementation.
The model presented here is a very simple representation of the perf-relevant structure of warp-specialized GEMM kernels, tested for B200 on several data types. It is a signficant improvement over a naive SOL model for predicting the performance of GEMM kernels of arbitrary shapes, improving from roughly 40-50% average ratio for the naive SOL model to around 80% for the custom model, with a much better worst-case ratio. The model as-written should be easily portable to different hardware (i.e. Hopper or RTX Blackwell) and extensible to other algorithms (i.e. implicit_gemm or GEMM+X). With additional details about the hardware and software, the model accuracy can be improved. Based on the results, the following additions would improve the accuracy for the worst-predicted cases:
- L2 cache
- CTA raterization
- More detailed epilogue model
- SM-level limiter, with empirical data about the throughput of different HW units (i.e. can a specific tile size be bound on L2 or SMEM?)









