Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
cf8280f
Update
howardzhang-cv Feb 11, 2026
9acfc52
Update (base update)
howardzhang-cv Feb 11, 2026
88dff89
Update
howardzhang-cv Feb 11, 2026
11e7cad
Update
howardzhang-cv Feb 12, 2026
95cccd5
Update (base update)
howardzhang-cv Feb 12, 2026
fdf88ac
Update
howardzhang-cv Feb 12, 2026
ad075ac
Update (base update)
howardzhang-cv Feb 12, 2026
878b464
Update
howardzhang-cv Feb 13, 2026
3be7bbb
Update
howardzhang-cv Feb 13, 2026
3eea34a
Update (base update)
howardzhang-cv Feb 13, 2026
333e08c
Update
howardzhang-cv Feb 13, 2026
8e227d0
Update
howardzhang-cv Feb 13, 2026
d85dcc2
Update
howardzhang-cv Feb 21, 2026
56ba611
Update (base update)
howardzhang-cv Feb 21, 2026
aac4e70
Update
howardzhang-cv Feb 24, 2026
9756826
Update (base update)
howardzhang-cv Feb 24, 2026
32858e9
Update
howardzhang-cv Feb 25, 2026
548d7ef
Update
howardzhang-cv Feb 25, 2026
e3c6014
Update (base update)
howardzhang-cv Feb 25, 2026
97eafd5
Update
howardzhang-cv Feb 27, 2026
0a042ad
Update (base update)
howardzhang-cv Feb 27, 2026
b6e59d0
Update
howardzhang-cv Feb 28, 2026
44a7429
Update (base update)
howardzhang-cv Feb 28, 2026
a64a978
Update
howardzhang-cv Feb 28, 2026
411886b
Update (base update)
howardzhang-cv Feb 28, 2026
264d2bd
Update
howardzhang-cv Feb 28, 2026
74f3cfd
Update (base update)
howardzhang-cv Feb 28, 2026
708547f
Update
howardzhang-cv Mar 2, 2026
d60829a
Update (base update)
howardzhang-cv Mar 2, 2026
1d26fd8
Update
howardzhang-cv Mar 3, 2026
68efede
Update (base update)
howardzhang-cv Mar 3, 2026
e5a8c5a
Update
howardzhang-cv Mar 3, 2026
fec81e6
Update (base update)
howardzhang-cv Mar 5, 2026
669829e
Update
howardzhang-cv Mar 5, 2026
edb1f38
Update (base update)
howardzhang-cv Mar 6, 2026
7db5ce9
Update
howardzhang-cv Mar 6, 2026
58b0e6a
Update (base update)
howardzhang-cv Mar 6, 2026
d18f997
Update
howardzhang-cv Mar 6, 2026
100382a
Update (base update)
howardzhang-cv Mar 6, 2026
58c838f
Update
howardzhang-cv Mar 6, 2026
c348a9f
Update (base update)
howardzhang-cv Mar 7, 2026
a719b90
Update
howardzhang-cv Mar 7, 2026
f140854
Update (base update)
howardzhang-cv Mar 7, 2026
ed23fd0
Update
howardzhang-cv Mar 7, 2026
94d9200
Merge branch 'main' into gh/howardzhang-cv/16/head
howardzhang-cv Mar 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 124 additions & 1 deletion test/prototype/attention/test_fp8_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,29 @@
AttentionBackend,
apply_low_precision_attention,
)
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
from torchao.prototype.attention.fp8_fa3.attention import (
fp8_fa3_rope_sdpa,
fp8_fa3_sdpa,
)


def _rope_cos_sin(S, D, device):
freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
angles = torch.outer(torch.arange(S, dtype=torch.float32), freqs)
cos_half = torch.cos(angles)
sin_half = torch.sin(angles)
cos = torch.cat([cos_half, cos_half], dim=-1).to(device)
sin = torch.cat([sin_half, sin_half], dim=-1).to(device)
return cos, sin


def _apply_rope(x, cos, sin):
"""NeoX rotate-half RoPE. x: [B, S, H, D], cos/sin: [S, D]."""
D_HALF = x.shape[-1] // 2
rotate = torch.cat([-x[..., D_HALF:], x[..., :D_HALF]], dim=-1)
return (
x * cos.unsqueeze(0).unsqueeze(2) + rotate * sin.unsqueeze(0).unsqueeze(2)
).to(x.dtype)


class SimpleAttentionModel(nn.Module):
Expand All @@ -52,6 +74,30 @@ def forward(self, x):
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))


class SimpleRoPEAttentionModel(nn.Module):
"""Applies RoPE to Q and K immediately before SDPA (Pattern A: RoPE → transpose → SDPA)."""

def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

def forward(self, x, cos, sin):
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim)
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim)
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim)
q = _apply_rope(q, cos, sin).transpose(1, 2)
k = _apply_rope(k, cos, sin).transpose(1, 2)
v = v.transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))


