[Megatron] Add RMSNorm benchmark + measured H100 results#1257
Merged
Conversation
Adds benchmark/scripts/benchmark_megatron_rms_norm.py — the RMSNorm parallel to the megatron CE benchmark landed in #1207. Compares three providers on the [seq, batch, hidden] shape used by Megatron's TransformerBlock: - **liger** LigerMegatronRMSNorm (Liger Triton kernel via the Megatron-shaped wrapper) - **torch** vanilla torch.nn.RMSNorm - **megatron** Megatron's WrappedTorchNorm — its __new__ returns torch.nn.RMSNorm, so timings should match `torch`; included for explicit parity confirmation since WrappedTorchNorm is the specific symbol Liger displaces in the local-backend path If megatron-core is not installed, the `megatron` provider is silently dropped and the run proceeds with liger + torch. Output goes to the shared benchmark/data/all_benchmark_data.csv, tagged kernel_name="megatron_rms_norm". Standard visualizer renders the plots: python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_rms_norm --metric-name speed python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_rms_norm --metric-name memory H100 results (S=4096, B=1, bf16; 60 rows committed in the CSV): Forward (Liger wins, gap widens with H): H=1024: liger 0.012 ms vs torch 0.013 ms ≈ flat H=4096: liger 0.029 ms vs torch 0.033 ms ~12% faster H=16384: liger 0.095 ms vs torch 0.149 ms ~36% faster Full (fwd+bwd) — crossover around H≈6K: H=1024: liger 0.30 ms vs torch 0.075 ms Liger SLOWER (Triton launch overhead dominates at tiny hidden) H=8192: roughly equal ~0.33 ms H=16384: liger 0.59 ms vs torch 0.77 ms ~24% faster Memory: ~neutral across all H (Liger 0.5–1% higher than torch). nn.RMSNorm is already a single fused CUDA kernel, so Liger doesn't reduce activation memory in this comparison — speed is the win. Parity check: `torch` and `megatron` rows are bit-identical, as expected (WrappedTorchNorm returns nn.RMSNorm). Production LLMs run at H >= 4096, where Liger wins on forward and breaks even / wins on full. The tiny-H regression is launch overhead, not a math regression — Liger's backward launches multiple Triton kernels while PyTorch's fused C++ backward is a single launch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kolehma8
approved these changes
Jun 10, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
benchmark/scripts/benchmark_megatron_rms_norm.py— the RMSNorm parallel to the Megatron cross-entropy benchmark landed in #1207. Provides empirical speed/memory numbers so we can point at concrete data when claiming RMSNorm wins, instead of just citing the kernel sweep.Output goes to the shared
benchmark/data/all_benchmark_data.csv(taggedkernel_name="megatron_rms_norm"), so the standard visualizer renders the plots:python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_rms_norm --metric-name speed python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_rms_norm --metric-name memoryProviders compared
LigerMegatronRMSNorm(Liger's Triton RMSNorm via the Megatron-shaped wrapper from [Megatron] Add RMSNorm integration #1254)torch.nn.RMSNormWrappedTorchNorm— its__new__returnstorch.nn.RMSNorm, so timings should matchtorchexactly. Included for explicit parity confirmation sinceWrappedTorchNormis the symbol Liger displaces in the local-backend path.If
megatron-coreis not installed, themegatronprovider is silently dropped and the run proceeds withliger+torch.Results on H100 80GB (S=4096, B=1, bf16)
60 rows committed in the CSV (5 hidden sizes × 3 providers × 4 measurements).
Speed — forward
Liger forward wins, and the gap widens with hidden size.
Speed — full (fwd + bwd)
The flat Liger curve at small H is the giveaway: kernel launch overhead dominates.
nn.RMSNorm's backward is a single fused C++/CUDA kernel; Liger's backward launches multiple Triton kernels (dx + dw reduction + element_mul). At small H the actual compute is tiny relative to per-launch overhead, so the math wins from Triton get drowned. At H ≥ ~6K the compute dominates and Liger wins.Memory — full
Approximately neutral across all H — Liger uses 0.5–1% more than torch, within measurement noise.
nn.RMSNormis already a single fused CUDA kernel with minimal intermediates, so Liger doesn't get the activation-memory win it gets when replacing eager-PyTorch RMSNorm (which materializes variance + rsqrt + scale separately). Speed is the win in this comparison, not memory.Parity sanity check
torchandmegatronrows are bit-identical across the entire sweep — exactly what we'd expect sinceWrappedTorchNormis a factory that returnsnn.RMSNorm. Confirms Liger is replacing the right baseline.Honest read
Production LLMs run at H ≥ 4096. At those sizes:
So Liger's RMSNorm is a real speed improvement for typical training shapes. The tiny-H regression on full is launch overhead, not a numerical/math issue.
Plots
Memory
Backward speed
Forward speed
Forward + Backward speed
Testing Done
torch/megatronproviders produce bit-identical numbers (parity confirmation)make checkstylepasses