Skip to content

Conversation

@Alkaid-Benetnash
Copy link

@Alkaid-Benetnash Alkaid-Benetnash commented Nov 25, 2025

Hi,

This PR mainly aims to upstream some fixes and optimizations we developed for the blackwell FMHA decode attention in Meta's FBGEMM.

Change1: support FP16 in the fmha_gen kernel. Previously the fmha_gen uses hardcoded fp8 and ignores the precision compiler macros.
Change2: improve the fmha_gen kernel performance by reducing the MTile from 128 to 64.

More details available in the commit messages.

cc previous example_77 contributors who also reviewed this PR: @v0i0 @Aya-ZIbra

Overview of the perf improvements (measured on B200):

BW (TB/s) seqlen
precision MTile NTile KernSchedule 512 1024 2048 4096 8192 16384
fp16 64 64 UMMA_I 3.17476 4.49264 5.19867 5.28804 4.85761 4.8774
fp16 64 64 UMMA_P 3.72582 4.61441 4.62711 4.74302 4.70932 4.64647
fp16 64 128 UMMA_I 3.31873 4.53742 5.1661 5.5619 5.10862 4.97582
fp16 64 128 UMMA_P 3.79117 4.80919 5.28695 5.41161 5.11698 5.01131
fp16 64 256 UMMA_I 3.0013 4.31719 4.80408 5.26672 5.16098 4.88833
fp16 64 256 UMMA_P 3.30899 4.43762 5.06723 5.40213 5.16288 5.0174
fp16 128 64 UMMA_I 2.53065 3.41101 3.83049 4.17756 4.25571 4.2076
fp16 128 64 UMMA_P 3.25735 3.94733 4.0979 4.27806 4.19192 4.17491
fp16 128 128 UMMA_I 2.79024 3.86101 4.48453 4.73089 4.42713 4.36846
fp16 128 128 UMMA_P 3.46795 4.42207 4.74267 4.5842 4.56293 4.47918
fp16 128 256 UMMA_I 2.54354 3.68322 4.47963 4.62111 4.62095 4.51716
fp16 128 256 UMMA_P 2.92147 4.09284 4.47127 4.66789 4.55911 4.51602

fp16 MTile=64/MTile=128

 MTile NTile  KernelSchedule              geomean(row)
64/128 64 UMMA_I 125.45% 131.71% 135.72% 126.58% 114.14% 115.92% 124.68%
64/128 64 UMMA_P 114.38% 116.90% 112.91% 110.87% 112.34% 111.30% 113.10%
64/128 128 UMMA_I 118.94% 117.52% 115.20% 117.57% 115.39% 113.90% 116.41%
64/128 128 UMMA_P 109.32% 108.75% 111.48% 118.05% 112.14% 111.88% 111.90%
64/128 256 UMMA_I 118.00% 117.21% 107.24% 113.97% 111.69% 108.22% 112.65%
64/128 256 UMMA_P 113.26% 108.42% 113.33% 115.73% 113.24% 111.10% 112.49%
    geomean(col) 116.45% 116.51% 115.64% 117.03% 113.15% 112.03% 115.12%

overall geomean(fp16 speedup) is 1.15x

BW (TB/s) seqlen
precision MTile NTile KernSchedule 512 1024 2048 4096 8192 16384
fp8 64 64 UMMA_I 1.94386 2.51525 3.24701 3.89887 4.01632 4.08602
fp8 64 64 UMMA_P 2.50369 3.11726 3.6986 4.13179 4.11397 4.18423
fp8 64 128 UMMA_I 2.34714 3.10015 4.13021 4.96551 5.14912 5.06843
fp8 64 128 UMMA_P 2.75716 3.8988 4.97269 5.42379 5.16682 4.95883
fp8 64 256 UMMA_I 2.27186 3.23571 4.31995 5.01347 5.44343 5.26331
fp8 64 256 UMMA_P 2.65845 3.76536 4.76387 5.24293 5.34911 5.25732
fp8 128 64 UMMA_I 1.44813 1.86377 2.33225 2.84555 3.10332 3.23485
fp8 128 64 UMMA_P 1.96491 2.43062 2.96345 3.32285 3.51813 3.56919
fp8 128 128 UMMA_I 1.73366 2.36751 3.21492 3.74087 3.89075 3.96716
fp8 128 128 UMMA_P 2.36905 3.24425 3.88743 4.00828 4.14043 4.0752
fp8 128 256 UMMA_I 1.60114 2.44246 3.45299 4.0291 4.25753 4.25951
fp8 128 256 UMMA_P 2.00025 2.96163 3.81701 4.19517 4.24847 4.20399

