Skip to content

Commit bb37575

Browse files
committed
cudnn frontend v1.15 is the preferred cudnn frontend version for [cuDNN version 9.13.1](https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-13-1) and above.
- Introduced a new `cudnn.Graph` API that enables interoperability between `torch.tensors` and the cudnn frontend API. Sample code for performing a matmul with bias addition: ``` B, M, N, K = 16, 128, 128, 512 a_gpu = torch.randn(B, M, K, device="cuda", dtype=torch.bfloat16) b_gpu = torch.randn(B, K, N, device="cuda", dtype=torch.bfloat16) d_gpu = torch.randn(1, M, N, device="cuda", dtype=torch.bfloat16) with cudnn.Graph( intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, inputs=["mm::A", "mm::B", "bias::bias"], outputs=["bias::OUT_0"], ) as graph: AB = graph.matmul( name="mm", A=a_gpu, B=b_gpu, ) C = graph.bias(name="bias", input=AB, bias=d_gpu) C.set_output(True) c_gpu = graph(a_gpu, b_gpu, d_gpu, handle=handle) ``` All notebooks under [samples/python](samples/python) have been updated to showcase the flexibility of this API. - cudnn frontend now supports building editable pip wheels in place. - The cudnn frontend `Graph` now includes a `warmup` method that triggers kernel loading by performing a fake graph capture. This improves the startup time for running the initial kernel in the actual run and prevents deadlocks when used with other modules (e.g., NCCL). - Introduced `set_score_max` and `set_score_sum_exp` to allow the kernel to output `max attention score` and `sum of exponents`. - Updated support surface checks. (SDPA bprop does not support the combination of `s_q==1` and `s_kv==1`.) - SDPA bprop now automatically applies a padding mask if the sequence length is not a multiple of the tile size. - Added support for `COMPLEX_FP32` and `COMPLEX_FP64` datatypes. (Requires cuDNN v9.14.0 or later.) - Updated samples to prioritize `fe::HeurMode_t::A` over `fe::HeurMode_t::FALLBACK`. - Added support for a new parameter to enable negative scales in the Block Scale DeQuantize operation. - Improved logging to clearly illustrate the different stages of graph creation. - The `swish` function now accepts a `swish_beta` parameter. - Added samples demonstrating how to perform sink attention forward and backward propagation with the C++ API. (Requires cuDNN v9.13.0 or later.) - Added samples demonstrating "Block Scale Matmul Quantize". (Requires cuDNN v9.14.0 or later.) - Added a sample demonstrating how ragged (packed) tensors work with cuDNN SDPA ([test_sdpa_with_caching.py](test/python/test_sdpa_with_caching.py)). The sample also demonstrates simple caching and graph capture techniques that can improve execution time. - Fixed an issue where the SDPA node was accessing tensor dimensions before they were inferred, leading to a crash. - Updated results with cuDNN 9.13.1 for B200 and GB300. - [https://github.com/NVIDIA/cudnn-frontend/issues/160](https://github.com/NVIDIA/cudnn-frontend/issues/160) - [https://github.com/NVIDIA/cudnn-frontend/issues/152](https://github.com/NVIDIA/cudnn-frontend/issues/152)
1 parent 80a8e4a commit bb37575

File tree

108 files changed

+11846
-666
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+11846
-666
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.23)
22

3-
project(cudnn_frontend VERSION 1.14.1)
3+
project(cudnn_frontend VERSION 1.15.0)
44

55
option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
66
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)

benchmark/sdpa_benchmark_training/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvcr.io/nvidia/pytorch:25.06-py3
1+
FROM nvcr.io/nvidia/pytorch:25.09-py3
22

33
RUN pip install --upgrade pip && \
44
pip install seaborn

