Added new API for low precision fp8 attention using FA3#3857
Added new API for low precision fp8 attention using FA3#3857howardzhang-cv merged 45 commits intomainfrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3857
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 94d9200 with merge base aad1018 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Assuming this is gonna take a bit of back and forth to land to get it just right, since this is a user-facing change. @drisspg please take a look and let me know if the API looks okay |
| original_forward = model.forward | ||
|
|
||
| def wrapped_forward(*args, **kwargs): | ||
| with _fp8_fa3_attention_context(config): |
There was a problem hiding this comment.
not sure if we talked about this before, what's the thoughts on this v.s. following other torchao APIs, e.g.
q = Float8Tensor.from_hp(q)
k = Float8Tensor.from_hp(k)
v = Float8Tensor.from_hp(v)
# dispatch to `_fp8_fa3_sdpa` in Float8Tensor implementation
F.scaled_dot_product_attention(q, k, v, is_causal=True)
?
There was a problem hiding this comment.
Yeah, ideally we follow other torchao APIs but it's a bit difficult for attention. This first backend/recipe works well because it is a simple replacement of F.scaled_dot_product_attention, but as we add other features such as RoPE fusion or RoPE + Hadamard it becomes much more difficult, since it becomes model specific (i.e. if there is RoPE followed by SDPA, replace with RoPE + quantization fused kernel into fp8 SDPA). To future-proof other attention backends and recipes, I think it's better to have it as its own separate APIs. What are your thoughts?
There was a problem hiding this comment.
I thought the fused kernel replacement should happen in inductor? or is there a reason why we have to hand replace these in eager mode?
There was a problem hiding this comment.
Inductor unfortunately does not fuse RoPE with the quantization kernel
There was a problem hiding this comment.
oh I mean is this something that's not possible to do in inductor or just something that does not exist right now
There was a problem hiding this comment.
Not too sure, I feel like it should be possible in inductor? But it definitely is something that does not exist right now. Could look into it for the future, but for now, I think it would be simpler just to move this into prototype and have it be available through that. Will work on moving this over.
There was a problem hiding this comment.
@jerryzh168 I've gotten the rope fusion to work by using the inductor path. I've moved everything to a prototype folder. Due to the way attention works, I think it's better if it's own API, as it doesn't fit cleanly with others for now (especially with rope fusion and hadamard and stuff requiring inductor path). What do you think?
[ghstack-poisoned]
| if q.shape[3] != k.shape[3]: | ||
| raise ValueError(f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}") | ||
|
|
||
| if torch.compiler.is_compiling(): |
There was a problem hiding this comment.
is there a test comparing numerics between these two paths
There was a problem hiding this comment.
The numerics are exactly the same between the two paths. I tested the runtime and it seems like the triton implementation was much faster (~100 ms difference on llama3 model with 124k sequence length)
|
|
||
| from torchao.prototype.attention import apply_low_precision_attention | ||
|
|
||
| model = MyTransformer() |
There was a problem hiding this comment.
we should specify what kind of syntax is automatically converted to low precision here. F.SDPA? something else?
torchao/prototype/attention/api.py
Outdated
| ) | ||
|
|
||
| def fp8_attention_backend(gm, example_inputs): | ||
| """Custom Inductor backend that applies the RoPE + FP8 fusion pass.""" |
There was a problem hiding this comment.
this should specify what kind of user syntax counts as "attention"
| Different backends have different hardware requirements and capabilities. | ||
| """ | ||
|
|
||
| FP8_FA3 = "fa3" |
There was a problem hiding this comment.
make the string match the enum value name
| restore_flash_attention_impl() | ||
|
|
||
|
|
||
| def apply_low_precision_attention( |
There was a problem hiding this comment.
I think this should be explicit that parts of torch.compile is used to do the logic swap
There was a problem hiding this comment.
Added, additionally added a warning as well to be even more explicit
Summary: Added new folder for low precision attention APIs in torchao/prototype/attention Test Plan: python test/prototype/attention/test_fp8_fa3.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 25b6a97 Pull-Request: pytorch#3857
Summary: Added new folder for low precision attention APIs in torchao/prototype/attention Test Plan: python test/prototype/attention/test_fp8_fa3.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 25b6a97 Pull-Request: pytorch#3857
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.
Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}
ghstack-source-id: 4e11b16
Pull-Request: pytorch#3857
[ghstack-poisoned]
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.
Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}
ghstack-source-id: 7cf571a
Pull-Request: pytorch#3857
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.
Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}
ghstack-source-id: 35e5edf
Pull-Request: pytorch#3857
[ghstack-poisoned]
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.
Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}
ghstack-source-id: 35e5edf
Pull-Request: pytorch#3857
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.
Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}
ghstack-source-id: 35e5edf
Pull-Request: pytorch#3857
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.
Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}
ghstack-source-id: 35e5edf
Pull-Request: pytorch#3857
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Summary
nodes without RoPE.
New Files
Modified Files
Test Plan
python -m pytest test/prototype/attention/test_fp8_attention.py -vExample Usage
Results
Single-Layer Results
Results directly comparing FA3 SDPA versus FA3 fp8 SDPA (including quantization time):

Llama3 Model Results
Results comparing Llama3 model with FA3 SDPA versus Llama3 using the FA3 fp8 wrapper. Uses RoPE fusion.

Perplexity: 6.19 -> 6.24