fp8 MTile=64/MTile=128

 MTile NTile  KernelSchedule              geomean(row)
64/128 64 UMMA_I 134.23% 134.95% 139.22% 137.02% 129.42% 126.31% 133.45%
64/128 64 UMMA_P 127.42% 128.25% 124.81% 124.34% 116.94% 117.23% 123.08%
64/128 128 UMMA_I 135.39% 130.95% 128.47% 132.74% 132.34% 127.76% 131.25%
64/128 128 UMMA_P 116.38% 120.18% 127.92% 135.31% 124.79% 121.68% 124.23%
64/128 256 UMMA_I 141.89% 132.48% 125.11% 124.43% 127.85% 123.57% 129.07%
64/128 256 UMMA_P 132.91% 127.14% 124.81% 124.98% 125.91% 125.06% 126.77%
    geomean(col) 131.12% 128.90% 128.29% 129.69% 126.11% 123.55% 127.92%

overall geomean(fp8 speedup) is 1.28x

@hwu36
Copy link
Collaborator

hwu36 commented Dec 2, 2025

@alihassanijr , could you please take a look?

CC += @richardmcai

Comment on lines 326 to 330
#if defined(FP16)
using ElementType = cutlass::half_t;
#elif defined(FP8)
using ElementType = cutlass::float_e5m2_t;
#else
using ElementType = cutlass::float_e5m2_t;
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the else is extra?

Suggested change
#if defined(FP16)
using ElementType = cutlass::half_t;
#elif defined(FP8)
using ElementType = cutlass::float_e5m2_t;
#else
using ElementType = cutlass::float_e5m2_t;
#endif
#if defined(FP16)
using ElementType = cutlass::half_t;
#else
using ElementType = cutlass::float_e5m2_t;
#endif

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Updated in force-push ending with ab64886.

Comment on lines +130 to +133
* The following utility type_traits allow mapping constexpr variable to type at
* compile time.
* The default return type defined for each map would be returned if queried key
* does not exist in the map.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cutlass might already have a solution for these types of mappings, but I can't remember where it was.

But either way, I think the ideal way to handle the stage counts is through the Collective API, that way it's consistent with the rest of CUTLASS.

What do you think @hwu36 ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the same that cutlass must have something similar.
Eventually I failed to find anything and gave up searching after seeing the sm100 traits code using nested if constexpr (example1 op_repeater and example2 epilogue collective builder).

I am more than happy to switch to some cutlass internal utils that I overlooked, since I dislike reinventing the wheel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right; for a second I thought example 77 also had a collective builder. I must have confused it with 88.

Tagging more folks to see what they think about this approach: @thakkarV @richardmcai @IonThruster

@Alkaid-Benetnash Alkaid-Benetnash force-pushed the example77_decode_optimization branch from 31ea13a to ab64886 Compare December 9, 2025 01:45
Compile-time static/const mapping utilities for:
1. constexpr value -> constexpr value
2. constexpr value -> type

Useful when developing template-heavy cutlass code.
The fmha gen example
1. Previously ignored precision macros, now supporting FP16 and FP8 like
   other examples
2. Previously failed to run FP16 due to smem exceeding capacity, now
   works. Also added new static_assert to complain smem issues at
   compile time rather than runtime.
3. Previously used "ThreadShape=1x1x1", but didn't pass it into the
   kernel instantiation thus discarded. Now updated to reflect what is
   actually being used (the default 1x2x1).

Tested with
```
make test_examples_77_blackwell_fmha_gen_fp16
make test_examples_77_blackwell_fmha_gen_fp8
```
Summary:
This commit introduces a smaller TileSize at the M (i.e., query)
dimension, reducing from existing MTile=128 to MTile=64, which would
reduce the wasted compute (due to valid queries << MTile)

The most important change is how each row of the Q/K/V matrix is split
among threads, which in turn changes the MMA Atom and TMEM op to use.
Existing MTile=128 allocates a single thread to process a single row
using 32dp TMEM load/store instructions.
While the MTile=64 allocates two threads to collaborate on a single row
using 16dp TMEM load/store instructions (with additional
synchronizations/reductions).

This commit also clarifies the code a bit with more annotations.

Tested correctness with:
```
make test_examples_77_blackwell_fmha_gen_fp8
make test_examples_77_blackwell_fmha_gen_fp16
```