benchmark/sdpa_benchmark_training/README.md

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@ The provided benchmark targets training use cases--causal masking is enabled for
1010
- `Dockerfile` to create a Docker container for the dependencies and run the benchmark.
1111
- `benchmark_bf16_sdpa.py` which runs cudnn, pytorch, and other backends up to 128k sequence length.
1212
- `benchmark_fp8_sdpa.py` which runs cudnn on fp8 along with bf16 up to 128k sequence length.
13-
- Sample benchmark output and results on B200 in the `artifacts` directory.
13+
- Sample benchmark output and results on B200 and GB300 in the `artifacts` directory.
1414
- Useful Python scripts for running single attention layers:
1515
- `benchmark_single_sdpa.py` for benchmarking a single flash attention instance from various backends.
1616
- See below for usage example.
1717

1818
## Software versions
1919

20-
This benchmark code should run on any decently modern Python environment with CUDA-enabled GPU. The results in `artifacts` were collected using the PyTorch docker image [from the NVIDIA GPU CLOUD (NGC) catalog](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), `nvcr.io/nvidia/pytorch:25.06-py3`, where cuDNN 9.10.2 was used. We provide a `Dockerfile` to reproduce the environment with the following library versions
20+
This benchmark code should run on any decently modern Python environment with CUDA-enabled GPU. The results in `artifacts` were collected using the PyTorch docker image [from the NVIDIA GPU CLOUD (NGC) catalog](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), `nvcr.io/nvidia/pytorch:25.09-py3`, where cuDNN 9.13.1 was used. We provide a `Dockerfile` to reproduce the environment with the following library versions
2121

2222

2323
| Software | Version |
2424
|----------------|---------|
25-
| Python | 3.12.9 |
26-
| CUDA | 12.9.0 |
27-
| cuDNN | 9.10.2 |
28-
| PyTorch | 2.8.0 |
25+
| Python | 3.12.3 |
26+
| CUDA | 13.0.0 |
27+
| cuDNN | 9.13.1 |
28+
| PyTorch | 2.9.0 |
2929
| FlashAttention | 2.7.4 |
3030

3131

