[Cute,Fwd,Sm100] fp8 e4m3 and e5m2 support#2109
[Cute,Fwd,Sm100] fp8 e4m3 and e5m2 support#2109dcw02 wants to merge 13 commits intoDao-AILab:mainfrom
Conversation
|
yes, fixed merge conflicts |
|
It seems curious to me that FP8 can only get 0.95x - 1.15x speedup given it reduces data movement by half. Is it possible to get some profiling results to see if it's being bottlenecked by SFU/softmax? Thanks. |
e2e softmax is turned off for fp8 since during my testing it hurt performance. here's numbers with e2e turned off for bf16 as well: mostly looking to just get the initial support in first, then we can optimize more afterwards |
on my branch |
|
I did a brief review and it seems fine. There's more to optimize for fp8 but that's for later. |
flash_attn/cute/flash_fwd_sm100.py
Outdated
| window_size_left: Int32 | int | None = None, | ||
| window_size_right: Int32 | int | None = None, | ||
| learnable_sink: Optional[cute.Tensor] = None, | ||
| mQDescale: Optional[cute.Tensor] = None, |
There was a problem hiding this comment.
can we put the scales into a small struct?
There was a problem hiding this comment.
does putting things in struct make it more difficult to use tvm-ffi?
https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.html#working-with-named-tuples
You know the tvm-ffi stuff better than i do
There was a problem hiding this comment.
yup named tuple is the way;
There was a problem hiding this comment.
put the scales into a named tuple here and updated the wiring
flash_attn/cute/flash_fwd_sm100.py
Outdated
| acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum | ||
| stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) | ||
| scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) | ||
| scale = scale * v_descale |
There was a problem hiding this comment.
I doubt it has much of an impact but did we measure that the non qkv scaled doesnt take any performance hit from these changes
There was a problem hiding this comment.
there was no measurable performance impact but I added compile time gating
|
fixed lint and rebased but I think benchmarking script still works if you downgrade to |
|
Can you do one more rebase please :) |
|
Tested this fp8 low precision attention implementation by adding it as a quantized overload through PyTorch (pytorch/pytorch#175472), and tested it through the low precision attention API in TorchAO and saw some decent runtime results. You can see the PR here for more implementation details (pytorch/ao#3960 and pytorch/ao#3947) ResultsSingle-Layer ResultsResults directly comparing bf16 FA4 SDPA versus fp8 FA4 SDPA (including quantization time): Llama3 Model ResultsResults comparing Llama3 model with bf16 FA4 SDPA versus Llama3 using the fp8 FA4 SDPA. |



Summary
Adds FP8 (E4M3/E5M2) support to Flash Attention 4 on SM100 (Blackwell GPUs). This brings FP8 inference acceleration to the CuTe-DSL implementation with full correctness validation.
What's New
FP8 Forward Pass
Key Implementation Details
Numerical stability fixes:
softmax_scale_eff = softmax_scale * (q_descale * k_descale)PTX updates:
f16, now dynamic:f16,f8f6f4, etc.)tcgen05.mmainline assembly uses correct FP8 variantsInterface:
DLPack workaround: PyTorch 2.9.1 doesn't support FP8 via DLPack, so we export as uint8 and override the element type.
Benchmark & Testing
Added FP8 benchmark (modeled after FA3's) that tests:
Correctness checking:
--check-quantization-onlyto isolate kernel bugs from quantization errorUsage:
Results