-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Example77 blackwell FMHA decode attention optimization #2816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Example77 blackwell FMHA decode attention optimization #2816
Conversation
|
@alihassanijr , could you please take a look? CC += @richardmcai |
| #if defined(FP16) | ||
| using ElementType = cutlass::half_t; | ||
| #elif defined(FP8) | ||
| using ElementType = cutlass::float_e5m2_t; | ||
| #else | ||
| using ElementType = cutlass::float_e5m2_t; | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the else is extra?
| #if defined(FP16) | |
| using ElementType = cutlass::half_t; | |
| #elif defined(FP8) | |
| using ElementType = cutlass::float_e5m2_t; | |
| #else | |
| using ElementType = cutlass::float_e5m2_t; | |
| #endif | |
| #if defined(FP16) | |
| using ElementType = cutlass::half_t; | |
| #else | |
| using ElementType = cutlass::float_e5m2_t; | |
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Updated in force-push ending with ab64886.
| * The following utility type_traits allow mapping constexpr variable to type at | ||
| * compile time. | ||
| * The default return type defined for each map would be returned if queried key | ||
| * does not exist in the map. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think cutlass might already have a solution for these types of mappings, but I can't remember where it was.
But either way, I think the ideal way to handle the stage counts is through the Collective API, that way it's consistent with the rest of CUTLASS.
What do you think @hwu36 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the same that cutlass must have something similar.
Eventually I failed to find anything and gave up searching after seeing the sm100 traits code using nested if constexpr (example1 op_repeater and example2 epilogue collective builder).
I am more than happy to switch to some cutlass internal utils that I overlooked, since I dislike reinventing the wheel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right; for a second I thought example 77 also had a collective builder. I must have confused it with 88.
Tagging more folks to see what they think about this approach: @thakkarV @richardmcai @IonThruster
31ea13a to
ab64886
Compare
Compile-time static/const mapping utilities for: 1. constexpr value -> constexpr value 2. constexpr value -> type Useful when developing template-heavy cutlass code.
The fmha gen example 1. Previously ignored precision macros, now supporting FP16 and FP8 like other examples 2. Previously failed to run FP16 due to smem exceeding capacity, now works. Also added new static_assert to complain smem issues at compile time rather than runtime. 3. Previously used "ThreadShape=1x1x1", but didn't pass it into the kernel instantiation thus discarded. Now updated to reflect what is actually being used (the default 1x2x1). Tested with ``` make test_examples_77_blackwell_fmha_gen_fp16 make test_examples_77_blackwell_fmha_gen_fp8 ```
Summary:
This commit introduces a smaller TileSize at the M (i.e., query)
dimension, reducing from existing MTile=128 to MTile=64, which would
reduce the wasted compute (due to valid queries << MTile)
The most important change is how each row of the Q/K/V matrix is split
among threads, which in turn changes the MMA Atom and TMEM op to use.
Existing MTile=128 allocates a single thread to process a single row
using 32dp TMEM load/store instructions.
While the MTile=64 allocates two threads to collaborate on a single row
using 16dp TMEM load/store instructions (with additional
synchronizations/reductions).
This commit also clarifies the code a bit with more annotations.
Tested correctness with:
```
make test_examples_77_blackwell_fmha_gen_fp8
make test_examples_77_blackwell_fmha_gen_fp16
```
Performance numbers collected on B200:
```
for prec in fp16 fp8; do
for k in 512 1024 2048 4096 8192 16384; do
./77_blackwell_fmha_gen_${prec} --b=$((148*4)) --h=8 --h_k=1 --k=${k} --d=128 --iterations=2000 --verify;
done
done
```
Results:
| BW (TB/s) | | | | --k | | | | | |
| --------- | ----- | ----- | ------------ | ------- | ------- | ------- | ------- | ------- | ------- |
| precision | MTile | NTile | KernSchedule | 512 | 1024 | 2048 | 4096 | 8192 | 16384 |
| fp16 | 64 | 64 | UMMA_I | 3.17476 | 4.49264 | 5.19867 | 5.28804 | 4.85761 | 4.8774 |
| fp16 | 64 | 64 | UMMA_P | 3.72582 | 4.61441 | 4.62711 | 4.74302 | 4.70932 | 4.64647 |
| fp16 | 64 | 128 | UMMA_I | 3.31873 | 4.53742 | 5.1661 | 5.5619 | 5.10862 | 4.97582 |
| fp16 | 64 | 128 | UMMA_P | 3.79117 | 4.80919 | 5.28695 | 5.41161 | 5.11698 | 5.01131 |
| fp16 | 64 | 256 | UMMA_I | 3.0013 | 4.31719 | 4.80408 | 5.26672 | 5.16098 | 4.88833 |
| fp16 | 64 | 256 | UMMA_P | 3.30899 | 4.43762 | 5.06723 | 5.40213 | 5.16288 | 5.0174 |
| fp16 | 128 | 64 | UMMA_I | 2.53065 | 3.41101 | 3.83049 | 4.17756 | 4.25571 | 4.2076 |
| fp16 | 128 | 64 | UMMA_P | 3.25735 | 3.94733 | 4.0979 | 4.27806 | 4.19192 | 4.17491 |
| fp16 | 128 | 128 | UMMA_I | 2.79024 | 3.86101 | 4.48453 | 4.73089 | 4.42713 | 4.36846 |
| fp16 | 128 | 128 | UMMA_P | 3.46795 | 4.42207 | 4.74267 | 4.5842 | 4.56293 | 4.47918 |
| fp16 | 128 | 256 | UMMA_I | 2.54354 | 3.68322 | 4.47963 | 4.62111 | 4.62095 | 4.51716 |
| fp16 | 128 | 256 | UMMA_P | 2.92147 | 4.09284 | 4.47127 | 4.66789 | 4.55911 | 4.51602 |
geomean for fp16 TileM=64/TileM=128 == 1.15x
| BW (TB/s) | | | | --k | | | | | |
| --------- | ----- | ----- | ------------ | ------- | ------- | ------- | ------- | ------- | ------- |
| precision | MTile | NTile | KernSchedule | 512 | 1024 | 2048 | 4096 | 8192 | 16384 |
| fp8 | 64 | 64 | UMMA_I | 1.94386 | 2.51525 | 3.24701 | 3.89887 | 4.01632 | 4.08602 |
| fp8 | 64 | 64 | UMMA_P | 2.50369 | 3.11726 | 3.6986 | 4.13179 | 4.11397 | 4.18423 |
| fp8 | 64 | 128 | UMMA_I | 2.34714 | 3.10015 | 4.13021 | 4.96551 | 5.14912 | 5.06843 |
| fp8 | 64 | 128 | UMMA_P | 2.75716 | 3.8988 | 4.97269 | 5.42379 | 5.16682 | 4.95883 |
| fp8 | 64 | 256 | UMMA_I | 2.27186 | 3.23571 | 4.31995 | 5.01347 | 5.44343 | 5.26331 |
| fp8 | 64 | 256 | UMMA_P | 2.65845 | 3.76536 | 4.76387 | 5.24293 | 5.34911 | 5.25732 |
| fp8 | 128 | 64 | UMMA_I | 1.44813 | 1.86377 | 2.33225 | 2.84555 | 3.10332 | 3.23485 |
| fp8 | 128 | 64 | UMMA_P | 1.96491 | 2.43062 | 2.96345 | 3.32285 | 3.51813 | 3.56919 |
| fp8 | 128 | 128 | UMMA_I | 1.73366 | 2.36751 | 3.21492 | 3.74087 | 3.89075 | 3.96716 |
| fp8 | 128 | 128 | UMMA_P | 2.36905 | 3.24425 | 3.88743 | 4.00828 | 4.14043 | 4.0752 |
| fp8 | 128 | 256 | UMMA_I | 1.60114 | 2.44246 | 3.45299 | 4.0291 | 4.25753 | 4.25951 |
| fp8 | 128 | 256 | UMMA_P | 2.00025 | 2.96163 | 3.81701 | 4.19517 | 4.24847 | 4.20399 |
geomean for fp8 TileM=64/TileM=128 == 1.28x
Test: ``` ./77_blackwell_fmha_gen_fp16 --b=256 --h=8 --h_k=1 --k=65536 --d=128 --iterations=2 --cache-only --verify ./77_blackwell_fmha_gen_fp8 --b=256 --h=8 --h_k=1 --k=65536 --d=128 --iterations=2 --cache-only --verify ``` Before: ``` Failed to initialize the CUTLASS kernel. Last CUDA error is: no error terminate called after throwing an instance of 'std::runtime_error' what(): Failed to query occupancy. Aborted (core dumped) ``` After: run correctly.
ab64886 to
68e1d04
Compare
|
Rebase to the recent v4.3.3 release. cc @alihassanijr @thakkarV @richardmcai @IonThruster Still awaiting suggestions on the mapping utility. |
|
@hwu36 We have been using the smaller Q_tile size supported here (pytorch/FBGEMM#5072) for a while now and it shows good performance. So, it would be great to have this PR merged in the next version to simplify updating from upstream. Please, let us know if any needed changes we can do from our side. Thanks. |
…4, 256, 128> Summary: Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM. Update softmax warp TMEM LOAD/STORE OP selection logic. Reviewed By: Aya-ZIbra Differential Revision: D89217227
…4, 256, 128> (pytorch#5232) Summary: X-link: facebookresearch/FBGEMM#2227 Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM. Update softmax warp TMEM LOAD/STORE OP selection logic. Reviewed By: Aya-ZIbra Differential Revision: D89217227
…4, 256, 128> (pytorch#5232) Summary: X-link: facebookresearch/FBGEMM#2227 Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM. Update softmax warp TMEM LOAD/STORE OP selection logic. Reviewed By: Aya-ZIbra Differential Revision: D89217227
…4, 256, 128> (pytorch#5232) Summary: X-link: facebookresearch/FBGEMM#2227 Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM. Update softmax warp TMEM LOAD/STORE OP selection logic. Reviewed By: Aya-ZIbra Differential Revision: D89217227
…4, 256, 128> (pytorch#5232) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2227 Pull Request resolved: pytorch#5232 Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM. Update softmax warp TMEM LOAD/STORE OP selection logic. Reviewed By: Aya-ZIbra Differential Revision: D89217227
…4, 256, 128> (pytorch#5232) Summary: X-link: facebookresearch/FBGEMM#2227 Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM. Update softmax warp TMEM LOAD/STORE OP selection logic. Reviewed By: Aya-ZIbra Differential Revision: D89217227
…4, 256, 128> (#5232) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2227 Pull Request resolved: #5232 Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM. Update softmax warp TMEM LOAD/STORE OP selection logic. Reviewed By: Aya-ZIbra Differential Revision: D89217227 fbshipit-source-id: b75260d328c1fead18f363516389c369c6e7cb6e
Hi,
This PR mainly aims to upstream some fixes and optimizations we developed for the blackwell FMHA decode attention in Meta's FBGEMM.
Change1: support FP16 in the
fmha_genkernel. Previously the fmha_gen uses hardcoded fp8 and ignores the precision compiler macros.Change2: improve the
fmha_genkernel performance by reducing the MTile from 128 to 64.More details available in the commit messages.
cc previous example_77 contributors who also reviewed this PR: @v0i0 @Aya-ZIbra
Overview of the perf improvements (measured on B200):
fp16 MTile=64/MTile=128
overall geomean(fp16 speedup) is 1.15x
fp8 MTile=64/MTile=128
overall geomean(fp8 speedup) is 1.28x