|
67 | 67 | )
|
68 | 68 |
|
69 | 69 | from .test_fmha_utils import make_packed_qkv
|
| 70 | + |
| 71 | + HAS_FLASH_V2 = True |
70 | 72 | except (ImportError, IOError, AttributeError):
|
71 |
| - pass |
| 73 | + HAS_FLASH_V2 = False |
72 | 74 |
|
73 | 75 | HAS_CUDA_124 = (
|
74 | 76 | torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4"
|
75 | 77 | )
|
76 | 78 |
|
77 | 79 | # [Optional] flash_attn v3
|
78 |
| -HAS_FLASH_V3 = True |
79 | 80 | try:
|
80 | 81 | torch_lib_path = os.path.join(os.path.dirname(__file__), "lib")
|
81 | 82 | with add_ld_library_path(torch_lib_path):
|
|
85 | 86 | from ai_codesign.gen_ai.flash_attention_v2.hopper.flash_attn_interface import (
|
86 | 87 | flash_attn_func as flash_attn_v3,
|
87 | 88 | )
|
| 89 | + |
| 90 | + HAS_FLASH_V3 = True |
88 | 91 | except (ImportError, IOError, AttributeError):
|
89 | 92 | HAS_FLASH_V3 = False
|
90 | 93 |
|
@@ -244,7 +247,7 @@ def sdpa_flash_attention(q, k, v):
|
244 | 247 | v,
|
245 | 248 | )
|
246 | 249 |
|
247 |
| - @register_benchmark() |
| 250 | + @register_benchmark(enabled=HAS_FLASH_V2) |
248 | 251 | def flash_v2(
|
249 | 252 | self,
|
250 | 253 | q: torch.Tensor,
|
@@ -533,6 +536,10 @@ def get_ctx_vals():
|
533 | 536 | shapes = ctx_vals
|
534 | 537 | requires_grad = True
|
535 | 538 | for shape in shapes:
|
| 539 | + if torch.version.hip is not None and shape == (4, 32, 1, 128): |
| 540 | + # AMD ROCm has an issue running triton_tutorial_flash_v2 |
| 541 | + # on shape (4, 32, 1, 128). Skip it for now. |
| 542 | + continue |
536 | 543 | BATCH, H, N_CTX, D_HEAD = shape
|
537 | 544 | q = torch.randn(
|
538 | 545 | (BATCH, H, N_CTX, D_HEAD),
|
|
0 commit comments