Performance numbers collected on B200:
```
for prec in fp16 fp8; do
for k in 512 1024 2048 4096 8192 16384; do
  ./77_blackwell_fmha_gen_${prec} --b=$((148*4)) --h=8 --h_k=1 --k=${k} --d=128 --iterations=2000 --verify;
done
done
```

Results:

| BW (TB/s) |       |       |              | --k     |         |         |         |         |         |
| --------- | ----- | ----- | ------------ | ------- | ------- | ------- | ------- | ------- | ------- |
| precision | MTile | NTile | KernSchedule | 512     | 1024    | 2048    | 4096    | 8192    | 16384   |
| fp16      | 64    | 64    | UMMA_I       | 3.17476 | 4.49264 | 5.19867 | 5.28804 | 4.85761 | 4.8774  |
| fp16      | 64    | 64    | UMMA_P       | 3.72582 | 4.61441 | 4.62711 | 4.74302 | 4.70932 | 4.64647 |
| fp16      | 64    | 128   | UMMA_I       | 3.31873 | 4.53742 | 5.1661  | 5.5619  | 5.10862 | 4.97582 |
| fp16      | 64    | 128   | UMMA_P       | 3.79117 | 4.80919 | 5.28695 | 5.41161 | 5.11698 | 5.01131 |
| fp16      | 64    | 256   | UMMA_I       | 3.0013  | 4.31719 | 4.80408 | 5.26672 | 5.16098 | 4.88833 |
| fp16      | 64    | 256   | UMMA_P       | 3.30899 | 4.43762 | 5.06723 | 5.40213 | 5.16288 | 5.0174  |
| fp16      | 128   | 64    | UMMA_I       | 2.53065 | 3.41101 | 3.83049 | 4.17756 | 4.25571 | 4.2076  |
| fp16      | 128   | 64    | UMMA_P       | 3.25735 | 3.94733 | 4.0979  | 4.27806 | 4.19192 | 4.17491 |
| fp16      | 128   | 128   | UMMA_I       | 2.79024 | 3.86101 | 4.48453 | 4.73089 | 4.42713 | 4.36846 |
| fp16      | 128   | 128   | UMMA_P       | 3.46795 | 4.42207 | 4.74267 | 4.5842  | 4.56293 | 4.47918 |
| fp16      | 128   | 256   | UMMA_I       | 2.54354 | 3.68322 | 4.47963 | 4.62111 | 4.62095 | 4.51716 |
| fp16      | 128   | 256   | UMMA_P       | 2.92147 | 4.09284 | 4.47127 | 4.66789 | 4.55911 | 4.51602 |
geomean for fp16 TileM=64/TileM=128 == 1.15x

| BW (TB/s) |       |       |              | --k     |         |         |         |         |         |
| --------- | ----- | ----- | ------------ | ------- | ------- | ------- | ------- | ------- | ------- |
| precision | MTile | NTile | KernSchedule | 512     | 1024    | 2048    | 4096    | 8192    | 16384   |
| fp8       | 64    | 64    | UMMA_I       | 1.94386 | 2.51525 | 3.24701 | 3.89887 | 4.01632 | 4.08602 |
| fp8       | 64    | 64    | UMMA_P       | 2.50369 | 3.11726 | 3.6986  | 4.13179 | 4.11397 | 4.18423 |
| fp8       | 64    | 128   | UMMA_I       | 2.34714 | 3.10015 | 4.13021 | 4.96551 | 5.14912 | 5.06843 |
| fp8       | 64    | 128   | UMMA_P       | 2.75716 | 3.8988  | 4.97269 | 5.42379 | 5.16682 | 4.95883 |
| fp8       | 64    | 256   | UMMA_I       | 2.27186 | 3.23571 | 4.31995 | 5.01347 | 5.44343 | 5.26331 |
| fp8       | 64    | 256   | UMMA_P       | 2.65845 | 3.76536 | 4.76387 | 5.24293 | 5.34911 | 5.25732 |
| fp8       | 128   | 64    | UMMA_I       | 1.44813 | 1.86377 | 2.33225 | 2.84555 | 3.10332 | 3.23485 |
| fp8       | 128   | 64    | UMMA_P       | 1.96491 | 2.43062 | 2.96345 | 3.32285 | 3.51813 | 3.56919 |
| fp8       | 128   | 128   | UMMA_I       | 1.73366 | 2.36751 | 3.21492 | 3.74087 | 3.89075 | 3.96716 |
| fp8       | 128   | 128   | UMMA_P       | 2.36905 | 3.24425 | 3.88743 | 4.00828 | 4.14043 | 4.0752  |
| fp8       | 128   | 256   | UMMA_I       | 1.60114 | 2.44246 | 3.45299 | 4.0291  | 4.25753 | 4.25951 |
| fp8       | 128   | 256   | UMMA_P       | 2.00025 | 2.96163 | 3.81701 | 4.19517 | 4.24847 | 4.20399 |
geomean for fp8 TileM=64/TileM=128 == 1.28x
Test:
```
./77_blackwell_fmha_gen_fp16 --b=256 --h=8 --h_k=1 --k=65536 --d=128 --iterations=2 --cache-only --verify
./77_blackwell_fmha_gen_fp8 --b=256 --h=8 --h_k=1 --k=65536 --d=128 --iterations=2 --cache-only --verify
```

