Commit f7322d9
authored
perf: Performance tune cute dsl RMSNorm variants (#2777)
<!-- .github/pull_request_template.md -->
## 📌 Description
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
Rewrites all CuTe-DSL RMSNorm kernel variants (`rmsnorm`,
`gemma_rmsnorm`, `fused_add_rmsnorm`, `gemma_fused_add_rmsnorm`,
`rmsnorm_quant`, `fused_add_rmsnorm_quant`, `qk_rmsnorm`,
`gemma_qk_rmsnorm`)
**Key changes:**
* Multi-row blocks with async global-to-shared copy (cpasync): Each
thread block processes multiple rows, improving wave utilization and
hiding memory latency. Falls back to synchronous copies when alignment
or shared memory constraints prevent async usage.
* Cluster reduction on SM90+: For large hidden sizes (H > max single-CTA
capacity), the workload is split across a CTA cluster that reduces
partial sums via shared memory, avoiding the need for a single CTA to
handle the full row.
* Vectorized FP8 convert+store PTX intrinsics
`cvt.rn.satfinite.e4m3x2.f32`, dramatically improving quantization
kernel throughput.
* Occupancy-aware shared memory management
* Non-contiguous tensor support without performance loss: Uses dual-path
compilation — a compact kernel for contiguous inputs (optimal codegen)
and a strided kernel for non-contiguous inputs (symbolic row strides).
Runtime dispatch via is_contiguous() ensures zero overhead for the
common contiguous case.
<details>
<summary>Click to see B200 performance comparison data (Peak 8
TB/s)</summary>
**RMSNorm**
Before:
<img width="1905" height="1680"
alt="before_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/15582140-f6df-4794-a4b4-2cc19d252dbb"
/>
After
<img width="1905" height="1680"
alt="after_heatmap_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/0d306806-36d2-4576-a6c2-9f4629f277f8"
/>
**QK RMSNorm**
Before:
<img width="1905" height="1680"
alt="before_qk_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/71540b32-1df7-4772-94a7-b6b8c71080ee"
/>
After:
<img width="1905" height="1680"
alt="after_qk_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/04e95f62-73fe-43f4-b1a1-95eff234e379"
/>
**Add + RMSNorm + FP8 Quantize**
Before:
<img width="1905" height="1680"
alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/7bdda617-2d20-4a05-b7fd-2e9e489acba7"
/>
After:
<img width="1905" height="1680"
alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/663fb2a5-45cf-4fab-a74b-dc338d7d8bd0"
/>
</details>
<details>
<summary>Click to see H200 performance comparison data (Peak 4.8
TB/s)</summary>
**RMSNorm**
Before:
<img width="1905" height="1680"
alt="before_rmsnorm_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/42f63c06-8f6f-4ada-b6fd-e19de4ee32cc"
/>
After:
<img width="1905" height="1680" alt="after_rmsnorm_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/ae30fc58-159e-43b6-b108-850bf1711cad"
/>
**RMSNorm + FP8 Quantize**
Before:
<img width="1905" height="1680"
alt="before_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/52469123-6a5f-459a-ae0b-586a11370ac9"
/>
After:
<img width="1905" height="1680"
alt="after_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/4a229d4a-10ea-4d89-985f-c0378c6554d4"
/>
**Add + RMSNorm + FP8 Quantize**
Before:
<img width="1905" height="1680"
alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/78ac50aa-ae6a-4ea6-a585-0b326279e96b"
/>
After:
<img width="1905" height="1680"
alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/8268ffb8-0ee0-49b7-9353-8d0151002329"
/>
</details>
## 🔍 Related Issues
<!-- Link any related issues here -->
#2396
#2771
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* SM-version aware kernels and cluster-based tiling for multi-CTA
execution
* Contiguity-aware selection for compact vs. strided tensor paths
* Hardware-accelerated FP8/E4M3 conversion and packed storage routines
* New exposed utilities for device SM queries and cluster-backed
reductions
* **Improvements**
* Async copy paths, expanded shared-memory and cluster-reduction support
* Per-cluster memory/tiling estimation and improved multi-cluster
reduction handling
* Public APIs now accept an optional SM-version hint and infer/preserve
contiguity
<!-- end of auto-generated comment: release notes by coderabbit.ai -->1 parent abf080a commit f7322d9
3 files changed
Lines changed: 1806 additions & 512 deletions
0 commit comments