Add FA4 monkey-patch path for low-precision attention#3960
Add FA4 monkey-patch path for low-precision attention#3960howardzhang-cv wants to merge 16 commits intogh/howardzhang-cv/24/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3960
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 8 PendingAs of commit 6611347 with merge base 5ebd10d ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: fa10283 Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: fa10283 Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: fa10283 Pull-Request: pytorch#3960
|
So this FA4 fp8 low precision backend currently has some issues randomly with the Llama 3 model. For some reason, we get NaNs in the quantized QKV tensors (specifically in torchao/prototype/attention/shared_utils/attention.py in the _fp8_sdpa function, the q_fp8 tensor seems to be the first to become NaN). This issue goes away when we replace the triton kernel with simple PyTorch ops. It also goes away with --compile, and only happens on Blackwell for some reason. Not sure why this happens, I spent a bunch of time trying to fix this, but couldn't find the heart of the issue. |
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: af3a98b Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: af3a98b Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: 47eb616 Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: 05d86df Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: 05d86df Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: 05d86df Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: 05d86df Pull-Request: pytorch#3960
Adds FlashAttention 4 backend support for the monkey-patch SDPA path. FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware with the flash_attn.cute.interface package. Key additions: - FP8_FA4 enum value in AttentionBackend - _is_blackwell(), _is_fa4_available() hardware/library checks - FA4 dispatch in apply_low_precision_attention - fp8_fa4/ directory with fp8_fa4_sdpa entry point - FA4 backend config in test suite with eager probe guard - RoPE fusion placeholder (fuse_rope=True raises NotImplementedError) ghstack-source-id: 05d86df Pull-Request: pytorch#3960
Stack from ghstack (oldest at bottom):
Summary
New Files
Modified Files
Test Plan
python -m pytest test/prototype/attention/test_fp8_attention.py -vExample Usage
Results
Single-Layer Results
Results directly comparing FA4 SDPA versus FA4 fp8 SDPA (including quantization time):

Llama3 Model Results
Results comparing Llama3 model with FA4 SDPA versus Llama3 using the FA4 fp8 wrapper. Does not use RoPE fusion.

Perplexity: 6.19 -> 6.25