Add FP8 FA3 low-precision attention with monkey-patch SDPA path#3959
Add FP8 FA3 low-precision attention with monkey-patch SDPA path#3959howardzhang-cv merged 17 commits intomainfrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3959
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 699c526 with merge base 42bcdc4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: 4f8d3fb Pull-Request: pytorch#3959
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: 4f8d3fb Pull-Request: pytorch#3959
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: 4f8d3fb Pull-Request: pytorch#3959
[ghstack-poisoned]
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: 921d691 Pull-Request: pytorch#3959
torchao/prototype/attention/api.py
Outdated
| """ | ||
| Apply low-precision attention to a model. | ||
|
|
||
| Depending on the configuration, the model is either: |
There was a problem hiding this comment.
this is a bit confusing, can we make the distinction more clear? either with a boolean in this API, or just have two APIs, etc
torchao/prototype/attention/api.py
Outdated
| - **Monkey-patch path** (``fuse_rope=False``, default): wraps the model | ||
| so that ``F.scaled_dot_product_attention`` is replaced with the FP8 | ||
| backend at call time. No ``torch.compile`` is needed. | ||
| - **Compile path** (``fuse_rope=True``): internally calls |
There was a problem hiding this comment.
if fuse_rope also changes compile settings, let's call it something like fuse_rope_using_torch_compile?
There was a problem hiding this comment.
Yeah that makes sense, I changed the name
| Args: | ||
| backend: Attention backend to use. If None (default), automatically | ||
| selected based on hardware capabilities. | ||
| use_hadamard: Apply Hadamard transform. Options: |
There was a problem hiding this comment.
this looks like a boolean but is actually an enum, maybe the name can be clearer?
There was a problem hiding this comment.
Changed to hadamard_mode
| - None: No Hadamard transform (default) | ||
| - "v": Apply Hadamard to V only | ||
| - "qkv": Apply Hadamard to Q, K, and V | ||
| fuse_rope: If True, the model is compiled with a custom Inductor |
There was a problem hiding this comment.
this seems to be about fusing rope as well as implementation details of compile usage, can the name reflect this?
There was a problem hiding this comment.
Changed to fuse_rope_using_torch_compile
torchao/prototype/attention/api.py
Outdated
| Apply low-precision attention to a model. | ||
|
|
||
| Depending on the configuration, the model is either: | ||
| - **Monkey-patch path** (``fuse_rope=False``, default): wraps the model |
There was a problem hiding this comment.
if user starts with mokey patch patch and calls torch.compile on it, do they get the compile path? if yes, why not just delete the setting for compile path from this API and have it automatically work if compile is on?
There was a problem hiding this comment.
No, if the user starts with monkey patch path and calls torch.compile, it will still just be a torch.compiled model but the F.SDPA is replaced with the triton quant kernel + low precision SDPA. There will be no rope fusion like the compile path. I originally had it set so that the compile path is always on, and it would simply fallback to the simpler path if no RoPE is detected, but I figured that wasn't good since the compile path uses the pre-grad IR path, and the default should be a simpler path that doesn't use that.
There was a problem hiding this comment.
where does this stand post convo?
There was a problem hiding this comment.
We'll keep the higher-level API with the --fuse_rope option. If the user enables fuse_rope, we send a warning stating that they must explicitly run torch.compile to enable the rope fusion and low precision attention pathway. No more magic behind-the-scenes compile
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: e7282fa Pull-Request: pytorch#3959
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: e7282fa Pull-Request: pytorch#3959
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: e7282fa Pull-Request: pytorch#3959
| FA3 uses _scaled_dot_product_attention_quantized internally, | ||
| which requires FA3 activation. This probe catches mismatches. | ||
| """ | ||
| try: |
There was a problem hiding this comment.
can you run a deslop pass on some of this, I kind of doubt you need code like this
| try: | ||
| q = torch.randn(1, 1, 4, 64, device="cuda", dtype=torch.bfloat16) | ||
| with torch.no_grad(): | ||
| sdpa_fn(q, q, q, is_causal=False) |
There was a problem hiding this comment.
also we should just know right?
| _scaled_dot_product_attention_quantized, | ||
| ) | ||
|
|
||
| _MIN_VERSION_ERROR = ( |
There was a problem hiding this comment.
I feel like these global error strings is weird Claude coding we could just inlin e
| # In the monkey-patch path (no torch.compile), accelerate's hooks move | ||
| # tensors to the correct device but don't call torch.cuda.set_device(). | ||
| # Triton dispatches based on current_device(), not tensor device, so | ||
| # without this guard the kernel launches on the wrong GPU's stream. |
There was a problem hiding this comment.
this is kinda weird to me .. hmmm
There was a problem hiding this comment.
Oh yeah, I was actually planning on asking you this later in the 1/1. This was weird to me too, but for some reason without this guard, triton does actually dispatch to the wrong device leading to NaNs in multi-gpu setups.
| return torch.equal(mask.broadcast_to(mask.shape), ref.expand_as(mask)) | ||
|
|
||
|
|
||
| def detect_causal_mask( |
There was a problem hiding this comment.
is this really the best way?
There was a problem hiding this comment.
Yeah, it's surprisingly difficult to have a general way to see if the attention mask is just a causal mask. Claude suggested a bunch of ideas that imo either leaned too heavily on user interaction or was not robust enough. We can't do it in the pre-grad IR pass because we don't have access to the attention mask values to see if it's causal, so I came up with this method where we do a fake forward pass to see if the masks are causal to get rid of them. I could definitely just be doing something dumb though, there likely is a better way.
| def _make_causal_aware_sdpa(fp8_sdpa_fn: Callable, strip_causal_mask: bool) -> Callable: | ||
| """Wrap an FP8 SDPA function to strip materialized causal masks. | ||
|
|
||
| HuggingFace models (e.g. LLaMA) pass a materialized lower-triangular |
There was a problem hiding this comment.
I'm surprised this is true.. are you sure ?
There was a problem hiding this comment.
Yeah I had to verify to see if it was true as well. In "SDPA" attention mode, huggingface materializes an attention mask instead of just using causal. Since FA3/FA4 aren't compatible with attention masks and only work with causal, I had to do the whole sketchy "detect mask" thing. On Blackwell though, it doesn't do this, only on Hopper.
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: e7282fa Pull-Request: pytorch#3959
| - None: No Hadamard transform (default) | ||
| - "v": Apply Hadamard to V only | ||
| - "qkv": Apply Hadamard to Q, K, and V | ||
| fuse_rope_using_torch_compile: If True, fuse RoPE + quantization + SDPA into optimized |
There was a problem hiding this comment.
if user is using torch.compile, do we still need this? shouldn't we just pick the best available fusion in torch.compile land?
There was a problem hiding this comment.
Unfortunately we'll still be needing this flag, since we need to use the pre-grad IR pass, which is a global setting, and there is no way to selectively use the pre-grad IR pass for just the selected module (as far as I know but I could totally be wrong). It's one of the reasons why I first chose to have the behind the scenes compile, so we can set the global flag, compile that one module, and reset it to default.
What I have now is we return a compile backend, and in the user warning, we tell them they must run torch.compile(module, backend=returned_backend) to compile that module with the specific backend that they need (which is different depending on FA3/FA4). We still need the flag, and the user cannot just call torch.compile on the whole model with us automatically picking the best available fusion for that specific module, but I think it's probably still better than having a magic compile the user doesn't know about?
torchao/prototype/attention/api.py
Outdated
|
|
||
| def apply_low_precision_attention( | ||
| model: nn.Module, | ||
| config: Optional[LowPrecisionAttentionConfig] = None, |
There was a problem hiding this comment.
thoughts about removing the config and just adding arguments here? will be more discoverable
There was a problem hiding this comment.
Yup, that makes sense to me. I originally had it as a config because I want it to be usable as we start adding more and more parameters. But, I think for what it is now, it makes more sense to have it as arguments. I have made the change.
Introduces a new `torchao.prototype.attention` module providing FP8 attention via FlashAttention 3. The default path monkey-patches `F.scaled_dot_product_attention` at call time — no `torch.compile` needed. Flash attention activation is managed internally by the wrapper. Key components: - `apply_low_precision_attention()`: user-facing API - `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic - `fp8_fa3/`: FA3-specific thin wrappers - `quantization/`: fused Triton FP8 QKV quantization kernels - Causal mask detection for HuggingFace model compatibility ghstack-source-id: e7282fa Pull-Request: pytorch#3959
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Summary
Folder Breakdown
Test Plan
python -m pytest test/prototype/attention/test_fp8_attention.py -vExample Usage
Results
Single-Layer Results
Results directly comparing FA3 SDPA versus FA3 fp8 SDPA (including quantization time):

Llama3 Model Results
Results comparing Llama3 model with FA3 SDPA versus Llama3 using the FA3 fp8 wrapper. Does not use RoPE fusion.

Perplexity: 6.19 -> 6.25