55# LICENSE file in the root directory of this source tree.
66
77"""
8- Benchmark two attention backends against each other for a single layer,
9- sweeping sequence lengths and measuring runtime and SQNR.
8+ Benchmark two attention backends against each other for a single layer.
109
11- Usage: python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa2 --test fa3_fp8
10+ Sweeps over sequence lengths from 1K to 128K, measuring runtime and SQNR.
11+
12+ Usage:
13+ python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa3 --test fa3_fp8
1214"""
1315
1416import argparse
2527)
2628
2729from torchao .prototype .attention .fp8_fa3 .attention import fp8_fa3_sdpa
30+ from torchao .prototype .attention .fp8_fa4 .attention import fp8_fa4_sdpa
2831from torchao .quantization .utils import compute_error as compute_sqnr
2932
30- BACKENDS = ["fa2" , "fa3" , "fa3_fp8" ]
33+ BACKENDS = ["fa2" , "fa3" , "fa3_fp8" , "fa4" , "fa4_fp8" ]
3134
3235BACKEND_LABELS = {
3336 "fa2" : "FA2 BF16" ,
3437 "fa3" : "FA3 BF16" ,
3538 "fa3_fp8" : "FA3 FP8" ,
39+ "fa4" : "FA4 BF16" ,
40+ "fa4_fp8" : "FA4 FP8" ,
3641}
3742
3843
@@ -41,20 +46,24 @@ def _activate_backend(backend: str):
4146 """Context manager that activates the appropriate flash attention impl."""
4247 if backend in ("fa3" , "fa3_fp8" ):
4348 activate_flash_attention_impl ("FA3" )
49+ elif backend in ("fa4" , "fa4_fp8" ):
50+ activate_flash_attention_impl ("FA4" )
4451 else :
4552 # fa2 is the default, no activation needed
4653 pass
4754 try :
4855 yield
4956 finally :
50- if backend in ("fa3" , "fa3_fp8" ):
57+ if backend in ("fa3" , "fa3_fp8" , "fa4" , "fa4_fp8" ):
5158 restore_flash_attention_impl ()
5259
5360
5461def _run_attention (backend : str , q , k , v , is_causal : bool ):
5562 """Run a single attention call for the given backend."""
5663 if backend == "fa3_fp8" :
5764 return fp8_fa3_sdpa (q , k , v , is_causal = is_causal )
65+ elif backend == "fa4_fp8" :
66+ return fp8_fa4_sdpa (q , k , v , is_causal = is_causal )
5867 else :
5968 with sdpa_kernel (SDPBackend .FLASH_ATTENTION ):
6069 return F .scaled_dot_product_attention (q , k , v , is_causal = is_causal )
0 commit comments