Skip to content

Commit f937055

Browse files
authored
Benchmark results for sdpa operation on Blackwell (#146)
The benchmarking script in this current directory profiles scaled dot product attention (SDPA) from various backends. Here we benchmark attention layer dimensions inspired by Llama-3.1-405B with sequence lengths ranging from 512 to 131,072.
1 parent 576e0ed commit f937055

File tree

8 files changed

+1146
-3
lines changed

8 files changed

+1146
-3
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
FROM nvcr.io/nvidia/pytorch:25.05-py3
2+
3+
RUN pip install --upgrade pip && \
4+
pip install seaborn
5+
6+
RUN apt-get update && \
7+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
8+
dpkg -i cuda-keyring_1.1-1_all.deb && \
9+
apt-get update && \
10+
apt-get -y install cudnn
11+
12+
RUN pip uninstall -y cudnn
13+
14+
COPY benchmark_sdpa.py .
15+
16+
COPY benchmark_single_sdpa.py .
17+
18+
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
19+
20+
WORKDIR /workspace
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Scaled Dot Product Attention Benchmark
2+
## Introduction
3+
4+
The benchmarking script in this current directory profiles scaled dot product attention (SDPA) from various backends. Here we benchmark attention layer dimensions inspired by [Llama-3.1-405B](https://ai.meta.com/blog/meta-llama-3-1/) with sequence lengths ranging from 512 to 131,072.
5+
6+
The provided benchmark targets training use cases--causal masking is enabled for grouped query attention (GQA). Layer dimensions and causal masking can be altered by modifying the preset parameters in `benchmark_sdpa.py`. Inference-specific attention optimizations such as paged attention are not benchmarked at this time.
7+
8+
## Contents
9+
10+
- `Dockerfile` to create a Docker container for the dependencies and run the benchmark.
11+
- `benchmark_sdpa.py` which runs cudnn, pytorch, and other backends up to 128k sequence length.
12+
- Benchmark results on B200 in the `artifacts` directory.
13+
- Useful Python scripts for running single attention layers:
14+
- `benchmark_single_sdpa.py` for benchmarking a single flash attention instance from various backends.
15+
- See below for usage example.
16+
17+
## Software versions
18+
19+
This benchmark code should run on any decently modern Python environment with CUDA-enabled GPU. The results in `artifacts` were collected using the PyTorch docker image [from the NVIDIA GPU CLOUD (NGC) catalog](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), `nvcr.io/nvidia/pytorch:25.05-py3`, where cuDNN 9.10.2 was used. We provide a `Dockerfile` to reproduce the environment with the following library versions
20+
21+
22+
| Software | Version |
23+
|----------------|---------|
24+
| Python | 3.12.9 |
25+
| CUDA | 12.9.0 |
26+
| cuDNN | 9.10.2 |
27+
| PyTorch | 2.8.0 |
28+
| FlashAttention | 2.7.3 |
29+
30+
31+
## Steps to run
32+
### 0. *Optional*: Lock Clocks
33+
Although the benchmarking code inserts dynamically-sized delays to avoid GPU throttling, most reproducible results can be obtained when clocks are locked. For example, use `nvidia-smi -q -d SUPPORTED_CLOCKS` to get the supported clocks
34+
35+
```
36+
sudo nvidia-smi -pm 1
37+
nvidia-smi -lgc <min_clock>,<max_clock>
38+
```
39+
40+
### 1. Build docker container
41+
Launch the docker build and run. We prodivde a simple `Dockerfile` to help run the benchmark
42+
```
43+
docker build -t cudnn_attention_benchmark .
44+
docker run -it --gpus all --rm -v $(pwd):/workspace cudnn_attention_benchmark
45+
```
46+
47+
### 2. Run Benchmark script
48+
The `benchmark_sdpa.py` executes a predefined set of attention layers of various sequence lengths, where the transformer dimensions are inspired by [Llama-3.1-405B](https://ai.meta.com/blog/meta-llama-3-1/) (`num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`)
49+
50+
The following scaled dot product attention backends are benchmarked:
51+
- [PyTorch's SDPA backends](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html):
52+
- cuDNN (`CUDNN_ATTENTION`)
53+
- Standard Attention (`MATH`)
54+
- FlashAttention-2 (`FLASH_ATTENTION`; PyTorch FAv2 )
55+
- [FlashAttention-2](https://github.com/Dao-AILab/flash-attention)'s original implementation (native FAv2)
56+
57+
Please note that FlashAttention-3 is currently not supported on NVIDIA's Blackwell generation GPUs.
58+
59+
Sample output
60+
```
61+
$ python3 benchmark_sdpa.py
62+
[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.05'
63+
[INFO] torch.version.cuda = '12.9'
64+
[INFO] torch.cuda.is_available() = True
65+
[INFO] torch.cuda.device_count() = 1
66+
[INFO] torch.cuda.current_device() = 0
67+
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200'
68+
[INFO] torch.backends.cudnn.version() = 91002
69+
[INFO] torch.backends.cudnn.enabled = True
70+
[INFO] flash_attn.__version__ = '2.7.3'
71+
[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)
72+
[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128), (2, 131072, 131072, 128, 8, 128)]
73+
[INFO] Running layer (1, 512, 512, 128, 8, 128)
74+
...
75+
[INFO] Saving results to ./artifacts/sdpa_benchmark_results_NVIDIA_B200.csv
76+
[INFO] Saving plot to ./artifacts/sdpa_benchmark_results_NVIDIA_B200.png
77+
```
78+
79+
Benchmarked performance numbers are stored in the [artifacts](artifacts) directory as csv and png files.
80+
81+
## Results
82+
Below are the result of the benchmark running on a single B200 GPU.
83+
84+
For both runs, the following software versions are used:
85+
86+
- CUDA: 12.9 (from NGC container)
87+
- PyTorch: 2.8.0 (from NGC container)
88+
- cuDNN: 9.10.2 (Installed via `apt-get`; see `Dockerfile`)
89+
90+
91+
#### B200
92+
![Comparison of pytorch and cudnn](artifacts/sdpa_benchmark_results_NVIDIA_B200.png)
93+
- SDPA parameters were used `num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`.
94+
- Batch size and sequence lengths are shown in the x-axis.
95+
- Results were obtained on an NVIDIA B200 GPU with free clock.
96+
97+
## Pytorch adoption
98+
As demonstrated can be seen from the results, cuDNN v9 can achieve over 2x the performance of the comparable PyTorch eager implementation. Refer to [PyTorch's scaled_dot_product_attention()](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and [sdpa_kernel](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) context manager documentations for enabling the cuDNN backend for scaled dot product attention.
99+
100+
## `benchmark_single_sdpa.py`
101+
`benchmark_single_sdpa.py` is provided to conveniently run a single SDPA operation. Try running `python benchmark_single_sdpa.py --help` to see available flags.
102+
103+
Example commands and outputs:
104+
```
105+
## For running various PyTorch backends (FlashAttention, cuDNN, ...) or FlashAttention-2:
106+
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --is_causal --data_type bfloat16 --num_iterations 10 --sdpa_backend pyt_cudnn --fwd_bwd
107+
pyt_cudnn:: Median (fwd, bwd) Execution Times: 24.602 ms (1430 TFLOPS), 78.140 ms (1126 TFLOPS) (max difference vs. pyt_reference: 0.007812 from 10 iterations)
108+
109+
## For directly running cuDNN via cuDNN Frontend
110+
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --is_causal --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --fwd_bwd
111+
cudnn_fe:: Median (fwd, bwd) Execution Times: 24.480 ms (1437 TFLOPS), 73.519 ms (1196 TFLOPS) (max difference vs. pyt_reference: 0.007812 from 10 iterations)
112+
```
113+
114+
The cuDNN version used in the benchmark can be replaced by setting the `LD_LIBRARY_PATH` environment variable.
115+
```
116+
$ export LD_LIBRARY_PATH=<my_path_to_cuDNN_9.10.1>
117+
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 16384 --kv_seqlen 16384 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --is_causal --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --fwd_bwd --verbose
118+
[INFO] cuDNN Backend Version: cudnn.backend_version() = 91001
119+
[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.11.0'
120+
[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.05'
121+
[INFO] torch.version.cuda = '12.9'
122+
[INFO] torch.cuda.is_available() = True
123+
[INFO] torch.cuda.device_count() = 1
124+
[INFO] torch.cuda.current_device() = 0
125+
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200'
126+
cudnn_fe:: Median (fwd, bwd) Execution Times: 6.421 ms (1370 TFLOPS), 19.367 ms (1135 TFLOPS) (max difference vs. pyt_reference: 0.007812 from 10 iterations)
127+
```
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.05'
2+
[INFO] torch.version.cuda = '12.9'
3+
[INFO] torch.cuda.is_available() = True
4+
[INFO] torch.cuda.device_count() = 1
5+
[INFO] torch.cuda.current_device() = 0
6+
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200'
7+
[INFO] torch.backends.cudnn.version() = 91002
8+
[INFO] torch.backends.cudnn.enabled = True
9+
[INFO] flash_attn.__version__ = '2.7.3'
10+
[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)
11+
[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128), (2, 131072, 131072, 128, 8, 128)]
12+
[INFO] Running layer (1, 512, 512, 128, 8, 128)
13+
[INFO] Benchmarking backend pyt_math
14+
[INFO] Benchmarking backend pyt_cudnn
15+
[INFO] Benchmarking backend pyt_flash_attention
16+
[INFO] Benchmarking backend flash_attention
17+
[INFO] Running layer (1, 1024, 1024, 128, 8, 128)
18+
[INFO] Benchmarking backend pyt_math
19+
[INFO] Benchmarking backend pyt_cudnn
20+
[INFO] Benchmarking backend pyt_flash_attention
21+
[INFO] Benchmarking backend flash_attention
22+
[INFO] Running layer (1, 2048, 2048, 128, 8, 128)
23+
[INFO] Benchmarking backend pyt_math
24+
[INFO] Benchmarking backend pyt_cudnn
25+
[INFO] Benchmarking backend pyt_flash_attention
26+
[INFO] Benchmarking backend flash_attention
27+
[INFO] Running layer (1, 4096, 4096, 128, 8, 128)
28+
[INFO] Benchmarking backend pyt_math
29+
[INFO] Benchmarking backend pyt_cudnn
30+
[INFO] Benchmarking backend pyt_flash_attention
31+
[INFO] Benchmarking backend flash_attention
32+
[INFO] Running layer (1, 8192, 8192, 128, 8, 128)
33+
[INFO] Benchmarking backend pyt_math
34+
[INFO] Benchmarking backend pyt_cudnn
35+
[INFO] Benchmarking backend pyt_flash_attention
36+
[INFO] Benchmarking backend flash_attention
37+
[INFO] Running layer (1, 16384, 16384, 128, 8, 128)
38+
[INFO] Benchmarking backend pyt_math
39+
[INFO] Benchmarking backend pyt_cudnn
40+
[INFO] Benchmarking backend pyt_flash_attention
41+
[INFO] Benchmarking backend flash_attention
42+
[INFO] Running layer (1, 32768, 32768, 128, 8, 128)
43+
[INFO] Benchmarking backend pyt_math
44+
[INFO] Benchmarking backend pyt_cudnn
45+
[INFO] Benchmarking backend pyt_flash_attention
46+
[INFO] Benchmarking backend flash_attention
47+
[INFO] Running layer (1, 65536, 65536, 128, 8, 128)
48+
[INFO] Benchmarking backend pyt_math
49+
[INFO] Benchmarking backend pyt_cudnn
50+
[INFO] Benchmarking backend pyt_flash_attention
51+
[INFO] Benchmarking backend flash_attention
52+
[INFO] Running layer (1, 131072, 131072, 128, 8, 128)
53+
[INFO] Benchmarking backend pyt_math
54+
[INFO] Benchmarking backend pyt_cudnn
55+
[INFO] Benchmarking backend pyt_flash_attention
56+
[INFO] Benchmarking backend flash_attention
57+
[INFO] Running layer (2, 131072, 131072, 128, 8, 128)
58+
[INFO] Benchmarking backend pyt_math
59+
[INFO] Benchmarking backend pyt_cudnn
60+
[INFO] Benchmarking backend pyt_flash_attention
61+
[INFO] Benchmarking backend flash_attention
62+
[INFO] Saving results to ./artifacts/sdpa_benchmark_results_NVIDIA_B200.csv
63+
[INFO] Saving plot to ./artifacts/sdpa_benchmark_results_NVIDIA_B200.png
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,is_causal,backend,forward_time,backward_time,fwd_tflops_per_sec,bwd_tflops_per_sec
2+
1,512,512,128,8,128,True,pyt_math,0.694,0.739,12.372,29.046
3+
1,512,512,128,8,128,True,pyt_cudnn,0.130,0.494,66.280,43.512
4+
1,512,512,128,8,128,True,pyt_flash_attention,0.142,0.514,60.602,41.808
5+
1,512,512,128,8,128,True,flash_attention,0.147,0.629,58.451,34.139
6+
1,1024,1024,128,8,128,True,pyt_math,1.823,1.356,18.844,63.354
7+
1,1024,1024,128,8,128,True,pyt_cudnn,0.159,0.482,216.110,178.132
8+
1,1024,1024,128,8,128,True,pyt_flash_attention,0.233,0.768,147.330,111.864
9+
1,1024,1024,128,8,128,True,flash_attention,0.235,0.754,146.187,113.956
10+
1,2048,2048,128,8,128,True,pyt_math,6.829,4.000,20.125,85.893
11+
1,2048,2048,128,8,128,True,pyt_cudnn,0.262,0.895,525.475,383.705
12+
1,2048,2048,128,8,128,True,pyt_flash_attention,0.548,1.765,250.977,194.686
13+
1,2048,2048,128,8,128,True,flash_attention,0.530,1.635,259.139,210.091
14+
1,4096,4096,128,8,128,True,pyt_math,27.183,14.169,20.225,96.998
15+
1,4096,4096,128,8,128,True,pyt_cudnn,0.613,1.938,897.309,709.138
16+
1,4096,4096,128,8,128,True,pyt_flash_attention,1.678,5.159,327.554,266.388
17+
1,4096,4096,128,8,128,True,flash_attention,1.581,4.722,347.732,291.053
18+
1,8192,8192,128,8,128,True,pyt_math,115.349,52.913,19.064,103.897
19+
1,8192,8192,128,8,128,True,pyt_cudnn,1.816,5.815,1211.140,945.440
20+
1,8192,8192,128,8,128,True,pyt_flash_attention,5.848,17.781,376.059,309.174
21+
1,8192,8192,128,8,128,True,flash_attention,5.453,16.392,403.305,335.386
22+
1,16384,16384,128,8,128,True,pyt_math,inf,inf,0.000,0.000
23+
1,16384,16384,128,8,128,True,pyt_cudnn,6.475,20.466,1358.566,1074.499
24+
1,16384,16384,128,8,128,True,pyt_flash_attention,22.158,66.431,396.971,331.025
25+
1,16384,16384,128,8,128,True,flash_attention,20.522,62.001,428.611,354.678
26+
1,32768,32768,128,8,128,True,pyt_math,inf,inf,0.000,0.000
27+
1,32768,32768,128,8,128,True,pyt_cudnn,24.518,77.677,1435.072,1132.387
28+
1,32768,32768,128,8,128,True,pyt_flash_attention,86.581,257.176,406.373,342.027
29+
1,32768,32768,128,8,128,True,flash_attention,80.170,241.841,438.873,363.713
30+
1,65536,65536,128,8,128,True,pyt_math,inf,inf,0.000,0.000
31+
1,65536,65536,128,8,128,True,pyt_cudnn,98.489,327.462,1428.973,1074.456
32+
1,65536,65536,128,8,128,True,pyt_flash_attention,342.696,1015.375,410.677,346.516
33+
1,65536,65536,128,8,128,True,flash_attention,317.255,958.828,443.610,366.952
34+
1,131072,131072,128,8,128,True,pyt_math,inf,inf,0.000,0.000
35+
1,131072,131072,128,8,128,True,pyt_cudnn,417.410,1355.987,1348.674,1037.897
36+
1,131072,131072,128,8,128,True,pyt_flash_attention,1366.320,4051.869,412.019,347.340
37+
1,131072,131072,128,8,128,True,flash_attention,1264.455,3845.172,445.212,366.011
38+
2,131072,131072,128,8,128,True,pyt_math,inf,inf,0.000,0.000
39+
2,131072,131072,128,8,128,True,pyt_cudnn,854.100,2736.916,1318.230,1028.438
40+
2,131072,131072,128,8,128,True,pyt_flash_attention,2731.190,8108.904,412.238,347.118
41+
2,131072,131072,128,8,128,True,flash_attention,2527.642,7692.665,445.435,365.900
120 KB
Loading

0 commit comments

Comments
 (0)