Skip to content

Commit 25855fa

Browse files
Nick Riasanovskyfacebook-github-bot
Nick Riasanovsky
authored andcommitted
Fix Import Error for Flash Attention V2 + Disable triton_tutorial_flash_v2 memory fault shape
Summary: Fixes an import error if your environment doesn't have flash attention v2 but you try and run all of the operators. Reviewed By: dhruvak3 Differential Revision: D71598757 fbshipit-source-id: 850e4ded958414f644ed2a4276484d7735adf818
1 parent 7441efb commit 25855fa

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tritonbench/operators/flash_attention/operator.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,16 @@
6767
)
6868

6969
from .test_fmha_utils import make_packed_qkv
70+
71+
HAS_FLASH_V2 = True
7072
except (ImportError, IOError, AttributeError):
71-
pass
73+
HAS_FLASH_V2 = False
7274

7375
HAS_CUDA_124 = (
7476
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4"
7577
)
7678

7779
# [Optional] flash_attn v3
78-
HAS_FLASH_V3 = True
7980
try:
8081
torch_lib_path = os.path.join(os.path.dirname(__file__), "lib")
8182
with add_ld_library_path(torch_lib_path):
@@ -85,6 +86,8 @@
8586
from ai_codesign.gen_ai.flash_attention_v2.hopper.flash_attn_interface import (
8687
flash_attn_func as flash_attn_v3,
8788
)
89+
90+
HAS_FLASH_V3 = True
8891
except (ImportError, IOError, AttributeError):
8992
HAS_FLASH_V3 = False
9093

@@ -244,7 +247,7 @@ def sdpa_flash_attention(q, k, v):
244247
v,
245248
)
246249

247-
@register_benchmark()
250+
@register_benchmark(enabled=HAS_FLASH_V2)
248251
def flash_v2(
249252
self,
250253
q: torch.Tensor,
@@ -533,6 +536,10 @@ def get_ctx_vals():
533536
shapes = ctx_vals
534537
requires_grad = True
535538
for shape in shapes:
539+
if torch.version.hip is not None and shape == (4, 32, 1, 128):
540+
# AMD ROCm has an issue running triton_tutorial_flash_v2
541+
# on shape (4, 32, 1, 128). Skip it for now.
542+
continue
536543
BATCH, H, N_CTX, D_HEAD = shape
537544
q = torch.randn(
538545
(BATCH, H, N_CTX, D_HEAD),

0 commit comments

Comments
 (0)