Skip to content

krishs0404/tk-rmsnorm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

RMSNorm TK — Speed-of-Light RMSNorm Kernel

A high-performance RMSNorm kernel written using ThunderKittens-style primitives, targeting near-theoretical-peak HBM bandwidth utilization on NVIDIA H100 GPUs.

What is this?

RMSNorm (Root Mean Square Layer Normalization) is a critical operation in modern LLM inference. Every transformer layer applies it at least twice (pre-attention and pre-FFN). At decode time (batch=1), RMSNorm is purely memory-bandwidth-bound — the arithmetic intensity is essentially zero. This makes it a perfect target for a "speed-of-light" kernel: one that saturates HBM bandwidth.

Design Philosophy

This kernel follows ThunderKittens' core principles, adapted from studying the TK GEMM kernel patterns:

From the GEMM kernel: what carries over

TK GEMM Pattern RMSNorm Adaptation
16×16 tiles as fundamental data structure Vectorized bf16×8 loads (128-bit), tile-aligned
Swizzled shared memory layouts (128-byte) Bank-conflict-free SMEM staging for gamma weights
global_layout / gl<bf16,...> descriptors Row-major bf16* with compile-time stride
warpgroup::mma_AB for compute warp_reduce_sum for the reduction (no tensor cores needed — this is bandwidth-bound)
prototype::lcf::kernel template (Load-Compute-Store-Finish) Two-phase: (1) Load+Reduce, (2) Load+Normalize+Store
tma::load_async with semaphores Vectorized float4 loads with __syncthreads barriers
fp32 accumulators for numerical stability fp32 accumulation for sum-of-squares, bf16 I/O

Key design decisions

  1. One threadblock per row. Each block independently normalizes one row of the input. This eliminates inter-block synchronization entirely — critical for decode latency.

  2. Vectorized 128-bit loads. Each thread loads 8 bf16 values (= one float4 = one LDG.128 instruction) per memory transaction, maximizing memory bus utilization.

  3. Shared memory gamma staging. The learnable scale vector gamma is loaded once into SMEM and reused during the normalization pass. For typical LLM hidden dims (4096-8192), this fits comfortably in SMEM.

  4. Two-pass design with HBM re-read. Pass 1 reads x and computes sum-of-squares. Pass 2 re-reads x from HBM, normalizes, and writes y. This re-read costs bandwidth but avoids SMEM pressure for large hidden dims. The total data movement is 3·M·N·sizeof(bf16).

  5. Fused residual variant. In transformers, RMSNorm is almost always preceded by a residual add. The fused kernel (rmsnorm_residual_tk_kernel) does x = x + residual and RMSNorm in a single launch, saving a full read+write pass.

Speed-of-light analysis

For a [M, N] input in bf16:

Total HBM bytes = 3 × M × N × 2  +  N × 2
                  ↑ read x twice    ↑ read gamma once
                  + write y once

H100 SXM peak bandwidth = 3.35 TB/s

Theoretical minimum time (M=1, N=4096):
  = (3 × 4096 × 2 + 4096 × 2) / 3.35e12
  = 32,768 / 3.35e12
  ≈ 9.8 ns

A kernel achieving >80% of peak bandwidth on this workload is operating at speed-of-light.

Building

Requires CUDA 12.3+ and a C++20 capable host compiler.

# For H100
make

# For A100
make GPU=A100

# Build and run tests + benchmarks
make test

Files

File Description
rmsnorm_tk.cuh Kernel implementation (header-only, following TK convention)
test_rmsnorm.cu Correctness tests + bandwidth benchmarks
Makefile Self-contained build (TK pattern: one Makefile per kernel)

Results — NVIDIA H100 80GB HBM3 (April 2026)

Tested on H100 SXM (132 SMs, 3350 GB/s peak HBM, CUDA 12.8).

Correctness

All 9 test cases pass. Max absolute error stays within bf16 precision (~0.02), well within the expected rounding envelope for a two-pass fp32-accumulation kernel.

Shape Result Max abs error
M=1, N=128 PASS 1.61e-02
M=4, N=768 PASS 1.58e-02
M=8, N=1024 PASS 1.70e-02
M=16, N=2048 PASS 1.75e-02
M=32, N=4096 PASS 1.78e-02
M=64, N=5120 PASS 1.81e-02
M=32, N=8192 PASS 1.83e-02
M=1, N=4096 PASS 1.41e-02
M=2048, N=4096 PASS 1.87e-02

Bandwidth benchmarks

Workload M N Time (µs) BW (GB/s) % Peak
decode-7B 1 4096 3.51 9.3 0.3%
decode-70B 1 8192 3.74 17.5 0.5%
decode-180B 1 16384 5.14 25.5 0.8%
batch8-7B 8 4096 3.54 57.9 1.7%
batch8-70B 8 8192 3.77 108.8 3.2%
prefill-7B 256 4096 3.91 1611.8 48.1%
prefill-70B 256 8192 4.67 2700.1 80.6%
1K-7B 1024 4096 6.65 3786.3 113.0%*
2K-7B 2048 4096 10.75 4684.0 139.8%*
4K-7B 4096 4096 27.57 3651.0 109.0%*
2K-70B 2048 8192 28.56 3525.8 105.2%*

* Values >100% indicate L2 cache is serving the second pass of x rather than HBM — the bandwidth formula assumes both passes go to HBM. For large-batch prefill the actual data movement is lower than 3·M·N·2 bytes, so % peak is artificially inflated. The kernel is genuinely fast in this regime.

Key observations:

  • Decode (M=1): ~3.5µs flat regardless of hidden dim — dominated by kernel launch overhead, not HBM. This is expected and normal for a batch=1 memory-bound op.
  • Prefill (M=256, N=8192): 2700 GB/s, 80.6% of peak — near speed-of-light.
  • Large prefill (M≥1024): L2 cache effects kick in for N=4096 (data fits in 50MB L2), inflating the reported number above 100%.

Benchmarking methodology

Following TK 2.0 conventions:

  • Bitwise-identical random inputs across runs
  • 500 warmup iterations to reach power steady-state
  • 100 profiling iterations measured with CUDA events
  • L2 cache awareness (multiple input groups for small workloads)
  • Reports achieved bandwidth as % of theoretical HBM peak

Relation to the inference problem

RMSNorm is one piece of the broader LLM inference puzzle. At decode time, the full transformer layer consists of:

  1. RMSNorm (this kernel) — memory-bound, ~microseconds
  2. QKV projection — GEMV, memory-bound at batch=1
  3. Attention — memory-bound (KV cache read)
  4. Output projection — GEMV
  5. RMSNorm — memory-bound
  6. FFN up+gate — GEMV
  7. FFN down — GEMV

Every one of these is memory-bandwidth-bound at small batch sizes. The overarching challenge is: can we fuse or overlap these operations to reduce total memory traffic? Ideas like megakernels (fusing the entire layer into one persistent kernel) and speculative decoding (increasing arithmetic intensity by batching speculative tokens) attack this from different angles.

References

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors