Commit 541d5da
authored
[CUDA] Speed up flash attention build (#26924)
## Summary
This pull request aims to significantly reduce the build time for Flash
Attention by removing support for less common head dimensions (160 and
224).
It also includes a build option for quick build `--cmake_extra_defines
onnxruntime_QUICK_BUILD=ON`, which will only build flash attention
kernel for float16 and head dimension=128. That could speed up
development.
## Key Changes
### 1. Flash Attention Build Optimization
- **Removed Head Dimensions:** Deleted source files and kernel
instantiations for head dimensions **160** and **224** (both FP16 and
BF16). These dimensions are less frequently used, and removing them
reduces the number of kernels to be compiled, thereby speeding up the
build process.
- **Updated Dispatch Logic:** Modified `static_switch.h` and
`flash_api.h` to remove the dispatch cases for `kHeadDim = 160` and
`kHeadDim = 224`.
### 2. Test Enhancements
- **GQA Tests:** Updated
`onnxruntime/test/python/transformers/test_gqa.py` to detect whether it
is quick build package. If it is, only test supported data type
(float16) and head dimension (128 only) for flash attention, and use
`has_flash_attention(bf16=True)` when checking for Flash Attention
availability in BF16 tests. This ensures that tests are skipped
appropriately if BF16 kernels are not compiled/available.
## Impact
- **Build Time:** Faster compilation of the CUDA provider due to fewer
Flash Attention kernels.
- **Functionality:** Head dimensions 160 and 224 are no longer supported
for Flash Attention. Models using these specific head dimensions will
fall back to next supported head dimension like 192 or 256.
## Verification
- Validated that the build completes successfully with the reduced
kernel set.
- `test_gqa.py` should pass or skip correctly based on hardware support.
- Build onnxruntime-gpu package with `--cmake_extra_defines
onnxruntime_QUICK_BUILD=ON` option, and the build info has
"quick-build=1", like the following python script:
```python
import onnxruntime
print(onnxruntime.get_build_info())
```
The output is like
```
ORT Build Info: git-branch=main, git-commit-id=ecf164a945, quick-build=1, build type=Release
```1 parent 751af64 commit 541d5da
File tree
19 files changed
+89
-170
lines changed- cmake
- onnxruntime
- contrib_ops/cuda/bert
- flash_attention
- test/python/transformers
19 files changed
+89
-170
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
102 | 102 | | |
103 | 103 | | |
104 | 104 | | |
| 105 | + | |
105 | 106 | | |
106 | 107 | | |
107 | 108 | | |
| |||
789 | 790 | | |
790 | 791 | | |
791 | 792 | | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
792 | 798 | | |
793 | 799 | | |
794 | 800 | | |
| |||
1442 | 1448 | | |
1443 | 1449 | | |
1444 | 1450 | | |
| 1451 | + | |
| 1452 | + | |
| 1453 | + | |
1445 | 1454 | | |
1446 | 1455 | | |
1447 | 1456 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
28 | 38 | | |
29 | 39 | | |
30 | 40 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
116 | 116 | | |
117 | 117 | | |
118 | 118 | | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | | - | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
123 | 123 | | |
124 | 124 | | |
125 | 125 | | |
| |||
Lines changed: 22 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
35 | 38 | | |
36 | 39 | | |
37 | 40 | | |
| |||
89 | 92 | | |
90 | 93 | | |
91 | 94 | | |
92 | | - | |
93 | | - | |
| 95 | + | |
| 96 | + | |
94 | 97 | | |
95 | 98 | | |
96 | 99 | | |
| |||
131 | 134 | | |
132 | 135 | | |
133 | 136 | | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
134 | 154 | | |
135 | 155 | | |
136 | 156 | | |
| |||
Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 0 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
73 | | - | |
74 | 73 | | |
75 | 74 | | |
76 | 75 | | |
77 | | - | |
78 | | - | |
79 | 76 | | |
80 | 77 | | |
81 | 78 | | |
| |||
112 | 109 | | |
113 | 110 | | |
114 | 111 | | |
115 | | - | |
116 | | - | |
117 | 112 | | |
118 | 113 | | |
119 | 114 | | |
| |||
Lines changed: 0 additions & 15 deletions
This file was deleted.
0 commit comments