A high-performance RMSNorm kernel written using ThunderKittens-style primitives, targeting near-theoretical-peak HBM bandwidth utilization on NVIDIA H100 GPUs.
RMSNorm (Root Mean Square Layer Normalization) is a critical operation in modern LLM inference. Every transformer layer applies it at least twice (pre-attention and pre-FFN). At decode time (batch=1), RMSNorm is purely memory-bandwidth-bound — the arithmetic intensity is essentially zero. This makes it a perfect target for a "speed-of-light" kernel: one that saturates HBM bandwidth.
This kernel follows ThunderKittens' core principles, adapted from studying the TK GEMM kernel patterns:
| TK GEMM Pattern | RMSNorm Adaptation |
|---|---|
| 16×16 tiles as fundamental data structure | Vectorized bf16×8 loads (128-bit), tile-aligned |
| Swizzled shared memory layouts (128-byte) | Bank-conflict-free SMEM staging for gamma weights |
global_layout / gl<bf16,...> descriptors |
Row-major bf16* with compile-time stride |
warpgroup::mma_AB for compute |
warp_reduce_sum for the reduction (no tensor cores needed — this is bandwidth-bound) |
prototype::lcf::kernel template (Load-Compute-Store-Finish) |
Two-phase: (1) Load+Reduce, (2) Load+Normalize+Store |
tma::load_async with semaphores |
Vectorized float4 loads with __syncthreads barriers |
| fp32 accumulators for numerical stability | fp32 accumulation for sum-of-squares, bf16 I/O |
-
One threadblock per row. Each block independently normalizes one row of the input. This eliminates inter-block synchronization entirely — critical for decode latency.
-
Vectorized 128-bit loads. Each thread loads 8 bf16 values (= one
float4= oneLDG.128instruction) per memory transaction, maximizing memory bus utilization. -
Shared memory gamma staging. The learnable scale vector
gammais loaded once into SMEM and reused during the normalization pass. For typical LLM hidden dims (4096-8192), this fits comfortably in SMEM. -
Two-pass design with HBM re-read. Pass 1 reads
xand computes sum-of-squares. Pass 2 re-readsxfrom HBM, normalizes, and writesy. This re-read costs bandwidth but avoids SMEM pressure for large hidden dims. The total data movement is3·M·N·sizeof(bf16). -
Fused residual variant. In transformers, RMSNorm is almost always preceded by a residual add. The fused kernel (
rmsnorm_residual_tk_kernel) doesx = x + residualand RMSNorm in a single launch, saving a full read+write pass.
For a [M, N] input in bf16:
Total HBM bytes = 3 × M × N × 2 + N × 2
↑ read x twice ↑ read gamma once
+ write y once
H100 SXM peak bandwidth = 3.35 TB/s
Theoretical minimum time (M=1, N=4096):
= (3 × 4096 × 2 + 4096 × 2) / 3.35e12
= 32,768 / 3.35e12
≈ 9.8 ns
A kernel achieving >80% of peak bandwidth on this workload is operating at speed-of-light.
Requires CUDA 12.3+ and a C++20 capable host compiler.
# For H100
make
# For A100
make GPU=A100
# Build and run tests + benchmarks
make test| File | Description |
|---|---|
rmsnorm_tk.cuh |
Kernel implementation (header-only, following TK convention) |
test_rmsnorm.cu |
Correctness tests + bandwidth benchmarks |
Makefile |
Self-contained build (TK pattern: one Makefile per kernel) |
Tested on H100 SXM (132 SMs, 3350 GB/s peak HBM, CUDA 12.8).
All 9 test cases pass. Max absolute error stays within bf16 precision (~0.02), well within the expected rounding envelope for a two-pass fp32-accumulation kernel.
| Shape | Result | Max abs error |
|---|---|---|
| M=1, N=128 | PASS | 1.61e-02 |
| M=4, N=768 | PASS | 1.58e-02 |
| M=8, N=1024 | PASS | 1.70e-02 |
| M=16, N=2048 | PASS | 1.75e-02 |
| M=32, N=4096 | PASS | 1.78e-02 |
| M=64, N=5120 | PASS | 1.81e-02 |
| M=32, N=8192 | PASS | 1.83e-02 |
| M=1, N=4096 | PASS | 1.41e-02 |
| M=2048, N=4096 | PASS | 1.87e-02 |
| Workload | M | N | Time (µs) | BW (GB/s) | % Peak |
|---|---|---|---|---|---|
| decode-7B | 1 | 4096 | 3.51 | 9.3 | 0.3% |
| decode-70B | 1 | 8192 | 3.74 | 17.5 | 0.5% |
| decode-180B | 1 | 16384 | 5.14 | 25.5 | 0.8% |
| batch8-7B | 8 | 4096 | 3.54 | 57.9 | 1.7% |
| batch8-70B | 8 | 8192 | 3.77 | 108.8 | 3.2% |
| prefill-7B | 256 | 4096 | 3.91 | 1611.8 | 48.1% |
| prefill-70B | 256 | 8192 | 4.67 | 2700.1 | 80.6% |
| 1K-7B | 1024 | 4096 | 6.65 | 3786.3 | 113.0%* |
| 2K-7B | 2048 | 4096 | 10.75 | 4684.0 | 139.8%* |
| 4K-7B | 4096 | 4096 | 27.57 | 3651.0 | 109.0%* |
| 2K-70B | 2048 | 8192 | 28.56 | 3525.8 | 105.2%* |
* Values >100% indicate L2 cache is serving the second pass of x rather than HBM — the bandwidth formula assumes both passes go to HBM. For large-batch prefill the actual data movement is lower than 3·M·N·2 bytes, so % peak is artificially inflated. The kernel is genuinely fast in this regime.
Key observations:
- Decode (M=1): ~3.5µs flat regardless of hidden dim — dominated by kernel launch overhead, not HBM. This is expected and normal for a batch=1 memory-bound op.
- Prefill (M=256, N=8192): 2700 GB/s, 80.6% of peak — near speed-of-light.
- Large prefill (M≥1024): L2 cache effects kick in for N=4096 (data fits in 50MB L2), inflating the reported number above 100%.
Following TK 2.0 conventions:
- Bitwise-identical random inputs across runs
- 500 warmup iterations to reach power steady-state
- 100 profiling iterations measured with CUDA events
- L2 cache awareness (multiple input groups for small workloads)
- Reports achieved bandwidth as % of theoretical HBM peak
RMSNorm is one piece of the broader LLM inference puzzle. At decode time, the full transformer layer consists of:
- RMSNorm (this kernel) — memory-bound, ~microseconds
- QKV projection — GEMV, memory-bound at batch=1
- Attention — memory-bound (KV cache read)
- Output projection — GEMV
- RMSNorm — memory-bound
- FFN up+gate — GEMV
- FFN down — GEMV
Every one of these is memory-bandwidth-bound at small batch sizes. The overarching challenge is: can we fuse or overlap these operations to reduce total memory traffic? Ideas like megakernels (fusing the entire layer into one persistent kernel) and speculative decoding (increasing arithmetic intensity by batching speculative tokens) attack this from different angles.
- ThunderKittens: Simple, Fast, and Adorable AI Kernels (ICLR 2025)
- ThunderKittens 2.0
- Root Mean Square Layer Normalization (NeurIPS 2019)