Skip to content

Commit 28add5e

Browse files
Add FA4 RoPE fusion path for low-precision attention
Adds the compile path (fuse_rope=True) for the FA4 backend, mirroring the FA3 fusion pass structure via the shared custom op and fusion pass factories. Key additions: - fp8_fa4/fusion_pass.py: FA4-specific custom ops and compile helper - fp8_fa4_rope_sdpa entry point in attention.py - Replace placeholder compile_fn with real fusion pass in setup.py - Wire up FA4 rope_sdpa_fn in test backend config ghstack-source-id: 2e69d43 Pull-Request: #3947
1 parent 7c180fe commit 28add5e

File tree

9 files changed

+191
-17
lines changed

9 files changed

+191
-17
lines changed

benchmarks/prototype/attention/benchmark_sdpa.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
fa2 - BF16 SDPA with FlashAttention 2 (PyTorch default)
1515
fa3 - BF16 SDPA with FlashAttention 3
1616
fa3_fp8 - FP8 SDPA with FlashAttention 3 (includes quantization kernels)
17+
fa4 - BF16 SDPA with FlashAttention 4
18+
fa4_fp8 - FP8 SDPA with FlashAttention 4 (includes quantization kernels)
1719
1820
Usage:
1921
# Default: FA2 vs FA3+FP8
@@ -22,8 +24,11 @@
2224
# FA3 bf16 vs FA3 fp8
2325
python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa3 --test fa3_fp8
2426
27+
# FA2 vs FA4
28+
python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa2 --test fa4
29+
2530
# With causal masking
26-
python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa3 --test fa3_fp8 --causal
31+
python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa3 --test fa4 --causal
2732
"""
2833

2934
import argparse
@@ -40,13 +45,16 @@
4045
)
4146

4247
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
48+
from torchao.prototype.attention.fp8_fa4.attention import fp8_fa4_sdpa
4349

44-
BACKENDS = ["fa2", "fa3", "fa3_fp8"]
50+
BACKENDS = ["fa2", "fa3", "fa3_fp8", "fa4", "fa4_fp8"]
4551

4652
BACKEND_LABELS = {
4753
"fa2": "FA2 BF16",
4854
"fa3": "FA3 BF16",
4955
"fa3_fp8": "FA3 FP8",
56+
"fa4": "FA4 BF16",
57+
"fa4_fp8": "FA4 FP8",
5058
}
5159

5260

@@ -55,20 +63,24 @@ def _activate_backend(backend: str):
5563
"""Context manager that activates the appropriate flash attention impl."""
5664
if backend in ("fa3", "fa3_fp8"):
5765
activate_flash_attention_impl("FA3")
66+
elif backend in ("fa4", "fa4_fp8"):
67+
activate_flash_attention_impl("FA4")
5868
else:
5969
# fa2 is the default, no activation needed
6070
pass
6171
try:
6272
yield
6373
finally:
64-
if backend in ("fa3", "fa3_fp8"):
74+
if backend in ("fa3", "fa3_fp8", "fa4", "fa4_fp8"):
6575
restore_flash_attention_impl()
6676

6777

6878
def _run_attention(backend: str, q, k, v, is_causal: bool):
6979
"""Run a single attention call for the given backend."""
7080
if backend == "fa3_fp8":
7181
return fp8_fa3_sdpa(q, k, v, is_causal=is_causal)
82+
elif backend == "fa4_fp8":
83+
return fp8_fa4_sdpa(q, k, v, is_causal=is_causal)
7284
else:
7385
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
7486
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)

benchmarks/prototype/attention/eval_flux_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
fa2 - Flash Attention 2 (default SDPA)
1515
fa3 - Flash Attention 3
1616
fa3_fp8 - Flash Attention 3 with FP8 quantization (fused RoPE + FP8 SDPA)
17+
fa4 - Flash Attention 4
18+
fa4_fp8 - Flash Attention 4 with FP8 quantization (fused RoPE + FP8 SDPA)
1719
1820
Usage:
1921
# Compare FA3 vs FA3 FP8 (default)
@@ -22,6 +24,9 @@
2224
# Compare FA2 vs FA3
2325
python eval_flux_model.py --baseline fa2 --test fa3
2426
27+
# Compare FA3 vs FA4
28+
python eval_flux_model.py --baseline fa3 --test fa4
29+
2530
# Full benchmark with 200 prompts
2631
python eval_flux_model.py --num_prompts 200
2732
@@ -64,6 +69,12 @@
6469
"fp8": True,
6570
"fp8_backend": AttentionBackend.FP8_FA3,
6671
},
72+
"fa4": {"flash_impl": "FA4", "fp8": False},
73+
"fa4_fp8": {
74+
"flash_impl": "FA4",
75+
"fp8": True,
76+
"fp8_backend": AttentionBackend.FP8_FA4,
77+
},
6778
}
6879

6980
IMAGE_SIZE = (512, 512) # (width, height) - resize for consistent LPIPS

benchmarks/prototype/attention/eval_llama3_model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
fa2 - Flash Attention 2 (default SDPA)
1818
fa3 - Flash Attention 3
1919
fa3_fp8 - Flash Attention 3 with FP8 quantization (fused RoPE + FP8 SDPA)
20+
fa4 - Flash Attention 4
21+
fa4_fp8 - Flash Attention 4 with FP8 quantization (fused RoPE + FP8 SDPA)
2022
2123
Usage:
2224
# Default: FA3 vs FA3 FP8
@@ -25,6 +27,9 @@
2527
# FA2 vs FA3
2628
python eval_llama3_model.py --baseline fa2 --test fa3
2729
30+
# FA3 vs FA4
31+
python eval_llama3_model.py --baseline fa3 --test fa4
32+
2833
# With torch.compile (applies to non-FP8 backends)
2934
python eval_llama3_model.py --compile
3035
"""
@@ -77,6 +82,17 @@
7782
"fp8_backend": AttentionBackend.FP8_FA3,
7883
"label": "FA3 FP8",
7984
},
85+
"fa4": {
86+
"flash_impl": "FA4",
87+
"fp8": False,
88+
"label": "FA4 BF16",
89+
},
90+
"fa4_fp8": {
91+
"flash_impl": "FA4",
92+
"fp8": True,
93+
"fp8_backend": AttentionBackend.FP8_FA4,
94+
"label": "FA4 FP8",
95+
},
8096
}
8197

