diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index a578d2622..f12ee9830 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -2315,3 +2315,63 @@ megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,16384,1280.1 megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,32768,2560.1650390625,2560.1650390625,2560.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,65536,5120.1650390625,5120.1650390625,5120.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,131072,10240.1650390625,10240.1650390625,10240.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_rms_norm,liger,forward,speed,ms,H,hidden size,1024,0.012160000391304493,0.011935999616980553,0.012460799887776375,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0 +megatron_rms_norm,liger,forward,speed,ms,H,hidden size,2048,0.01756799966096878,0.01740800030529499,0.017920000478625298,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0 +megatron_rms_norm,liger,forward,speed,ms,H,hidden size,4096,0.029184000566601753,0.028863999992609024,0.029888000339269638,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0 +megatron_rms_norm,liger,forward,speed,ms,H,hidden size,8192,0.050912000238895416,0.05020799860358238,0.05167999863624573,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0 +megatron_rms_norm,liger,forward,speed,ms,H,hidden size,16384,0.0950080007314682,0.09459199756383896,0.09532800316810608,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0 +megatron_rms_norm,torch,forward,speed,ms,H,hidden size,1024,0.012768000364303589,0.012543999589979649,0.012992000207304955,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,torch,forward,speed,ms,H,hidden size,2048,0.019039999693632126,0.01881599985063076,0.019392000511288643,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,torch,forward,speed,ms,H,hidden size,4096,0.03248000144958496,0.03208959847688675,0.03308799862861633,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,torch,forward,speed,ms,H,hidden size,8192,0.07449600100517273,0.0737600028514862,0.07487999647855759,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,torch,forward,speed,ms,H,hidden size,16384,0.14924800395965576,0.14870400726795197,0.14982399344444275,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,1024,0.012736000120639801,0.012480000033974648,0.012992000207304955,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,2048,0.019007999449968338,0.018783999606966972,0.019392000511288643,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,4096,0.03248000144958496,0.03206399828195572,0.03308799862861633,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,8192,0.07449600100517273,0.07381760329008102,0.07487999647855759,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,16384,0.14927999675273895,0.14876799285411835,0.14988799393177032,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0 +megatron_rms_norm,liger,backward,speed,ms,H,hidden size,1024,0.2996159940958023,0.2936832010746002,0.3068287968635559,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0 +megatron_rms_norm,liger,backward,speed,ms,H,hidden size,2048,0.29683199524879456,0.29320319890975954,0.3028415977954865,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0 +megatron_rms_norm,liger,backward,speed,ms,H,hidden size,4096,0.3018079996109009,0.297132807970047,0.31088640093803405,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0 +megatron_rms_norm,liger,backward,speed,ms,H,hidden size,8192,0.3019679933786392,0.2977280020713806,0.308896005153656,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0 +megatron_rms_norm,liger,backward,speed,ms,H,hidden size,16384,0.45558400452136993,0.4541440010070801,0.4569920003414154,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0 +megatron_rms_norm,torch,backward,speed,ms,H,hidden size,1024,0.07440000027418137,0.07275520265102386,0.07758720219135284,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0 +megatron_rms_norm,torch,backward,speed,ms,H,hidden size,2048,0.09468799829483032,0.09400960057973862,0.09536000341176987,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0 +megatron_rms_norm,torch,backward,speed,ms,H,hidden size,4096,0.1629280000925064,0.16211199760437012,0.1637440025806427,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0 +megatron_rms_norm,torch,backward,speed,ms,H,hidden size,8192,0.3341119885444641,0.33350399136543274,0.33475199341773987,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0 +megatron_rms_norm,torch,backward,speed,ms,H,hidden size,16384,0.643455982208252,0.6425664067268371,0.6444415807724,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0 +megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,1024,0.07462400197982788,0.07313919812440872,0.07705599814653397,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,2048,0.0931520015001297,0.0926399976015091,0.09366399794816971,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,4096,0.162432000041008,0.1616320013999939,0.16316799819469452,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,8192,0.33664000034332275,0.335999995470047,0.3373759984970093,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,16384,0.6427839994430542,0.641862416267395,0.6437952041625977,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,liger,full,speed,ms,H,hidden size,1024,0.2959360033273697,0.2916415989398956,0.3015359938144684,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,liger,full,speed,ms,H,hidden size,2048,0.30425599217414856,0.30000640749931334,0.31065601110458374,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,liger,full,speed,ms,H,hidden size,4096,0.3089919984340668,0.3035840094089508,0.3177344083786011,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,liger,full,speed,ms,H,hidden size,8192,0.32104000449180603,0.3198783934116363,0.32229760885238645,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,liger,full,speed,ms,H,hidden size,16384,0.5876799821853638,0.5862079858779907,0.5891199707984924,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0 +megatron_rms_norm,torch,full,speed,ms,H,hidden size,1024,0.07609599828720093,0.07417599856853485,0.07968000322580338,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,torch,full,speed,ms,H,hidden size,2048,0.10864000022411346,0.10764160007238388,0.1096000000834465,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,torch,full,speed,ms,H,hidden size,4096,0.19510400295257568,0.1937599927186966,0.19665920436382295,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,torch,full,speed,ms,H,hidden size,8192,0.3964959979057312,0.3953407883644104,0.397651207447052,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,torch,full,speed,ms,H,hidden size,16384,0.7717440128326416,0.7703359723091125,0.77292799949646,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,megatron,full,speed,ms,H,hidden size,1024,0.0745600014925003,0.0729919970035553,0.07756800204515457,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,megatron,full,speed,ms,H,hidden size,2048,0.1085439994931221,0.10764800012111664,0.1096000000834465,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,megatron,full,speed,ms,H,hidden size,4096,0.19526399672031403,0.19379200041294098,0.19660800695419312,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,megatron,full,speed,ms,H,hidden size,8192,0.39667201042175293,0.3954240083694458,0.3978111982345581,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,megatron,full,speed,ms,H,hidden size,16384,0.7716799974441528,0.7703359723091125,0.7728319764137268,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0 +megatron_rms_norm,liger,full,memory,MB,H,hidden size,1024,40.5419921875,40.5419921875,40.5419921875,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,liger,full,memory,MB,H,hidden size,2048,81.0673828125,81.0673828125,81.0673828125,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,liger,full,memory,MB,H,hidden size,4096,162.1181640625,162.1181640625,162.1181640625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,liger,full,memory,MB,H,hidden size,8192,324.2197265625,324.2197265625,324.2197265625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,liger,full,memory,MB,H,hidden size,16384,648.4228515625,648.4228515625,648.4228515625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,torch,full,memory,MB,H,hidden size,1024,40.0224609375,40.0224609375,40.0224609375,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,torch,full,memory,MB,H,hidden size,2048,80.0283203125,80.0283203125,80.0283203125,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,torch,full,memory,MB,H,hidden size,4096,160.0400390625,160.0400390625,160.0400390625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,torch,full,memory,MB,H,hidden size,8192,320.0634765625,320.0634765625,320.0634765625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,torch,full,memory,MB,H,hidden size,16384,640.1103515625,640.1103515625,640.1103515625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,megatron,full,memory,MB,H,hidden size,1024,40.0224609375,40.0224609375,40.0224609375,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,megatron,full,memory,MB,H,hidden size,2048,80.0283203125,80.0283203125,80.0283203125,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,megatron,full,memory,MB,H,hidden size,4096,160.0400390625,160.0400390625,160.0400390625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,megatron,full,memory,MB,H,hidden size,8192,320.0634765625,320.0634765625,320.0634765625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 +megatron_rms_norm,megatron,full,memory,MB,H,hidden size,16384,640.1103515625,640.1103515625,640.1103515625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0 diff --git a/benchmark/scripts/benchmark_megatron_rms_norm.py b/benchmark/scripts/benchmark_megatron_rms_norm.py new file mode 100644 index 000000000..af608f270 --- /dev/null +++ b/benchmark/scripts/benchmark_megatron_rms_norm.py @@ -0,0 +1,175 @@ +"""Benchmark Liger's Megatron-LM RMSNorm wrapper. + +Compares three providers on the per-token RMSNorm call shape ``[seq, batch, hidden]``: + + - **torch**: vanilla ``torch.nn.RMSNorm`` — the raw PyTorch reference + - **megatron**: Megatron's ``WrappedTorchNorm`` (the symbol Liger displaces in the + local-backend path; structurally a factory that returns ``nn.RMSNorm``, so its + timing should be indistinguishable from ``torch`` — included for explicit parity + confirmation, since the *point* of the patch is replacing this specific symbol) + - **liger**: ``LigerMegatronRMSNorm`` — Liger's Triton RMSNorm in the Megatron-shaped + wrapper (per-layer + final_layernorm slot). + +Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not +installed, the ``megatron`` provider is silently dropped and the run proceeds with +``liger`` + ``torch``. + +Output goes to the shared ``benchmark/data/all_benchmark_data.csv`` — rows are +tagged with ``kernel_name="megatron_rms_norm"`` and the standard visualizer renders +them via: + + python benchmark/benchmarks_visualizer.py \\ + --kernel-name megatron_rms_norm --metric-name speed + python benchmark/benchmarks_visualizer.py \\ + --kernel-name megatron_rms_norm --metric-name memory +""" + +from types import SimpleNamespace + +import torch +import torch.nn as nn +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.megatron import LigerMegatronRMSNorm +from liger_kernel.utils import infer_device + +device = infer_device() + +try: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + _MEGATRON_AVAILABLE = True +except ImportError: + WrappedTorchNorm = None + _MEGATRON_AVAILABLE = False + + +def _make_config(): + """Duck-typed TransformerConfig accepted by both LigerMegatronRMSNorm and WrappedTorchNorm. + + WrappedTorchNorm asserts a handful of attributes are False; LigerMegatronRMSNorm reads + ``normalization``, ``sequence_parallel``, and ``layernorm_zero_centered_gamma``. + """ + return SimpleNamespace( + normalization="RMSNorm", + sequence_parallel=False, + layernorm_zero_centered_gamma=False, + persist_layer_norm=False, + memory_efficient_layer_norm=False, + ) + + +def _make_layer(provider: str, hidden_size: int, eps: float = 1e-6) -> nn.Module: + config = _make_config() + if provider == "liger": + layer = LigerMegatronRMSNorm(config=config, hidden_size=hidden_size, eps=eps) + elif provider == "torch": + layer = nn.RMSNorm(normalized_shape=hidden_size, eps=eps) + elif provider == "megatron": + if not _MEGATRON_AVAILABLE: + raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider") + # WrappedTorchNorm.__new__ returns an nn.RMSNorm instance directly. + layer = WrappedTorchNorm(config=config, hidden_size=hidden_size, eps=eps) + else: + raise ValueError(f"unknown provider: {provider!r}") + return layer.to(device).to(torch.bfloat16) + + +def _make_input(s: int, b: int, h: int, requires_grad: bool = True) -> torch.Tensor: + return torch.randn(s, b, h, device=device, dtype=torch.bfloat16, requires_grad=requires_grad) + + +def bench_speed_megatron_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + h = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + s = input.extra_benchmark_config["S"] + b = input.extra_benchmark_config["B"] + + layer = _make_layer(provider, h) + x = _make_input(s, b, h) + + def fwd(): + return layer(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + # Rerun fwd inside the timed loop so each backward sees a fresh graph (mirrors the + # megatron CE benchmark's "backward includes forward" convention — subtract the + # "forward" measurement to derive backward-only timing). + def _fwd_bwd(): + if x.grad is not None: + x.grad = None + out = fwd() + out.sum().backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES) + elif mode == "full": + + def full(): + y = fwd() + y.sum().backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"unknown mode: {mode!r}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_megatron_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + h = input.x + provider = input.kernel_provider + s = input.extra_benchmark_config["S"] + b = input.extra_benchmark_config["B"] + + layer = _make_layer(provider, h) + x = _make_input(s, b, h) + + def full(): + y = layer(x) + y.sum().backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + providers = ["liger", "torch"] + if _MEGATRON_AVAILABLE: + providers.append("megatron") + + common_configs = { + "kernel_name": "megatron_rms_norm", + "x_name": "H", + "x_label": "hidden size", + "x_values": [2**i for i in range(10, 15)], # 1024 → 16384 + "kernel_providers": providers, + "extra_benchmark_configs": [{"S": 4096, "B": 1}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_megatron_rms_norm, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_megatron_rms_norm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + )