Skip to content

Commit 541d5da

Browse files
authored
[CUDA] Speed up flash attention build (#26924)
## Summary This pull request aims to significantly reduce the build time for Flash Attention by removing support for less common head dimensions (160 and 224). It also includes a build option for quick build `--cmake_extra_defines onnxruntime_QUICK_BUILD=ON`, which will only build flash attention kernel for float16 and head dimension=128. That could speed up development. ## Key Changes ### 1. Flash Attention Build Optimization - **Removed Head Dimensions:** Deleted source files and kernel instantiations for head dimensions **160** and **224** (both FP16 and BF16). These dimensions are less frequently used, and removing them reduces the number of kernels to be compiled, thereby speeding up the build process. - **Updated Dispatch Logic:** Modified `static_switch.h` and `flash_api.h` to remove the dispatch cases for `kHeadDim = 160` and `kHeadDim = 224`. ### 2. Test Enhancements - **GQA Tests:** Updated `onnxruntime/test/python/transformers/test_gqa.py` to detect whether it is quick build package. If it is, only test supported data type (float16) and head dimension (128 only) for flash attention, and use `has_flash_attention(bf16=True)` when checking for Flash Attention availability in BF16 tests. This ensures that tests are skipped appropriately if BF16 kernels are not compiled/available. ## Impact - **Build Time:** Faster compilation of the CUDA provider due to fewer Flash Attention kernels. - **Functionality:** Head dimensions 160 and 224 are no longer supported for Flash Attention. Models using these specific head dimensions will fall back to next supported head dimension like 192 or 256. ## Verification - Validated that the build completes successfully with the reduced kernel set. - `test_gqa.py` should pass or skip correctly based on hardware support. - Build onnxruntime-gpu package with `--cmake_extra_defines onnxruntime_QUICK_BUILD=ON` option, and the build info has "quick-build=1", like the following python script: ```python import onnxruntime print(onnxruntime.get_build_info()) ``` The output is like ``` ORT Build Info: git-branch=main, git-commit-id=ecf164a945, quick-build=1, build type=Release ```
1 parent 751af64 commit 541d5da

19 files changed

+89
-170
lines changed

cmake/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention ke
102102
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
103103
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
104104
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
105+
option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF)
105106

106107
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
107108
option(onnxruntime_USE_AVX "Use AVX instructions" OFF)
@@ -789,6 +790,11 @@ if (onnxruntime_USE_CUDA)
789790
message( STATUS "Enable FpA IntB Gemm for CUDA EP")
790791
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FPA_INTB_GEMM=1)
791792
endif()
793+
794+
if (onnxruntime_QUICK_BUILD)
795+
message( STATUS "Quick build mode: Flash attention limited to fp16 only")
796+
list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
797+
endif()
792798
endif()
793799

794800
if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA))
@@ -1442,6 +1448,9 @@ if (Git_FOUND)
14421448
OUTPUT_VARIABLE ORT_GIT_BRANCH)
14431449
string(STRIP "${ORT_GIT_BRANCH}" ORT_GIT_BRANCH)
14441450
string(APPEND ORT_BUILD_INFO "git-branch=${ORT_GIT_BRANCH}, git-commit-id=${ORT_GIT_COMMIT}, ")
1451+
if (onnxruntime_QUICK_BUILD)
1452+
string(APPEND ORT_BUILD_INFO "quick-build=1, ")
1453+
endif()
14451454
endif()
14461455
string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}")
14471456
configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h)

cmake/onnxruntime_providers_cpu.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS
2525
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh"
2626
)
2727

28+
# Quick build mode: Filter out non-hdim128 flash attention kernels for faster development iteration
29+
if(onnxruntime_QUICK_BUILD)
30+
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
31+
# Filter non-hdim128 kernels
32+
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|160|192|224|256)")
33+
# Filter all bfloat16 kernels (only keep fp16)
34+
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*_bf16")
35+
endif()
36+
37+
2838

2939
file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
3040
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h"

onnxruntime/contrib_ops/cuda/bert/attention.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
116116
nullptr == present &&
117117
parameters.hidden_size == parameters.v_hidden_size &&
118118
nullptr == mask_index &&
119-
onnxruntime::flash::is_supported(device_prop,
120-
parameters.head_size,
121-
parameters.num_heads,
122-
parameters.num_heads);
119+
onnxruntime::flash::is_supported<T>(device_prop,
120+
parameters.head_size,
121+
parameters.num_heads,
122+
parameters.num_heads);
123123
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
124124
if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
125125
use_flash_attention = false;

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232

3333
#include "core/providers/cuda/cuda_common.h"
3434
#include <tuple>
35+
#include <type_traits>
36+
#include <cutlass/numeric_types.h>
37+
#include <cuda_bf16.h>
3538

3639
namespace onnxruntime {
3740
namespace flash {
@@ -89,8 +92,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
8992
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
9093
cudaStream_t stream,
9194
void* q, // batch_size x seqlen_q x num_heads x head_size
92-
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
93-
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
95+
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
96+
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
9497
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
9598
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
9699
void* out, // batch_size x seqlen_q x num_heads x head_size
@@ -131,6 +134,23 @@ std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_
131134

132135
bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k);
133136

137+
// Template version that checks for bf16 type in quick build mode
138+
template <typename T>
139+
bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) {
140+
#ifdef ORT_QUICK_BUILD
141+
// In quick build mode, only fp16 flash attention is built
142+
constexpr bool is_bf16 = std::is_same<T, onnxruntime::BFloat16>::value;
143+
if (is_bf16) {
144+
return false;
145+
}
146+
147+
if (head_size != 128) {
148+
return false;
149+
}
150+
#endif
151+
return is_supported(dprops, head_size, num_heads, num_heads_k);
152+
}
153+
134154
} // namespace flash
135155
} // namespace onnxruntime
136156

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_bf16_sm80.cu

Lines changed: 0 additions & 18 deletions
This file was deleted.

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu

Lines changed: 0 additions & 18 deletions
This file was deleted.

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_bf16_sm80.cu

Lines changed: 0 additions & 18 deletions
This file was deleted.

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu

Lines changed: 0 additions & 18 deletions
This file was deleted.

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,9 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
7070
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
7171
// If Is_local, set Is_causal to false
7272
auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, false > ;
73-
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst>;
7473
if (smem_size >= 48 * 1024) {
7574
cudaFuncSetAttribute(
7675
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(smem_size));
77-
// ORT_ENFORCE(cudaFuncSetAttribute(
78-
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
7976
}
8077
// int ctas_per_sm;
8178
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
@@ -112,8 +109,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) {
112109
auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, Is_Local_Const && !Is_causal, Has_alibi,
113110
IsEvenMNConst && !Append_KV_Const && IsEvenKConst && !Is_Local_Const && Kernel_traits::kHeadDim <= 128,
114111
IsEvenKConst, Is_softcap, SplitConst, Append_KV_Const >;
115-
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV_Const>;
116-
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
117112
if (smem_size >= 48 * 1024) {
118113
cudaFuncSetAttribute(
119114
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(smem_size));

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)