8298
RANDOM_SEED = 42
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD 3-Clause license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# Run all low-precision attention benchmarks (FA4 baseline vs FA4 FP8 test).
9+
# Usage: bash benchmarks/prototype/attention/run_all_benchmarks_fa4.sh
10+
11+
set -euo pipefail
12+
13+
BENCH_DIR="benchmarks/prototype/attention"
14+
BASELINE="fa4"
15+
TEST="fa4_fp8"
16+
17+
echo "================================================================"
18+
echo " Low-Precision Attention Benchmarks ($BASELINE vs $TEST)"
19+
echo "================================================================"
20+
21+
# --------------------------------------------------------------------------
22+
# 1. Single attention layer benchmark
23+
# --------------------------------------------------------------------------
24+
echo ""
25+
echo "================================================================"
26+
echo " [1/9] benchmark_sdpa.py — Single Attention Layer"
27+
echo "================================================================"
28+
python "$BENCH_DIR/benchmark_sdpa.py" --baseline "$BASELINE" --test "$TEST"
29+
30+
# --------------------------------------------------------------------------
31+
# 2. LLaMA 3 model benchmarks (4 configurations)
32+
# --------------------------------------------------------------------------
33+
echo ""
34+
echo "================================================================"
35+
echo " [2/9] eval_llama3_model.py — No compile, no fuse_rope_using_torch_compile"
36+
echo "================================================================"
37+
python "$BENCH_DIR/eval_llama3_model.py" --baseline "$BASELINE" --test "$TEST"
38+
39+
echo ""
40+
echo "================================================================"
41+
echo " [3/9] eval_llama3_model.py — Compile, no fuse_rope_using_torch_compile"
42+
echo "================================================================"
43+
python "$BENCH_DIR/eval_llama3_model.py" --baseline "$BASELINE" --test "$TEST" --compile
44+
45+
echo ""
46+
echo "================================================================"
47+
echo " [4/9] eval_llama3_model.py — No compile, fuse_rope_using_torch_compile"
48+
echo "================================================================"
49+
python "$BENCH_DIR/eval_llama3_model.py" --baseline "$BASELINE" --test "$TEST" --fuse_rope_using_torch_compile
50+
51+
echo ""
52+
echo "================================================================"
53+
echo " [5/9] eval_llama3_model.py — Compile, fuse_rope_using_torch_compile"
54+
echo "================================================================"
55+
python "$BENCH_DIR/eval_llama3_model.py" --baseline "$BASELINE" --test "$TEST" --compile --fuse_rope_using_torch_compile
56+
57+
# --------------------------------------------------------------------------
58+
# 3. FLUX model benchmarks (4 configurations)
59+
# --------------------------------------------------------------------------
60+
echo ""
61+
echo "================================================================"
62+
echo " [6/9] eval_flux_model.py — No compile, no fuse_rope_using_torch_compile"
63+
echo "================================================================"
64+
python "$BENCH_DIR/eval_flux_model.py" --baseline "$BASELINE" --test "$TEST"
65+
66+
echo ""
67+
echo "================================================================"
68+
echo " [7/9] eval_flux_model.py — Compile, no fuse_rope_using_torch_compile"
69+
echo "================================================================"
70+
python "$BENCH_DIR/eval_flux_model.py" --baseline "$BASELINE" --test "$TEST" --compile
71+
72+
echo ""
73+
echo "================================================================"
74+
echo " [8/9] eval_flux_model.py — No compile, fuse_rope_using_torch_compile"
75+
echo "================================================================"
76+
python "$BENCH_DIR/eval_flux_model.py" --baseline "$BASELINE" --test "$TEST" --fuse_rope_using_torch_compile
77+
78+
echo ""
79+
echo "================================================================"
80+
echo " [9/9] eval_flux_model.py — Compile, fuse_rope_using_torch_compile"
81+
echo "================================================================"
82+
python "$BENCH_DIR/eval_flux_model.py" --baseline "$BASELINE" --test "$TEST" --compile --fuse_rope_using_torch_compile
83+
84+
echo ""
85+
echo "================================================================"
86+
echo " All benchmarks complete."
87+
echo "================================================================"

