Skip to content

Commit 20f2055

Browse files
Add FA4 monkey-patch path for low-precision attention
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: fa10283 Pull-Request: pytorch#3960
1 parent 8ddcd98 commit 20f2055

File tree

7 files changed

+201
-6
lines changed

7 files changed

+201
-6
lines changed

test/prototype/attention/test_fp8_attention.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""
8-
Tests for FP8 low-precision attention (FA3 backend).
8+
Tests for FP8 low-precision attention (FA3 and FA4 backends).
99
10-
Tests are gated on Hopper (SM 9.x) with flash-attn installed.
11-
When the backend is not available on the current hardware, tests are
12-
automatically skipped.
10+
Tests are parametrized over available backends. On Hopper (SM 9.x) with
11+
flash-attn installed, FA3 tests run. On Hopper or Blackwell (SM 10.x)
12+
with flash_attn.cute.interface installed, FA4 tests run. Backends that
13+
are not available on the current hardware are automatically skipped.
1314
"""
1415

1516
import unittest
@@ -45,7 +46,9 @@
4546
apply_low_precision_attention,
4647
)
4748
from torchao.prototype.attention.utils import (
49+
_is_blackwell,
4850
_is_fa3_available,
51+
_is_fa4_available,
4952
_is_hopper,
5053
)
5154

@@ -58,10 +61,10 @@ class BackendConfig:
5861
"""Configuration for a single backend under test."""
5962

6063
name: str
61-
flash_impl: str # "FA3"
64+
flash_impl: str # "FA3" or "FA4"
6265
attention_backend: AttentionBackend
6366
sdpa_fn: Callable # fp8_fa3_sdpa
64-
rope_sdpa_fn: Callable # fp8_fa3_rope_sdpa
67+
rope_sdpa_fn: Callable # fp8_fa3_rope_sdpa, or None if not yet available
6568
available_eager: bool # Can run direct sdpa calls
6669
available_compiled: bool # Can run via apply_low_precision_attention
6770
skip_msg: str
@@ -122,6 +125,38 @@ def _build_backend_configs() -> List[BackendConfig]:
122125
)
123126
)
124127

128+
# FA4: Hopper or Blackwell
129+
fa4_available = (
130+
_has_flash_activation_api
131+
and (_is_hopper() or _is_blackwell())
132+
and _is_fa4_available()
133+
)
134+
if fa4_available:
135+
from torchao.prototype.attention.fp8_fa4.attention import fp8_fa4_sdpa
136+
137+
sdpa_fn = fp8_fa4_sdpa
138+
eager_ok = _probe_eager_quantized_sdpa(sdpa_fn, "FA4")
139+
else:
140+
sdpa_fn = None
141+
eager_ok = False
142+
143+
configs.append(
144+
BackendConfig(
145+
name="FA4",
146+
flash_impl="FA4",
147+
attention_backend=AttentionBackend.FP8_FA4,
148+
sdpa_fn=sdpa_fn,
149+
rope_sdpa_fn=None, # FA4 rope not yet available
150+
available_eager=eager_ok,
151+
available_compiled=eager_ok,
152+
skip_msg=(
153+
"FP8 FA4 requires Hopper (SM 9.x) or Blackwell (SM 10.x), "
154+
"flash-attn with FA4 support installed, "
155+
"and PyTorch with flash activation APIs"
156+
),
157+
)
158+
)
159+
125160
return configs
126161

127162

@@ -289,6 +324,8 @@ def test_rope_sdpa_accuracy(self, shape, dtype):
289324
)
290325

291326
for backend in _EAGER_BACKENDS:
327+
if backend.rope_sdpa_fn is None:
328+
continue # Backend doesn't support fused RoPE yet
292329
self._activate(backend)
293330
with torch.no_grad():
294331
out_fp8 = backend.rope_sdpa_fn(q, k, v, cos, sin, is_causal=False)

torchao/prototype/attention/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,9 @@ def apply_low_precision_attention(
108108

109109
return setup_fp8_fa3(model, config)
110110

111+
if backend == AttentionBackend.FP8_FA4:
112+
from torchao.prototype.attention.fp8_fa4.setup import setup_fp8_fa4
113+
114+
return setup_fp8_fa4(model, config)
115+
111116
raise ValueError(f"Unknown backend: {backend}")

torchao/prototype/attention/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class AttentionBackend(str, Enum):
2323
FP8_FA3 = "fa3"
2424
"""FlashAttention 3 via PyTorch core. Requires SM90+ (Hopper)."""
2525

26+
FP8_FA4 = "fa4"
27+
"""FlashAttention 4 via PyTorch core. Requires SM90+ (Hopper) or SM100+ (Blackwell)."""
28+
2629

2730
@dataclass
2831
class LowPrecisionAttentionConfig:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
FP8 attention implementation using FA4 backend.
9+
10+
Use apply_low_precision_attention() from torchao.prototype.attention as the public API.
11+
For lower-level access, use fp8_fa4_sdpa() directly.
12+
"""
13+
14+
from torchao.prototype.attention.fp8_fa4.attention import fp8_fa4_sdpa
15+
from torchao.prototype.attention.quantization import _fp8_sdpa_quantize
16+
17+
__all__ = [
18+
"fp8_fa4_sdpa",
19+
"_fp8_sdpa_quantize",
20+
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
FP8 scaled dot-product attention using FA4 backend.
9+
10+
This is a thin wrapper around the shared implementation in
11+
``shared_utils/attention.py``. It exists so that the FA4 backend has
12+
a named entry point (``fp8_fa4_sdpa``) and backend-specific error messages.
13+
14+
.. important::
15+
16+
When using this function directly (not through
17+
``apply_low_precision_attention``), you **must** activate the FA4
18+
flash attention implementation yourself::
19+
20+
from torch.nn.attention import (
21+
activate_flash_attention_impl,
22+
restore_flash_attention_impl,
23+
)
24+
25+
activate_flash_attention_impl("FA4")
26+
try:
27+
out = fp8_fa4_sdpa(q, k, v, is_causal=True)
28+
finally:
29+
restore_flash_attention_impl()
30+
31+
The high-level ``apply_low_precision_attention`` API handles this
32+
automatically.
33+
"""
34+
35+
from functools import partial
36+
37+
from torchao.prototype.attention.shared_utils.attention import (
38+
_fp8_sdpa,
39+
)
40+
41+
fp8_fa4_sdpa = partial(_fp8_sdpa, backend_name="FA4")
42+
fp8_fa4_sdpa.__doc__ = _fp8_sdpa.__doc__
43+
fp8_fa4_sdpa.__name__ = "fp8_fa4_sdpa"
44+
fp8_fa4_sdpa.__qualname__ = "fp8_fa4_sdpa"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
FP8 FA4 backend setup.
9+
10+
Thin wrapper around the shared ``setup_fp8_backend``, binding the FA4
11+
attention function.
12+
"""
13+
14+
import torch.nn as nn
15+
16+
from torchao.prototype.attention.config import LowPrecisionAttentionConfig
17+
from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend
18+
19+
20+
def _compile_not_available(model, config):
21+
raise NotImplementedError(
22+
"FA4 RoPE fusion (fuse_rope=True) is not yet available. "
23+
"Use fuse_rope=False (default) for the monkey-patch path."
24+
)
25+
26+
27+
def setup_fp8_fa4(
28+
model: nn.Module,
29+
config: LowPrecisionAttentionConfig,
30+
) -> nn.Module:
31+
"""Set up FP8 FA4 attention on *model* and wrap it."""
32+
from torchao.prototype.attention.fp8_fa4.attention import fp8_fa4_sdpa
33+
34+
return setup_fp8_backend(
35+
model,
36+
config,
37+
flash_impl_name="FA4",
38+
sdpa_fn=fp8_fa4_sdpa,
39+
compile_fn=_compile_not_available,
40+
)

torchao/prototype/attention/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ def _is_hopper() -> bool:
2525
return major == 9
2626

2727

28+
def _is_blackwell() -> bool:
29+
"""
30+
Check if the current CUDA device is Blackwell (SM 10.x).
31+
"""
32+
if not torch.cuda.is_available():
33+
return False
34+
major, _ = torch.cuda.get_device_capability()
35+
return major == 10
36+
37+
2838
def _is_fa3_available() -> bool:
2939
"""
3040
Check if the flash attention 3 library (flash_attn_interface) is installed.
@@ -36,6 +46,17 @@ def _is_fa3_available() -> bool:
3646
return False
3747

3848

49+
def _is_fa4_available() -> bool:
50+
"""
51+
Check if the flash attention 4 library (flash_attn.cute.interface) is installed.
52+
"""
53+
try:
54+
importlib.import_module("flash_attn.cute.interface")
55+
return True
56+
except ModuleNotFoundError:
57+
return False
58+
59+
3960
def _get_available_backend() -> AttentionBackend:
4061
"""
4162
Get the best available backend for current hardware.
@@ -51,10 +72,18 @@ def _get_available_backend() -> AttentionBackend:
5172

5273
capability = torch.cuda.get_device_capability()
5374

75+
# FA4 on Blackwell (SM 10.x) with flash_attn.cute.interface
76+
if _is_blackwell() and _is_fa4_available():
77+
return AttentionBackend.FP8_FA4
78+
5479
# FA3 requires exactly Hopper (SM 9.x) and flash_attn_interface
5580
if _is_hopper() and _is_fa3_available():
5681
return AttentionBackend.FP8_FA3
5782

83+
# FA4 also supports Hopper (SM 9.x) with flash_attn.cute.interface
84+
if _is_hopper() and _is_fa4_available():
85+
return AttentionBackend.FP8_FA4
86+
5887
raise RuntimeError(f"No compatible backend for SM{capability[0]}{capability[1]}.")
5988

6089

@@ -84,5 +113,22 @@ def _check_backend_available(backend: AttentionBackend) -> None:
84113
"FP8_FA3 backend requires the flash-attn package with FA3 support. "
85114
)
86115

116+
elif backend == AttentionBackend.FP8_FA4:
117+
if not torch.cuda.is_available():
118+
raise RuntimeError("FP8_FA4 backend requires CUDA.")
119+
120+
if not (_is_hopper() or _is_blackwell()):
121+
capability = torch.cuda.get_device_capability()
122+
raise RuntimeError(
123+
f"FP8_FA4 backend requires Hopper (SM 9.x) or Blackwell (SM 10.x). "
124+
f"Current device: SM{capability[0]}{capability[1]}. "
125+
)
126+
127+
if not _is_fa4_available():
128+
raise RuntimeError(
129+
"FP8_FA4 backend requires the flash-attn package with FA4 support "
130+
"(flash_attn.cute.interface). "
131+
)
132+
87133
else:
88134
raise ValueError(f"Unknown backend: {backend}")

0 commit comments

Comments
 (0)