@common_utils.instantiate_parametrized_tests
class TestFP8FA3Attention(TestCase):
@unittest.skipUnless(
Expand Down Expand Up @@ -83,6 +129,41 @@ def test_sdpa_accuracy(self, shape, dtype):
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("shape", [(2, 1024, 8, 64), (1, 1024, 16, 128)])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_rope_sdpa_accuracy(self, shape, dtype):
B, S, H, D = shape
q = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
k = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
v = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
cos, sin = _rope_cos_sin(S, D, "cuda")

with torch.no_grad():
out_ref = F.scaled_dot_product_attention(
_apply_rope(q, cos, sin).transpose(1, 2),
_apply_rope(k, cos, sin).transpose(1, 2),
v.transpose(1, 2),
is_causal=False,
)

activate_flash_attention_impl("FA3")
try:
with torch.no_grad():
out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False)
finally:
restore_flash_attention_impl()

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
25.0,
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
Expand Down Expand Up @@ -122,6 +203,48 @@ def test_monkey_patch_model(self, dtype):
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_rope_fusion_model(self, dtype):
embed_dim, num_heads = 512, 8
model = (
SimpleRoPEAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
S = 128
x = torch.randn(2, S, embed_dim, device="cuda", dtype=dtype)
cos, sin = _rope_cos_sin(S, embed_dim // num_heads, "cuda")

with torch.no_grad():
out_ref = model(x, cos, sin)

fp8_model = (
SimpleRoPEAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
fp8_model.load_state_dict(model.state_dict())
fp8_model = apply_low_precision_attention(
fp8_model,
backend=AttentionBackend.FP8_FA3,
fuse_rope_using_torch_compile=True,
)
fp8_model = torch.compile(fp8_model, backend=fp8_model.compile_backend)

with torch.no_grad():
out_fp8 = fp8_model(x, cos, sin)

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
20.0,
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
)


if __name__ == "__main__":
run_tests()
7 changes: 7 additions & 0 deletions torchao/prototype/attention/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def apply_low_precision_attention(
Must be called before ``torch.compile``. KV caching should be
disabled before calling (e.g., ``config.use_cache = False`` for
HuggingFace models).

When ``fuse_rope_using_torch_compile=True``, the returned wrapper
exposes a ``compile_backend`` attribute. You must compile with it to get
the RoPE fusion::

model = apply_low_precision_attention(model, fuse_rope_using_torch_compile=True)
model = torch.compile(model, backend=model.compile_backend)
"""
if isinstance(model, _LowPrecisionAttentionWrapper):
raise RuntimeError(
Expand Down
10 changes: 8 additions & 2 deletions torchao/prototype/attention/fp8_fa3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""FP8 attention using FA3 backend."""
"""
FP8 attention using FA3 backend.
"""

from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
from torchao.prototype.attention.fp8_fa3.attention import (
fp8_fa3_rope_sdpa,
fp8_fa3_sdpa,
)
from torchao.prototype.attention.quantization import _fp8_sdpa_quantize

__all__ = [
"fp8_fa3_sdpa",
"fp8_fa3_rope_sdpa",
"_fp8_sdpa_quantize",
]
19 changes: 16 additions & 3 deletions torchao/prototype/attention/fp8_fa3/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,32 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""FP8 SDPA using FA3 backend.
"""
FP8 SDPA using FA3 backend.

When using these functions directly (not through apply_low_precision_attention),
you must activate FA3 yourself::

Thin wrapper around ``shared_utils/attention.py``. When using directly,
activate the FA3 flash attention implementation before calling.
activate_flash_attention_impl("FA3")
try:
out = fp8_fa3_sdpa(q, k, v, is_causal=True)
finally:
restore_flash_attention_impl()
"""

from functools import partial

from torchao.prototype.attention.shared_utils.attention import (
_fp8_rope_sdpa,
_fp8_sdpa,
)

fp8_fa3_sdpa = partial(_fp8_sdpa, backend_name="FA3")
fp8_fa3_sdpa.__doc__ = _fp8_sdpa.__doc__
fp8_fa3_sdpa.__name__ = "fp8_fa3_sdpa"
fp8_fa3_sdpa.__qualname__ = "fp8_fa3_sdpa"

fp8_fa3_rope_sdpa = partial(_fp8_rope_sdpa, backend_name="FA3")
fp8_fa3_rope_sdpa.__doc__ = _fp8_rope_sdpa.__doc__
fp8_fa3_rope_sdpa.__name__ = "fp8_fa3_rope_sdpa"
fp8_fa3_rope_sdpa.__qualname__ = "fp8_fa3_rope_sdpa"
22 changes: 22 additions & 0 deletions torchao/prototype/attention/fp8_fa3/fusion_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from torchao.prototype.attention.fp8_fa3.attention import (
fp8_fa3_rope_sdpa,
fp8_fa3_sdpa,
)
from torchao.prototype.attention.shared_utils.custom_ops import (
make_backend_fn,
register_fp8_attention_ops,
)

_ops = register_fp8_attention_ops(
backend_name="fa3",
rope_sdpa_fn=fp8_fa3_rope_sdpa,
sdpa_fn=fp8_fa3_sdpa,
)

make_fp8_backend = make_backend_fn(_ops, backend_name="FA3", flash_impl_name="FA3")
10 changes: 6 additions & 4 deletions torchao/prototype/attention/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""Shared FP8 quantization kernels for low-precision attention."""

from torchao.prototype.attention.quantization.quantization import (
_fp8_sdpa_quantize,
from torchao.prototype.attention.quantization.triton_qkv_quantization import (
triton_fp8_sdpa_quantize as _fp8_sdpa_quantize,
)
from torchao.prototype.attention.quantization.triton_rope_qkv_quantization import (
triton_fp8_rope_sdpa_quantize as _fp8_rope_sdpa_quantize,
)

__all__ = [
"_fp8_sdpa_quantize",
"_fp8_rope_sdpa_quantize",
]
Loading
Loading