Skip to content

Commit f7322d9

Browse files
authored
perf: Performance tune cute dsl RMSNorm variants (#2777)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Rewrites all CuTe-DSL RMSNorm kernel variants (`rmsnorm`, `gemma_rmsnorm`, `fused_add_rmsnorm`, `gemma_fused_add_rmsnorm`, `rmsnorm_quant`, `fused_add_rmsnorm_quant`, `qk_rmsnorm`, `gemma_qk_rmsnorm`) **Key changes:** * Multi-row blocks with async global-to-shared copy (cpasync): Each thread block processes multiple rows, improving wave utilization and hiding memory latency. Falls back to synchronous copies when alignment or shared memory constraints prevent async usage. * Cluster reduction on SM90+: For large hidden sizes (H > max single-CTA capacity), the workload is split across a CTA cluster that reduces partial sums via shared memory, avoiding the need for a single CTA to handle the full row. * Vectorized FP8 convert+store PTX intrinsics `cvt.rn.satfinite.e4m3x2.f32`, dramatically improving quantization kernel throughput. * Occupancy-aware shared memory management * Non-contiguous tensor support without performance loss: Uses dual-path compilation — a compact kernel for contiguous inputs (optimal codegen) and a strided kernel for non-contiguous inputs (symbolic row strides). Runtime dispatch via is_contiguous() ensures zero overhead for the common contiguous case. <details> <summary>Click to see B200 performance comparison data (Peak 8 TB/s)</summary> **RMSNorm** Before: <img width="1905" height="1680" alt="before_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/15582140-f6df-4794-a4b4-2cc19d252dbb" /> After <img width="1905" height="1680" alt="after_heatmap_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/0d306806-36d2-4576-a6c2-9f4629f277f8" /> **QK RMSNorm** Before: <img width="1905" height="1680" alt="before_qk_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/71540b32-1df7-4772-94a7-b6b8c71080ee" /> After: <img width="1905" height="1680" alt="after_qk_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/04e95f62-73fe-43f4-b1a1-95eff234e379" /> **Add + RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/7bdda617-2d20-4a05-b7fd-2e9e489acba7" /> After: <img width="1905" height="1680" alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/663fb2a5-45cf-4fab-a74b-dc338d7d8bd0" /> </details> <details> <summary>Click to see H200 performance comparison data (Peak 4.8 TB/s)</summary> **RMSNorm** Before: <img width="1905" height="1680" alt="before_rmsnorm_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/42f63c06-8f6f-4ada-b6fd-e19de4ee32cc" /> After: <img width="1905" height="1680" alt="after_rmsnorm_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/ae30fc58-159e-43b6-b108-850bf1711cad" /> **RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/52469123-6a5f-459a-ae0b-586a11370ac9" /> After: <img width="1905" height="1680" alt="after_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/4a229d4a-10ea-4d89-985f-c0378c6554d4" /> **Add + RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/78ac50aa-ae6a-4ea6-a585-0b326279e96b" /> After: <img width="1905" height="1680" alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/8268ffb8-0ee0-49b7-9353-8d0151002329" /> </details> ## 🔍 Related Issues <!-- Link any related issues here --> #2396 #2771 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * SM-version aware kernels and cluster-based tiling for multi-CTA execution * Contiguity-aware selection for compact vs. strided tensor paths * Hardware-accelerated FP8/E4M3 conversion and packed storage routines * New exposed utilities for device SM queries and cluster-backed reductions * **Improvements** * Async copy paths, expanded shared-memory and cluster-reduction support * Per-cluster memory/tiling estimation and improved multi-cluster reduction handling * Public APIs now accept an optional SM-version hint and infer/preserve contiguity <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent abf080a commit f7322d9

3 files changed

Lines changed: 1806 additions & 512 deletions

File tree

0 commit comments

Comments
 (0)