Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2315,3 +2315,63 @@ megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,16384,1280.1
megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,32768,2560.1650390625,2560.1650390625,2560.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0
megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,65536,5120.1650390625,5120.1650390625,5120.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0
megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,131072,10240.1650390625,10240.1650390625,10240.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0
megatron_rms_norm,liger,forward,speed,ms,H,hidden size,1024,0.012160000391304493,0.011935999616980553,0.012460799887776375,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0
megatron_rms_norm,liger,forward,speed,ms,H,hidden size,2048,0.01756799966096878,0.01740800030529499,0.017920000478625298,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0
megatron_rms_norm,liger,forward,speed,ms,H,hidden size,4096,0.029184000566601753,0.028863999992609024,0.029888000339269638,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0
megatron_rms_norm,liger,forward,speed,ms,H,hidden size,8192,0.050912000238895416,0.05020799860358238,0.05167999863624573,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0
megatron_rms_norm,liger,forward,speed,ms,H,hidden size,16384,0.0950080007314682,0.09459199756383896,0.09532800316810608,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:37,0.8.0
megatron_rms_norm,torch,forward,speed,ms,H,hidden size,1024,0.012768000364303589,0.012543999589979649,0.012992000207304955,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,torch,forward,speed,ms,H,hidden size,2048,0.019039999693632126,0.01881599985063076,0.019392000511288643,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,torch,forward,speed,ms,H,hidden size,4096,0.03248000144958496,0.03208959847688675,0.03308799862861633,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,torch,forward,speed,ms,H,hidden size,8192,0.07449600100517273,0.0737600028514862,0.07487999647855759,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,torch,forward,speed,ms,H,hidden size,16384,0.14924800395965576,0.14870400726795197,0.14982399344444275,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,1024,0.012736000120639801,0.012480000033974648,0.012992000207304955,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,2048,0.019007999449968338,0.018783999606966972,0.019392000511288643,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,4096,0.03248000144958496,0.03206399828195572,0.03308799862861633,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,8192,0.07449600100517273,0.07381760329008102,0.07487999647855759,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,megatron,forward,speed,ms,H,hidden size,16384,0.14927999675273895,0.14876799285411835,0.14988799393177032,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:38,0.8.0
megatron_rms_norm,liger,backward,speed,ms,H,hidden size,1024,0.2996159940958023,0.2936832010746002,0.3068287968635559,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0
megatron_rms_norm,liger,backward,speed,ms,H,hidden size,2048,0.29683199524879456,0.29320319890975954,0.3028415977954865,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0
megatron_rms_norm,liger,backward,speed,ms,H,hidden size,4096,0.3018079996109009,0.297132807970047,0.31088640093803405,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0
megatron_rms_norm,liger,backward,speed,ms,H,hidden size,8192,0.3019679933786392,0.2977280020713806,0.308896005153656,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0
megatron_rms_norm,liger,backward,speed,ms,H,hidden size,16384,0.45558400452136993,0.4541440010070801,0.4569920003414154,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:39,0.8.0
megatron_rms_norm,torch,backward,speed,ms,H,hidden size,1024,0.07440000027418137,0.07275520265102386,0.07758720219135284,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0
megatron_rms_norm,torch,backward,speed,ms,H,hidden size,2048,0.09468799829483032,0.09400960057973862,0.09536000341176987,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0
megatron_rms_norm,torch,backward,speed,ms,H,hidden size,4096,0.1629280000925064,0.16211199760437012,0.1637440025806427,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0
megatron_rms_norm,torch,backward,speed,ms,H,hidden size,8192,0.3341119885444641,0.33350399136543274,0.33475199341773987,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0
megatron_rms_norm,torch,backward,speed,ms,H,hidden size,16384,0.643455982208252,0.6425664067268371,0.6444415807724,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:40,0.8.0
megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,1024,0.07462400197982788,0.07313919812440872,0.07705599814653397,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,2048,0.0931520015001297,0.0926399976015091,0.09366399794816971,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,4096,0.162432000041008,0.1616320013999939,0.16316799819469452,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,8192,0.33664000034332275,0.335999995470047,0.3373759984970093,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,megatron,backward,speed,ms,H,hidden size,16384,0.6427839994430542,0.641862416267395,0.6437952041625977,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,liger,full,speed,ms,H,hidden size,1024,0.2959360033273697,0.2916415989398956,0.3015359938144684,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,liger,full,speed,ms,H,hidden size,2048,0.30425599217414856,0.30000640749931334,0.31065601110458374,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,liger,full,speed,ms,H,hidden size,4096,0.3089919984340668,0.3035840094089508,0.3177344083786011,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,liger,full,speed,ms,H,hidden size,8192,0.32104000449180603,0.3198783934116363,0.32229760885238645,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,liger,full,speed,ms,H,hidden size,16384,0.5876799821853638,0.5862079858779907,0.5891199707984924,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:41,0.8.0
megatron_rms_norm,torch,full,speed,ms,H,hidden size,1024,0.07609599828720093,0.07417599856853485,0.07968000322580338,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,torch,full,speed,ms,H,hidden size,2048,0.10864000022411346,0.10764160007238388,0.1096000000834465,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,torch,full,speed,ms,H,hidden size,4096,0.19510400295257568,0.1937599927186966,0.19665920436382295,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,torch,full,speed,ms,H,hidden size,8192,0.3964959979057312,0.3953407883644104,0.397651207447052,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,torch,full,speed,ms,H,hidden size,16384,0.7717440128326416,0.7703359723091125,0.77292799949646,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,megatron,full,speed,ms,H,hidden size,1024,0.0745600014925003,0.0729919970035553,0.07756800204515457,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,megatron,full,speed,ms,H,hidden size,2048,0.1085439994931221,0.10764800012111664,0.1096000000834465,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,megatron,full,speed,ms,H,hidden size,4096,0.19526399672031403,0.19379200041294098,0.19660800695419312,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,megatron,full,speed,ms,H,hidden size,8192,0.39667201042175293,0.3954240083694458,0.3978111982345581,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,megatron,full,speed,ms,H,hidden size,16384,0.7716799974441528,0.7703359723091125,0.7728319764137268,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:42,0.8.0
megatron_rms_norm,liger,full,memory,MB,H,hidden size,1024,40.5419921875,40.5419921875,40.5419921875,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,liger,full,memory,MB,H,hidden size,2048,81.0673828125,81.0673828125,81.0673828125,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,liger,full,memory,MB,H,hidden size,4096,162.1181640625,162.1181640625,162.1181640625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,liger,full,memory,MB,H,hidden size,8192,324.2197265625,324.2197265625,324.2197265625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,liger,full,memory,MB,H,hidden size,16384,648.4228515625,648.4228515625,648.4228515625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,torch,full,memory,MB,H,hidden size,1024,40.0224609375,40.0224609375,40.0224609375,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,torch,full,memory,MB,H,hidden size,2048,80.0283203125,80.0283203125,80.0283203125,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,torch,full,memory,MB,H,hidden size,4096,160.0400390625,160.0400390625,160.0400390625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,torch,full,memory,MB,H,hidden size,8192,320.0634765625,320.0634765625,320.0634765625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,torch,full,memory,MB,H,hidden size,16384,640.1103515625,640.1103515625,640.1103515625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,megatron,full,memory,MB,H,hidden size,1024,40.0224609375,40.0224609375,40.0224609375,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,megatron,full,memory,MB,H,hidden size,2048,80.0283203125,80.0283203125,80.0283203125,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,megatron,full,memory,MB,H,hidden size,4096,160.0400390625,160.0400390625,160.0400390625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,megatron,full,memory,MB,H,hidden size,8192,320.0634765625,320.0634765625,320.0634765625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
megatron_rms_norm,megatron,full,memory,MB,H,hidden size,16384,640.1103515625,640.1103515625,640.1103515625,"{""S"": 4096, ""B"": 1}",NVIDIA H100 80GB HBM3,2026-06-10 21:03:43,0.8.0
175 changes: 175 additions & 0 deletions benchmark/scripts/benchmark_megatron_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Benchmark Liger's Megatron-LM RMSNorm wrapper.

Compares three providers on the per-token RMSNorm call shape ``[seq, batch, hidden]``:

- **torch**: vanilla ``torch.nn.RMSNorm`` — the raw PyTorch reference
- **megatron**: Megatron's ``WrappedTorchNorm`` (the symbol Liger displaces in the
local-backend path; structurally a factory that returns ``nn.RMSNorm``, so its
timing should be indistinguishable from ``torch`` — included for explicit parity
confirmation, since the *point* of the patch is replacing this specific symbol)
- **liger**: ``LigerMegatronRMSNorm`` — Liger's Triton RMSNorm in the Megatron-shaped
wrapper (per-layer + final_layernorm slot).

Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not
installed, the ``megatron`` provider is silently dropped and the run proceeds with
``liger`` + ``torch``.

Output goes to the shared ``benchmark/data/all_benchmark_data.csv`` — rows are
tagged with ``kernel_name="megatron_rms_norm"`` and the standard visualizer renders
them via:

python benchmark/benchmarks_visualizer.py \\
--kernel-name megatron_rms_norm --metric-name speed
python benchmark/benchmarks_visualizer.py \\
--kernel-name megatron_rms_norm --metric-name memory
"""

from types import SimpleNamespace

import torch
import torch.nn as nn
import triton

from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.megatron import LigerMegatronRMSNorm
from liger_kernel.utils import infer_device

device = infer_device()

try:
from megatron.core.transformer.torch_norm import WrappedTorchNorm

_MEGATRON_AVAILABLE = True
except ImportError:
WrappedTorchNorm = None
_MEGATRON_AVAILABLE = False


def _make_config():
"""Duck-typed TransformerConfig accepted by both LigerMegatronRMSNorm and WrappedTorchNorm.

