Skip to content

Commit d3111be

Browse files
Added prototype low precision attention API to the docs
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f8e9811 Pull-Request: #4056
1 parent 36745b0 commit d3111be

File tree

6 files changed

+85
-3
lines changed

6 files changed

+85
-3
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
.. _api_attention:
2+
3+
=======================================
4+
torchao.prototype.attention (prototype)
5+
=======================================
6+
7+
.. currentmodule:: torchao.prototype.attention
8+
9+
High-Level API
10+
--------------
11+
12+
.. autosummary::
13+
:toctree: generated/
14+
:nosignatures:
15+
16+
apply_low_precision_attention
17+
AttentionBackend
18+
19+
.. currentmodule:: torchao.prototype.attention.fp8_fa3.attention
20+
21+
Direct Usage (FA3)
22+
------------------
23+
24+
.. autosummary::
25+
:toctree: generated/
26+
:nosignatures:
27+
28+
fp8_fa3_sdpa
29+
fp8_fa3_rope_sdpa

docs/source/api_reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ Comprehensive API documentation for torchao.
1212
api_ref_float8
1313
api_ref_utils
1414
api_ref_prototype_quant_logger
15+
api_ref_prototype_attention
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from torchao.prototype.attention import apply_low_precision_attention
6+
7+
8+
# Simple model with attention
9+
class MyModel(nn.Module):
10+
def __init__(self, embed_dim=512, num_heads=8):
11+
super().__init__()
12+
self.num_heads = num_heads
13+
self.head_dim = embed_dim // num_heads
14+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
15+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
16+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
17+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
18+
19+
def forward(self, x):
20+
B, S, _ = x.shape
21+
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
22+
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
23+
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
24+
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
25+
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))
26+
27+
28+
model = MyModel().to(device="cuda", dtype=torch.bfloat16).eval()
29+
30+
# Auto-detect best backend
31+
model = apply_low_precision_attention(model)
32+
33+
# Or specify a backend explicitly
34+
# model = apply_low_precision_attention(model, backend=AttentionBackend.FP8_FA3)
35+
36+
# Optional: torch.compile for RoPE fusion
37+
model = torch.compile(model)

docs/source/workflows/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ This page provides an overview of the various workflows available in torchao.
1111
* QAT: the [QAT documentation](qat.md) for details on how to use quantization-aware training to improve model accuracy after quantization.
1212
* Inference: See the [inference quantization documentation](inference.md) for an overview of quantization for inference workflows.
1313

14+
1415
## Workflows status by dtype + hardware
1516

1617
🟢 = stable, 🟡 = prototype, 🟠 = planned, ⚪ = not supported

docs/source/workflows/inference.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,15 @@ The benchmarks below were run on a single NVIDIA-A6000 GPU.
202202
| | codebook-4-64 | 10.095 | 1.73 | 8.63 | 23.11 | 4.98 |
203203

204204
You try can out these apis with the `quantize_` api as above alongside the config `CodebookWeightOnlyConfig` an example can be found in in `torchao/_models/llama/generate.py`.
205+
206+
### Low-Precision FP8 Attention (Prototype)
207+
208+
FP8 low-precision attention for inference, built on Flash Attention backends. Currently supports FA3 on Hopper (SM90) and FA4 on Blackwell (SM100).
209+
210+
**Requirements:** PyTorch >= 2.11, Hopper or Blackwell GPU, Flash Attention 3 (`pip install flash-attn-3 --index-url=https://download.pytorch.org/whl/{cuda_version}`).
211+
212+
```{literalinclude} ../examples/prototype/low_precision_attention.py
213+
:language: python
214+
```
215+
216+
`apply_low_precision_attention` replaces all `F.scaled_dot_product_attention` calls with FP8 attention for eager execution. When combined with `torch.compile`, RoPE patterns are automatically detected and fused into a single kernel. KV caching should be disabled before calling for best results with `torch.compile`. See the {ref}`API reference <api_attention>` for details.

torchao/prototype/attention/api.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@ def apply_low_precision_attention(
6969
7070
This replaces ``F.scaled_dot_product_attention`` with an FP8 SDPA
7171
for eager execution and sets a global pre-grad pass so that
72-
``torch.compile`` will automatically fuse RoPE where detected::
72+
``torch.compile`` will automatically fuse RoPE where detected.
7373
74-
model = apply_low_precision_attention(model)
75-
model = torch.compile(model) # RoPE fusion happens automatically
74+
Example:
75+
76+
.. literalinclude:: ../../examples/prototype/low_precision_attention.py
77+
:language: python
7678
"""
7779
if not _TORCH_VERSION_AT_LEAST_2_11:
7880
raise RuntimeError("Low-precision attention requires PyTorch 2.11+.")

0 commit comments

Comments
 (0)