@@ -60,13 +60,13 @@ Please note that FlashAttention-3 is currently not supported on NVIDIA's Blackwe
6060
Sample outputs:
6161
```
6262
$ python3 benchmark_bf16_sdpa.py
63-
[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.06'
64-
[INFO] torch.version.cuda = '12.9'
63+
[INFO] torch.__version__ = '2.9.0a0+50eac811a6.nv25.09'
64+
[INFO] torch.version.cuda = '13.0'
6565
[INFO] torch.cuda.is_available() = True
66-
[INFO] torch.cuda.device_count() = 1
66+
[INFO] torch.cuda.device_count() = 8
6767
[INFO] torch.cuda.current_device() = 0
6868
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200'
69-
[INFO] torch.backends.cudnn.version() = 91002
69+
[INFO] torch.backends.cudnn.version() = 91300
7070
[INFO] torch.backends.cudnn.enabled = True
7171
[INFO] flash_attn.__version__ = '2.7.4.post1'
7272
[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)
@@ -79,12 +79,12 @@ $ python3 benchmark_bf16_sdpa.py
7979

8080
```
8181
$ python3 benchmark_sdpa_fp8.py
82-
[INFO] cuDNN Backend Version: cudnn.backend_version() = 91002
83-
[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.12.0'
84-
[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.06'
85-
[INFO] torch.version.cuda = '12.9'
82+
[INFO] cuDNN Backend Version: cudnn.backend_version() = 91301
83+
[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.14.1'
84+
[INFO] torch.__version__ = '2.9.0a0+50eac811a6.nv25.09'
85+
[INFO] torch.version.cuda = '13.0'
8686
[INFO] torch.cuda.is_available() = True
87-
[INFO] torch.cuda.device_count() = 1
87+
[INFO] torch.cuda.device_count() = 8
8888
[INFO] torch.cuda.current_device() = 0
8989
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200'
9090
[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)
@@ -100,13 +100,13 @@ $ python3 benchmark_sdpa_fp8.py
100100
Benchmarked performance numbers are stored in the [artifacts](artifacts) directory as csv and png files.
101101

102102
## Results
103-
Below are the result of the benchmark running on a single B200 GPU.
103+
Below are the result of the benchmark running on a single B200 GPU and a single GB300 GPU.
104104

105105
For both runs, the following software versions are used:
106106

107-
- CUDA: 12.9 (from NGC container)
108-
- PyTorch: 2.8.0 (from NGC container)
109-
- cuDNN: 9.10.2 (Installed via `apt-get`; see `Dockerfile`)
107+
- CUDA: 13.0 (from NGC container)
108+
- PyTorch: 2.9.0 (from NGC container)
109+
- cuDNN: 9.13.1 (Installed via `apt-get`; see `Dockerfile`)
110110

111111

112112
### B200 - BF16 Performance Comparison between Backends
@@ -121,6 +121,18 @@ For both runs, the following software versions are used:
121121
- Sequence lengths are shown in the x-axis.
122122
- Results were obtained on an NVIDIA B200 GPU with free clock.
123123

124+
### GB300 - BF16 Performance Comparison between Backends
125+
![Comparison of pytorch and cudnn](artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.png)
126+
- SDPA parameters were used `batch=1; num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`.
127+
- Sequence lengths are shown in the x-axis.
128+
- Results were obtained on an NVIDIA GB300 GPU with free clock.
129+
130+
### GB300 - cuDNN's FP8 Performance Relative to BF16
131+
![Comparison of pytorch and cudnn](artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.png)
132+
- SDPA parameters were used `batch=1; num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`.
133+
- Sequence lengths are shown in the x-axis.
134+
- Results were obtained on an NVIDIA GB300 GPU with free clock.
135+
124136
## Pytorch adoption
125137
As demonstrated can be seen from the results, cuDNN v9 can achieve over 2x the performance of the comparable PyTorch eager implementation. Refer to [PyTorch's scaled_dot_product_attention()](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and [sdpa_kernel](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) context manager documentations for enabling the cuDNN backend for scaled dot product attention.
126138

benchmark/sdpa_benchmark_training/artifacts/sample_b200_bf16_run.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.06'
2-
[INFO] torch.version.cuda = '12.9'
1+
[INFO] torch.__version__ = '2.9.0a0+50eac811a6.nv25.09'
2+
[INFO] torch.version.cuda = '13.0'
33
[INFO] torch.cuda.is_available() = True
4-
[INFO] torch.cuda.device_count() = 1
4+
[INFO] torch.cuda.device_count() = 8
55
[INFO] torch.cuda.current_device() = 0
66
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200'
7-
[INFO] torch.backends.cudnn.version() = 91002
7+
[INFO] torch.backends.cudnn.version() = 91300
88
[INFO] torch.backends.cudnn.enabled = True
99
[INFO] flash_attn.__version__ = '2.7.4.post1'
1010
[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)
Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
1-
[INFO] cuDNN Backend Version: cudnn.backend_version() = 91002
2-
[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.12.0'
3-
[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.06'
4-
[INFO] torch.version.cuda = '12.9'
1+
[INFO] cuDNN Backend Version: cudnn.backend_version() = 91301
2+
[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.14.1'
3+
[INFO] torch.__version__ = '2.9.0a0+50eac811a6.nv25.09'
4+
[INFO] torch.version.cuda = '13.0'
55
[INFO] torch.cuda.is_available() = True
6-
[INFO] torch.cuda.device_count() = 1
6+
[INFO] torch.cuda.device_count() = 8
77
[INFO] torch.cuda.current_device() = 0
88
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200'
99
[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)
1010
[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)]
1111
[INFO] Running layer (1, 512, 512, 128, 8, 128)
12-
[INFO] Benchmarking backend fp8
13-
[INFO] Benchmarking backend bf16
12+
[INFO] Benchmarking data type fp8
13+
[INFO] Benchmarking data type bf16
1414
[INFO] Running layer (1, 1024, 1024, 128, 8, 128)
15-
[INFO] Benchmarking backend fp8
16-
[INFO] Benchmarking backend bf16
15+
[INFO] Benchmarking data type fp8
16+
[INFO] Benchmarking data type bf16
1717
[INFO] Running layer (1, 2048, 2048, 128, 8, 128)
18-
[INFO] Benchmarking backend fp8
19-
[INFO] Benchmarking backend bf16
18+
[INFO] Benchmarking data type fp8
19+
[INFO] Benchmarking data type bf16
2020
[INFO] Running layer (1, 4096, 4096, 128, 8, 128)
21-
[INFO] Benchmarking backend fp8
22-
[INFO] Benchmarking backend bf16
21+
[INFO] Benchmarking data type fp8
22+
[INFO] Benchmarking data type bf16
2323
[INFO] Running layer (1, 8192, 8192, 128, 8, 128)
24-
[INFO] Benchmarking backend fp8
25-
[INFO] Benchmarking backend bf16
24+
[INFO] Benchmarking data type fp8
25+
[INFO] Benchmarking data type bf16
2626
[INFO] Running layer (1, 16384, 16384, 128, 8, 128)
27-
[INFO] Benchmarking backend fp8
28-
[INFO] Benchmarking backend bf16
27+
[INFO] Benchmarking data type fp8
28+
[INFO] Benchmarking data type bf16
2929
[INFO] Running layer (1, 32768, 32768, 128, 8, 128)
30-
[INFO] Benchmarking backend fp8
31-
[INFO] Benchmarking backend bf16
30+
[INFO] Benchmarking data type fp8
31+
[INFO] Benchmarking data type bf16
3232
[INFO] Running layer (1, 65536, 65536, 128, 8, 128)
33-
[INFO] Benchmarking backend fp8
34-
[INFO] Benchmarking backend bf16
33+
[INFO] Benchmarking data type fp8
34+
[INFO] Benchmarking data type bf16
3535
[INFO] Running layer (1, 131072, 131072, 128, 8, 128)
36-
[INFO] Benchmarking backend fp8
37-
[INFO] Benchmarking backend bf16
36+
[INFO] Benchmarking data type fp8
37+
[INFO] Benchmarking data type bf16
3838
[INFO] Saving results to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_B200.csv
3939
[INFO] Saving plot to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_B200.png
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
[INFO] torch.__version__ = '2.9.0a0+50eac811a6.nv25.09'
2+
[INFO] torch.version.cuda = '13.0'
3+
[INFO] torch.cuda.is_available() = True
4+
[INFO] torch.cuda.device_count() = 4
5+
[INFO] torch.cuda.current_device() = 0
6+
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA GB300'
7+
[INFO] torch.backends.cudnn.version() = 91300
8+
[INFO] torch.backends.cudnn.enabled = True
9+
[INFO] flash_attn.__version__ = '2.7.4.post1'
10+
[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)
11+
[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)]
12+
[INFO] Running layer (1, 512, 512, 128, 8, 128)
13+
[INFO] Benchmarking backend pyt_math
14+
[INFO] Benchmarking backend pyt_cudnn
15+
[INFO] Benchmarking backend pyt_flash_attention
16+
[INFO] Benchmarking backend flash_attention
17+
[INFO] Running layer (1, 1024, 1024, 128, 8, 128)
18+
[INFO] Benchmarking backend pyt_math
19+
[INFO] Benchmarking backend pyt_cudnn
20+
[INFO] Benchmarking backend pyt_flash_attention
21+
[INFO] Benchmarking backend flash_attention
22+
[INFO] Running layer (1, 2048, 2048, 128, 8, 128)
23+
[INFO] Benchmarking backend pyt_math
24+
[INFO] Benchmarking backend pyt_cudnn
25+
[INFO] Benchmarking backend pyt_flash_attention
26+
[INFO] Benchmarking backend flash_attention
27+
[INFO] Running layer (1, 4096, 4096, 128, 8, 128)
28+
[INFO] Benchmarking backend pyt_math
29+
[INFO] Benchmarking backend pyt_cudnn
30+
[INFO] Benchmarking backend pyt_flash_attention
31+
[INFO] Benchmarking backend flash_attention
32+
[INFO] Running layer (1, 8192, 8192, 128, 8, 128)
33+
[INFO] Benchmarking backend pyt_math
34+
[INFO] Benchmarking backend pyt_cudnn
35+
[INFO] Benchmarking backend pyt_flash_attention
36+
[INFO] Benchmarking backend flash_attention
37+
[INFO] Running layer (1, 16384, 16384, 128, 8, 128)
38+
[INFO] Benchmarking backend pyt_math
39+
[INFO] Benchmarking backend pyt_cudnn
40+
[INFO] Benchmarking backend pyt_flash_attention
41+
[INFO] Benchmarking backend flash_attention
42+
[INFO] Running layer (1, 32768, 32768, 128, 8, 128)
43+
[INFO] Benchmarking backend pyt_math
44+
[INFO] Benchmarking backend pyt_cudnn
45+
[INFO] Benchmarking backend pyt_flash_attention
46+
[INFO] Benchmarking backend flash_attention
47+
[INFO] Running layer (1, 65536, 65536, 128, 8, 128)
48+
[INFO] Benchmarking backend pyt_math
49+
[INFO] Benchmarking backend pyt_cudnn
50+
[INFO] Benchmarking backend pyt_flash_attention
51+
[INFO] Benchmarking backend flash_attention
52+
[INFO] Running layer (1, 131072, 131072, 128, 8, 128)
53+
[INFO] Benchmarking backend pyt_math
54+
[INFO] Benchmarking backend pyt_cudnn
55+
[INFO] Benchmarking backend pyt_flash_attention
56+
[INFO] Benchmarking backend flash_attention
57+
[INFO] Saving results to ./artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.csv
58+
[INFO] Saving plot to ./artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.png
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
[INFO] cuDNN Backend Version: cudnn.backend_version() = 91301
2+
[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.14.1'
3+
[INFO] torch.__version__ = '2.9.0a0+50eac811a6.nv25.09'
4+
[INFO] torch.version.cuda = '13.0'
5+
[INFO] torch.cuda.is_available() = True
6+
[INFO] torch.cuda.device_count() = 4
7+
[INFO] torch.cuda.current_device() = 0
8+
[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA GB300'
9+
[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)
10+
[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)]
11+
[INFO] Running layer (1, 512, 512, 128, 8, 128)
12+
[INFO] Benchmarking data type fp8
13+
[INFO] Benchmarking data type bf16
14+
[INFO] Running layer (1, 1024, 1024, 128, 8, 128)
15+
[INFO] Benchmarking data type fp8
16+
[INFO] Benchmarking data type bf16
17+
[INFO] Running layer (1, 2048, 2048, 128, 8, 128)
18+
[INFO] Benchmarking data type fp8
19+
[INFO] Benchmarking data type bf16
20+
[INFO] Running layer (1, 4096, 4096, 128, 8, 128)
21+
[INFO] Benchmarking data type fp8
22+
[INFO] Benchmarking data type bf16
23+
[INFO] Running layer (1, 8192, 8192, 128, 8, 128)
24+
[INFO] Benchmarking data type fp8
25+
[INFO] Benchmarking data type bf16
26+
[INFO] Running layer (1, 16384, 16384, 128, 8, 128)
27+
[INFO] Benchmarking data type fp8
28+
[INFO] Benchmarking data type bf16
29+
[INFO] Running layer (1, 32768, 32768, 128, 8, 128)
30+
[INFO] Benchmarking data type fp8
31+
[INFO] Benchmarking data type bf16
32+
[INFO] Running layer (1, 65536, 65536, 128, 8, 128)
33+
[INFO] Benchmarking data type fp8
34+
[INFO] Benchmarking data type bf16
35+
[INFO] Running layer (1, 131072, 131072, 128, 8, 128)
36+
[INFO] Benchmarking data type fp8
37+
[INFO] Benchmarking data type bf16
38+
[INFO] Saving results to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.csv
39+
[INFO] Saving plot to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.png

0 commit comments

Comments
 (0)