Before:
```
Failed to initialize the CUTLASS kernel. Last CUDA error is: no error
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to query occupancy.
  Aborted (core dumped)
```

After: run correctly.
@Alkaid-Benetnash Alkaid-Benetnash force-pushed the example77_decode_optimization branch from ab64886 to 68e1d04 Compare December 15, 2025 19:55
@Alkaid-Benetnash
Copy link
Author

Rebase to the recent v4.3.3 release. cc @alihassanijr @thakkarV @richardmcai @IonThruster

Still awaiting suggestions on the mapping utility.
Please also let me know if there are other feedbacks. After all the mapping utility is more of a side discussion not the core changes.

@Aya-ZIbra
Copy link
Contributor

@hwu36 We have been using the smaller Q_tile size supported here (pytorch/FBGEMM#5072) for a while now and it shows good performance. So, it would be great to have this PR merged in the next version to simplify updating from upstream. Please, let us know if any needed changes we can do from our side. Thanks.

Alkaid-Benetnash added a commit to Alkaid-Benetnash/FBGEMM-1 that referenced this pull request Dec 16, 2025
…4, 256, 128>

Summary:
Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM.
Update softmax warp TMEM LOAD/STORE OP selection logic.

Reviewed By: Aya-ZIbra

Differential Revision: D89217227
Alkaid-Benetnash added a commit to Alkaid-Benetnash/FBGEMM-1 that referenced this pull request Dec 17, 2025
…4, 256, 128> (pytorch#5232)

Summary:
X-link: facebookresearch/FBGEMM#2227


Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM.
Update softmax warp TMEM LOAD/STORE OP selection logic.

Reviewed By: Aya-ZIbra

Differential Revision: D89217227
Alkaid-Benetnash added a commit to Alkaid-Benetnash/FBGEMM-1 that referenced this pull request Dec 17, 2025
…4, 256, 128> (pytorch#5232)

Summary:
X-link: facebookresearch/FBGEMM#2227


Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM.
Update softmax warp TMEM LOAD/STORE OP selection logic.

Reviewed By: Aya-ZIbra

Differential Revision: D89217227
Alkaid-Benetnash added a commit to Alkaid-Benetnash/FBGEMM-1 that referenced this pull request Dec 17, 2025
…4, 256, 128> (pytorch#5232)

Summary:
X-link: facebookresearch/FBGEMM#2227


Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM.
Update softmax warp TMEM LOAD/STORE OP selection logic.

Reviewed By: Aya-ZIbra

Differential Revision: D89217227
Alkaid-Benetnash added a commit to Alkaid-Benetnash/FBGEMM-1 that referenced this pull request Dec 17, 2025
…4, 256, 128> (pytorch#5232)

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2227

Pull Request resolved: pytorch#5232

Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM.
Update softmax warp TMEM LOAD/STORE OP selection logic.

Reviewed By: Aya-ZIbra

Differential Revision: D89217227
Alkaid-Benetnash added a commit to Alkaid-Benetnash/FBGEMM-1 that referenced this pull request Dec 17, 2025
…4, 256, 128> (pytorch#5232)

Summary:
X-link: facebookresearch/FBGEMM#2227


Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM.
Update softmax warp TMEM LOAD/STORE OP selection logic.

Reviewed By: Aya-ZIbra

Differential Revision: D89217227
meta-codesync bot pushed a commit to pytorch/FBGEMM that referenced this pull request Dec 17, 2025
…4, 256, 128> (#5232)

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2227

Pull Request resolved: #5232

Port some new changes in NVIDIA/cutlass#2816 back to FBGEMM.
Update softmax warp TMEM LOAD/STORE OP selection logic.

Reviewed By: Aya-ZIbra

Differential Revision: D89217227

fbshipit-source-id: b75260d328c1fead18f363516389c369c6e7cb6e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants