Skip to content

Commit 99f237b

Browse files
Added prototype low precision attention API to the docs
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: de96bfb Pull-Request: #4056
1 parent dac0fae commit 99f237b

File tree

4 files changed

+140
-0
lines changed

4 files changed

+140
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
.. _api_attention:
2+
3+
=======================================
4+
torchao.prototype.attention (prototype)
5+
=======================================
6+
7+
.. currentmodule:: torchao.prototype.attention
8+
9+
High-Level API
10+
--------------
11+
12+
.. autosummary::
13+
:toctree: generated/
14+
:nosignatures:
15+
16+
apply_low_precision_attention
17+
AttentionBackend
18+
19+
.. currentmodule:: torchao.prototype.attention.fp8_fa3.attention
20+
21+
Direct Usage (FA3)
22+
------------------
23+
24+
.. autosummary::
25+
:toctree: generated/
26+
:nosignatures:
27+
28+
fp8_fa3_sdpa
29+
fp8_fa3_rope_sdpa

docs/source/api_reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ Comprehensive API documentation for torchao.
1212
api_ref_float8
1313
api_ref_utils
1414
api_ref_prototype_quant_logger
15+
api_ref_attention

docs/source/workflows/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ This page provides an overview of the various workflows available in torchao.
1010
[int8 dense](https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training)
1111
* QAT: the [QAT documentation](qat.md) for details on how to use quantization-aware training to improve model accuracy after quantization.
1212
* Inference: See the [inference quantization documentation](inference.md) for an overview of quantization for inference workflows.
13+
* Low-Precision Attention: See the [low-precision attention documentation](low_precision_attention.md) for FP8 attention using Flash Attention backends.
1314

1415
## Workflows status by dtype + hardware
1516

@@ -61,4 +62,5 @@ This page provides an overview of the various workflows available in torchao.
6162
training
6263
qat
6364
inference
65+
low_precision_attention
6466
```
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Low-Precision Attention (Prototype)
2+
3+
FP8 low-precision attention for inference, built on Flash Attention backends. Currently supports FA3 on Hopper (SM90) architectures, with FA4 for Blackwell coming soon.
4+
5+
> **Requirements:** PyTorch >= 2.11, Hopper GPU (H100/H200), `flash-attn` with FA3 support.
6+
7+
> **Note:** Only the forward pass is supported — backward is not supported by the underlying backends.
8+
9+
## High-Level API
10+
11+
The simplest way to enable low-precision attention is `apply_low_precision_attention`, which wraps your model to replace all `F.scaled_dot_product_attention` calls with FP8 attention:
12+
13+
```python
14+
from torchao.prototype.attention import apply_low_precision_attention
15+
16+
model = ... # your model
17+
model = apply_low_precision_attention(model)
18+
```
19+
20+
This must be called **before** `torch.compile`.
21+
22+
### Disabling KV Caching
23+
24+
KV caching should be disabled before calling `apply_low_precision_attention` (e.g., `config.use_cache = False` for HuggingFace models).
25+
26+
With KV caching enabled, HuggingFace models materialize an explicit attention mask for causal layers, which blocks Flash Attention from running. The monkey-patch path detects and strips these causal masks automatically, so **FP8 attention still works in eager mode with KV caching enabled**.
27+
28+
However, **RoPE fusion under `torch.compile` requires KV caching to be disabled.** KV caching inserts a `torch.cat` operation between RoPE and SDPA (to concatenate cached and new keys/values), which breaks the pattern matching required for the fusion pass — it expects RoPE to feed directly into SDPA. With KV caching enabled, you will still get FP8 attention but without the RoPE fusion optimization.
29+
30+
### RoPE Fusion with `torch.compile`
31+
32+
If you then `torch.compile` the wrapped model, the compiler will automatically detect RoPE patterns preceding `F.scaled_dot_product_attention` and fuse them into a single `fp8_fa3_rope_sdpa` kernel:
33+
34+
```python
35+
model = apply_low_precision_attention(model)
36+
model = torch.compile(model) # RoPE fusion happens automatically
37+
```
38+
39+
The fusion pass supports two RoPE patterns:
40+
- **NeoX/LLaMA style** (half-split): `x * cos + rotate_half(x) * sin`
41+
- **Interleaved style** (FLUX): complex rotation via reshape + unbind + stack
42+
43+
> **Warning:** The RoPE fusion pass sets `torch._inductor.config.pre_grad_custom_pass`. This will overwrite any existing custom pass you may have registered.
44+
45+
### Selecting a Backend
46+
47+
By default, `apply_low_precision_attention` auto-detects the best available backend. You can also specify one explicitly:
48+
49+
```python
50+
from torchao.prototype.attention import apply_low_precision_attention, AttentionBackend
51+
52+
model = apply_low_precision_attention(model, backend=AttentionBackend.FP8_FA3)
53+
```
54+
55+
| Backend | Architecture | Status |
56+
|---------|-------------|--------|
57+
| `FP8_FA3` | Hopper (SM90) | Available |
58+
| `FP8_FA4` | Blackwell (SM100) | Coming soon |
59+
60+
## Direct Usage
61+
62+
For finer-grained control, you can use `fp8_fa3_sdpa` and `fp8_fa3_rope_sdpa` directly as drop-in replacements for `F.scaled_dot_product_attention`.
63+
64+
### `fp8_fa3_sdpa` — Drop-in SDPA Replacement
65+
66+
Replaces `F.scaled_dot_product_attention`. Input/output layout is `[B, H, S, D]`.
67+
68+
```python
69+
from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl
70+
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
71+
72+
activate_flash_attention_impl("FA3")
73+
try:
74+
out = fp8_fa3_sdpa(query, key, value, is_causal=True)
75+
finally:
76+
restore_flash_attention_impl()
77+
```
78+
79+
**Parameters:**
80+
- `query`, `key`, `value``[B, H, S, D]` tensors
81+
- `is_causal` (`bool`, default `False`) — Whether to apply causal masking
82+
- `scale` (`float | None`, default `None`) — Attention scale factor (defaults to `1/sqrt(D)`)
83+
- `enable_gqa` (`bool`, default `False`) — Enable grouped-query attention
84+
85+
> **Note:** `attn_mask` and `dropout_p` are accepted for signature compatibility but must be `None` and `0.0` respectively.
86+
87+
### `fp8_fa3_rope_sdpa` — Fused RoPE + SDPA
88+
89+
Fuses RoPE application with FP8 attention in a single kernel. Input layout is `[B, S, H, D]` (pre-transpose); output layout is `[B, H, S, D]`.
90+
91+
```python
92+
from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl
93+
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_rope_sdpa
94+
95+
activate_flash_attention_impl("FA3")
96+
try:
97+
out = fp8_fa3_rope_sdpa(
98+
query, key, value, cos, sin,
99+
is_causal=True,
100+
rope_interleaved=False, # NeoX/LLaMA style
101+
)
102+
finally:
103+
restore_flash_attention_impl()
104+
```
105+
106+
**Additional parameters:**
107+
- `cos`, `sin` — RoPE frequency tensors, shape `[S, D]`
108+
- `rope_interleaved` (`bool`, default `False`) — `False` for NeoX/LLaMA half-split rotation, `True` for FLUX-style interleaved rotation

0 commit comments

Comments
 (0)