Skip to content

Commit 484148c

Browse files
Add FA4 RoPE fusion path for low-precision attention
ghstack-source-id: 56eda45 Pull-Request: #3947
1 parent d2cbad4 commit 484148c

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

benchmarks/prototype/attention/benchmark_sdpa.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""
8-
Benchmark two attention backends against each other for a single layer,
9-
sweeping sequence lengths and measuring runtime and SQNR.
8+
Benchmark two attention backends against each other for a single layer.
109
11-
Usage: python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa2 --test fa3_fp8
10+
Sweeps over sequence lengths from 1K to 128K, measuring runtime and SQNR.
11+
12+
Usage:
13+
python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa3 --test fa3_fp8
1214
"""
1315

1416
import argparse
@@ -25,14 +27,17 @@
2527
)
2628

2729
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
30+
from torchao.prototype.attention.fp8_fa4.attention import fp8_fa4_sdpa
2831
from torchao.quantization.utils import compute_error as compute_sqnr
2932

30-
BACKENDS = ["fa2", "fa3", "fa3_fp8"]
33+
BACKENDS = ["fa2", "fa3", "fa3_fp8", "fa4", "fa4_fp8"]
3134

3235
BACKEND_LABELS = {
3336
"fa2": "FA2 BF16",
3437
"fa3": "FA3 BF16",
3538
"fa3_fp8": "FA3 FP8",
39+
"fa4": "FA4 BF16",
40+
"fa4_fp8": "FA4 FP8",
3641
}
3742

3843

@@ -41,20 +46,24 @@ def _activate_backend(backend: str):
4146
"""Context manager that activates the appropriate flash attention impl."""
4247
if backend in ("fa3", "fa3_fp8"):
4348
activate_flash_attention_impl("FA3")
49+
elif backend in ("fa4", "fa4_fp8"):
50+
activate_flash_attention_impl("FA4")
4451
else:
4552
# fa2 is the default, no activation needed
4653
pass
4754
try:
4855
yield
4956
finally:
50-
if backend in ("fa3", "fa3_fp8"):
57+
if backend in ("fa3", "fa3_fp8", "fa4", "fa4_fp8"):
5158
restore_flash_attention_impl()
5259

5360

5461
def _run_attention(backend: str, q, k, v, is_causal: bool):
5562
"""Run a single attention call for the given backend."""
5663
if backend == "fa3_fp8":
5764
return fp8_fa3_sdpa(q, k, v, is_causal=is_causal)
65+
elif backend == "fa4_fp8":
66+
return fp8_fa4_sdpa(q, k, v, is_causal=is_causal)
5867
else:
5968
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
6069
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)

0 commit comments

Comments
 (0)