WrappedTorchNorm asserts a handful of attributes are False; LigerMegatronRMSNorm reads
``normalization``, ``sequence_parallel``, and ``layernorm_zero_centered_gamma``.
"""
return SimpleNamespace(
normalization="RMSNorm",
sequence_parallel=False,
layernorm_zero_centered_gamma=False,
persist_layer_norm=False,
memory_efficient_layer_norm=False,
)


def _make_layer(provider: str, hidden_size: int, eps: float = 1e-6) -> nn.Module:
config = _make_config()
if provider == "liger":
layer = LigerMegatronRMSNorm(config=config, hidden_size=hidden_size, eps=eps)
elif provider == "torch":
layer = nn.RMSNorm(normalized_shape=hidden_size, eps=eps)
elif provider == "megatron":
if not _MEGATRON_AVAILABLE:
raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider")
# WrappedTorchNorm.__new__ returns an nn.RMSNorm instance directly.
layer = WrappedTorchNorm(config=config, hidden_size=hidden_size, eps=eps)
else:
raise ValueError(f"unknown provider: {provider!r}")
return layer.to(device).to(torch.bfloat16)


def _make_input(s: int, b: int, h: int, requires_grad: bool = True) -> torch.Tensor:
return torch.randn(s, b, h, device=device, dtype=torch.bfloat16, requires_grad=requires_grad)


def bench_speed_megatron_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
h = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
s = input.extra_benchmark_config["S"]
b = input.extra_benchmark_config["B"]

layer = _make_layer(provider, h)
x = _make_input(s, b, h)

def fwd():
return layer(x)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "backward":
# Rerun fwd inside the timed loop so each backward sees a fresh graph (mirrors the
# megatron CE benchmark's "backward includes forward" convention — subtract the
# "forward" measurement to derive backward-only timing).
def _fwd_bwd():
if x.grad is not None:
x.grad = None
out = fwd()
out.sum().backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES)
elif mode == "full":

def full():
y = fwd()
y.sum().backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES)
else:
raise ValueError(f"unknown mode: {mode!r}")

return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80)


def bench_memory_megatron_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
h = input.x
provider = input.kernel_provider
s = input.extra_benchmark_config["S"]
b = input.extra_benchmark_config["B"]

layer = _make_layer(provider, h)
x = _make_input(s, b, h)

def full():
y = layer(x)
y.sum().backward()

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80)


if __name__ == "__main__":
args = parse_benchmark_script_args()

providers = ["liger", "torch"]
if _MEGATRON_AVAILABLE:
providers.append("megatron")

common_configs = {
"kernel_name": "megatron_rms_norm",
"x_name": "H",
"x_label": "hidden size",
"x_values": [2**i for i in range(10, 15)], # 1024 → 16384
"kernel_providers": providers,
"extra_benchmark_configs": [{"S": 4096, "B": 1}],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_megatron_rms_norm,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_megatron_rms_norm,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)