Skip to content

Commit 2e831a7

Browse files
committed
clean up quick build flag
1 parent 242ee30 commit 2e831a7

File tree

3 files changed

+3
-26
lines changed

3 files changed

+3
-26
lines changed

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,6 @@ bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_hea
145145
template <typename T>
146146
bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) {
147147
#ifdef ORT_QUICK_BUILD
148-
149-
#if ORT_QUICK_BUILD == 1
150-
// In quick build mode, only fp16 flash attention is built
151-
constexpr bool is_bf16 = std::is_same<T, onnxruntime::BFloat16>::value;
152-
if (is_bf16) {
153-
return false;
154-
}
155-
#endif
156-
157148
if (head_size != 128) {
158149
return false;
159150
}

onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,6 @@
6666
#define LOCAL_SWITCH BOOL_SWITCH
6767
#endif
6868

69-
#if ORT_QUICK_BUILD == 1
70-
// Quick build mode: only fp16 kernels are compiled
71-
#define FP16_SWITCH(COND, ...) \
72-
[&] { \
73-
using elem_type = cutlass::half_t; \
74-
return __VA_ARGS__(); \
75-
}()
76-
#else
7769
#define FP16_SWITCH(COND, ...) \
7870
[&] { \
7971
if (COND) { \
@@ -84,7 +76,6 @@
8476
return __VA_ARGS__(); \
8577
} \
8678
}()
87-
#endif
8879

8980
#ifdef ORT_QUICK_BUILD
9081
// Quick build mode: only hdim128 kernels are compiled

onnxruntime/test/python/transformers/test_gqa.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@
5252
# When quick build is used, flash attention only supports head_size=128
5353
quick_build = ", quick-build=" in get_build_info()
5454

55-
# When quick build mode is 1, bf16 is excluded
56-
quick_build_exclude_bf16 = ", quick-build=1, " in get_build_info()
57-
5855
enable_debug_print = quick_build
5956

6057
enable_deterministic_check = True
@@ -2057,9 +2054,7 @@ def has_cuda_device(min_capability: int = 80):
20572054
return major * 10 + minor >= min_capability
20582055

20592056

2060-
def has_flash_attention(bf16: bool = False):
2061-
if bf16 and quick_build_exclude_bf16:
2062-
return False
2057+
def has_flash_attention():
20632058
return has_cuda_device(80)
20642059

20652060

@@ -2151,7 +2146,7 @@ def test_gqa_quantized_prompt(self, name, config):
21512146
)
21522147

21532148

2154-
@unittest.skipIf(not has_flash_attention(bf16=True), "Flash Attention is not available, skipping tests.")
2149+
@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.")
21552150
class TestFlashGQABF16(unittest.TestCase):
21562151
@parameterized.expand(gqa_cuda_prompt_test_cases())
21572152
def test_gqa_prompt_flash_attention_bf16(self, name, config):
@@ -2199,7 +2194,7 @@ def test_gqa_past_flash_attention_bf16(self, name, config):
21992194

22002195

22012196
@unittest.skipIf(
2202-
not has_flash_attention(bf16=True) or not enable_quantized_kv_tests,
2197+
not has_flash_attention() or not enable_quantized_kv_tests,
22032198
"Flash Attention is not available, skipping tests.",
22042199
)
22052200
class TestFlashGQABF16QuantizedKV(unittest.TestCase):

0 commit comments

Comments
 (0)