|
| 1 | +# Low-Precision Attention (Prototype) |
| 2 | + |
| 3 | +FP8 low-precision attention for inference, built on Flash Attention backends. Currently supports FA3 on Hopper (SM90) architectures, with FA4 for Blackwell coming soon. |
| 4 | + |
| 5 | +> **Requirements:** PyTorch >= 2.11, Hopper GPU (H100/H200), `flash-attn` with FA3 support. |
| 6 | +
|
| 7 | +> **Note:** Only the forward pass is supported — backward is not supported by the underlying backends. |
| 8 | +
|
| 9 | +## High-Level API |
| 10 | + |
| 11 | +The simplest way to enable low-precision attention is `apply_low_precision_attention`, which wraps your model to replace all `F.scaled_dot_product_attention` calls with FP8 attention: |
| 12 | + |
| 13 | +```python |
| 14 | +from torchao.prototype.attention import apply_low_precision_attention |
| 15 | + |
| 16 | +model = ... # your model |
| 17 | +model = apply_low_precision_attention(model) |
| 18 | +``` |
| 19 | + |
| 20 | +This must be called **before** `torch.compile`. |
| 21 | + |
| 22 | +### Disabling KV Caching |
| 23 | + |
| 24 | +KV caching should be disabled before calling `apply_low_precision_attention` (e.g., `config.use_cache = False` for HuggingFace models). |
| 25 | + |
| 26 | +With KV caching enabled, HuggingFace models materialize an explicit attention mask for causal layers, which blocks Flash Attention from running. The monkey-patch path detects and strips these causal masks automatically, so **FP8 attention still works in eager mode with KV caching enabled**. |
| 27 | + |
| 28 | +However, **RoPE fusion under `torch.compile` requires KV caching to be disabled.** KV caching inserts a `torch.cat` operation between RoPE and SDPA (to concatenate cached and new keys/values), which breaks the pattern matching required for the fusion pass — it expects RoPE to feed directly into SDPA. With KV caching enabled, you will still get FP8 attention but without the RoPE fusion optimization. |
| 29 | + |
| 30 | +### RoPE Fusion with `torch.compile` |
| 31 | + |
| 32 | +If you then `torch.compile` the wrapped model, the compiler will automatically detect RoPE patterns preceding `F.scaled_dot_product_attention` and fuse them into a single `fp8_fa3_rope_sdpa` kernel: |
| 33 | + |
| 34 | +```python |
| 35 | +model = apply_low_precision_attention(model) |
| 36 | +model = torch.compile(model) # RoPE fusion happens automatically |
| 37 | +``` |
| 38 | + |
| 39 | +The fusion pass supports two RoPE patterns: |
| 40 | +- **NeoX/LLaMA style** (half-split): `x * cos + rotate_half(x) * sin` |
| 41 | +- **Interleaved style** (FLUX): complex rotation via reshape + unbind + stack |
| 42 | + |
| 43 | +> **Warning:** The RoPE fusion pass sets `torch._inductor.config.pre_grad_custom_pass`. This will overwrite any existing custom pass you may have registered. |
| 44 | +
|
| 45 | +### Selecting a Backend |
| 46 | + |
| 47 | +By default, `apply_low_precision_attention` auto-detects the best available backend. You can also specify one explicitly: |
| 48 | + |
| 49 | +```python |
| 50 | +from torchao.prototype.attention import apply_low_precision_attention, AttentionBackend |
| 51 | + |
| 52 | +model = apply_low_precision_attention(model, backend=AttentionBackend.FP8_FA3) |
| 53 | +``` |
| 54 | + |
| 55 | +| Backend | Architecture | Status | |
| 56 | +|---------|-------------|--------| |
| 57 | +| `FP8_FA3` | Hopper (SM90) | Available | |
| 58 | +| `FP8_FA4` | Blackwell (SM100) | Coming soon | |
| 59 | + |
| 60 | +## Direct Usage |
| 61 | + |
| 62 | +For finer-grained control, you can use `fp8_fa3_sdpa` and `fp8_fa3_rope_sdpa` directly as drop-in replacements for `F.scaled_dot_product_attention`. |
| 63 | + |
| 64 | +### `fp8_fa3_sdpa` — Drop-in SDPA Replacement |
| 65 | + |
| 66 | +Replaces `F.scaled_dot_product_attention`. Input/output layout is `[B, H, S, D]`. |
| 67 | + |
| 68 | +```python |
| 69 | +from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl |
| 70 | +from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa |
| 71 | + |
| 72 | +activate_flash_attention_impl("FA3") |
| 73 | +try: |
| 74 | + out = fp8_fa3_sdpa(query, key, value, is_causal=True) |
| 75 | +finally: |
| 76 | + restore_flash_attention_impl() |
| 77 | +``` |
| 78 | + |
| 79 | +**Parameters:** |
| 80 | +- `query`, `key`, `value` — `[B, H, S, D]` tensors |
| 81 | +- `is_causal` (`bool`, default `False`) — Whether to apply causal masking |
| 82 | +- `scale` (`float | None`, default `None`) — Attention scale factor (defaults to `1/sqrt(D)`) |
| 83 | +- `enable_gqa` (`bool`, default `False`) — Enable grouped-query attention |
| 84 | + |
| 85 | +> **Note:** `attn_mask` and `dropout_p` are accepted for signature compatibility but must be `None` and `0.0` respectively. |
| 86 | +
|
| 87 | +### `fp8_fa3_rope_sdpa` — Fused RoPE + SDPA |
| 88 | + |
| 89 | +Fuses RoPE application with FP8 attention in a single kernel. Input layout is `[B, S, H, D]` (pre-transpose); output layout is `[B, H, S, D]`. |
| 90 | + |
| 91 | +```python |
| 92 | +from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl |
| 93 | +from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_rope_sdpa |
| 94 | + |
| 95 | +activate_flash_attention_impl("FA3") |
| 96 | +try: |
| 97 | + out = fp8_fa3_rope_sdpa( |
| 98 | + query, key, value, cos, sin, |
| 99 | + is_causal=True, |
| 100 | + rope_interleaved=False, # NeoX/LLaMA style |
| 101 | + ) |
| 102 | +finally: |
| 103 | + restore_flash_attention_impl() |
| 104 | +``` |
| 105 | + |
| 106 | +**Additional parameters:** |
| 107 | +- `cos`, `sin` — RoPE frequency tensors, shape `[S, D]` |
| 108 | +- `rope_interleaved` (`bool`, default `False`) — `False` for NeoX/LLaMA half-split rotation, `True` for FLUX-style interleaved rotation |
0 commit comments