test/prototype/attention/test_fp8_attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,15 @@ def _build_backend_configs() -> List[BackendConfig]:
131131
and _is_fa4_available()
132132
)
133133
if fa4_available:
134-
from torchao.prototype.attention.fp8_fa4.attention import fp8_fa4_sdpa
134+
from torchao.prototype.attention.fp8_fa4.attention import (
135+
fp8_fa4_rope_sdpa,
136+
fp8_fa4_sdpa,
137+
)
135138

136-
sdpa_fn = fp8_fa4_sdpa
139+
sdpa_fn, rope_sdpa_fn = fp8_fa4_sdpa, fp8_fa4_rope_sdpa
137140
eager_ok = _probe_eager_quantized_sdpa(sdpa_fn, "FA4")
138141
else:
139-
sdpa_fn = None
142+
sdpa_fn = rope_sdpa_fn = None
140143
eager_ok = False
141144

142145
configs.append(
@@ -145,7 +148,7 @@ def _build_backend_configs() -> List[BackendConfig]:
145148
flash_impl="FA4",
146149
attention_backend=AttentionBackend.FP8_FA4,
147150
sdpa_fn=sdpa_fn,
148-
rope_sdpa_fn=None, # FA4 rope not yet available
151+
rope_sdpa_fn=rope_sdpa_fn,
149152
available_eager=eager_ok,
150153
available_compiled=eager_ok,
151154
skip_msg=(

torchao/prototype/attention/fp8_fa4/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
For lower-level access, use fp8_fa4_sdpa() directly.
1212
"""
1313

14-
from torchao.prototype.attention.fp8_fa4.attention import fp8_fa4_sdpa
14+
from torchao.prototype.attention.fp8_fa4.attention import (
15+
fp8_fa4_rope_sdpa,
16+
fp8_fa4_sdpa,
17+
)
1518
from torchao.prototype.attention.quantization import _fp8_sdpa_quantize
1619

1720
__all__ = [
1821
"fp8_fa4_sdpa",
22+
"fp8_fa4_rope_sdpa",
1923
"_fp8_sdpa_quantize",
2024
]

torchao/prototype/attention/fp8_fa4/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,16 @@
3535
from functools import partial
3636

3737
from torchao.prototype.attention.shared_utils.attention import (
38+
_fp8_rope_sdpa,
3839
_fp8_sdpa,
3940
)
4041

4142
fp8_fa4_sdpa = partial(_fp8_sdpa, backend_name="FA4")
4243
fp8_fa4_sdpa.__doc__ = _fp8_sdpa.__doc__
4344
fp8_fa4_sdpa.__name__ = "fp8_fa4_sdpa"
4445
fp8_fa4_sdpa.__qualname__ = "fp8_fa4_sdpa"
46+
47+
fp8_fa4_rope_sdpa = partial(_fp8_rope_sdpa, backend_name="FA4")
48+
fp8_fa4_rope_sdpa.__doc__ = _fp8_rope_sdpa.__doc__
49+
fp8_fa4_rope_sdpa.__name__ = "fp8_fa4_rope_sdpa"
50+
fp8_fa4_rope_sdpa.__qualname__ = "fp8_fa4_rope_sdpa"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
FA4-specific FX graph fusion pass and compile helper.
9+
10+
Registers FA4 custom ops (torchao::fp8_fa4_rope_sdpa, torchao::fp8_fa4_sdpa)
11+
via the shared factory, and exposes ``rope_sdpa_fusion_pass`` and
12+
``compile_with_fp8_fusion`` for use by ``fp8_fa4/setup.py``.
13+
14+
Pattern detection, graph surgery, and the main fusion loop are in
15+
torchao.prototype.attention.shared_utils.fusion_utils.
16+
"""
17+
18+
from torchao.prototype.attention.fp8_fa4.attention import (
19+
fp8_fa4_rope_sdpa,
20+
fp8_fa4_sdpa,
21+
)
22+
from torchao.prototype.attention.shared_utils.custom_ops import (
23+
make_compile_fn,
24+
make_fusion_pass,
25+
register_fp8_attention_ops,
26+
)
27+
28+
# Register FA4 custom ops at import time.
29+
_ops = register_fp8_attention_ops(
30+
backend_name="fa4",
31+
rope_sdpa_fn=fp8_fa4_rope_sdpa,
32+
sdpa_fn=fp8_fa4_sdpa,
33+
)
34+
35+
# FA4-specific fusion pass entry point.
36+
rope_sdpa_fusion_pass = make_fusion_pass(_ops, backend_name="FA4", max_head_dim=256)
37+
38+
# FA4-specific compile helper.
39+
compile_with_fp8_fusion = make_compile_fn(rope_sdpa_fusion_pass, flash_impl_name="FA4")

torchao/prototype/attention/fp8_fa4/setup.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
FP8 FA4 backend setup.
99
1010
Thin wrapper around the shared ``setup_fp8_backend``, binding the FA4
11-
attention function.
11+
attention function and FA4 compile helper.
1212
"""
1313

1414
import torch.nn as nn
@@ -17,24 +17,20 @@
1717
from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend
1818

1919

20-
def _compile_not_available(model, config):
21-
raise NotImplementedError(
22-
"FA4 RoPE fusion (fuse_rope=True) is not yet available. "
23-
"Use fuse_rope=False (default) for the monkey-patch path."
24-
)
25-
26-
2720
def setup_fp8_fa4(
2821
model: nn.Module,
2922
config: LowPrecisionAttentionConfig,
3023
) -> nn.Module:
3124
"""Set up FP8 FA4 attention on *model* and wrap it."""
3225
from torchao.prototype.attention.fp8_fa4.attention import fp8_fa4_sdpa
26+
from torchao.prototype.attention.fp8_fa4.fusion_pass import (
27+
compile_with_fp8_fusion,
28+
)
3329

3430
return setup_fp8_backend(
3531
model,
3632
config,
3733
flash_impl_name="FA4",
3834
sdpa_fn=fp8_fa4_sdpa,
39-
compile_fn=_compile_not_available,
35+
compile_fn=compile_with_fp8_fusion,
4036
)

0 commit comments

Comments
 (0)