Skip to content

[Megatron] Add cross-entropy integration#1207

Merged
vaibhavjindal merged 9 commits into
linkedin:mainfrom
PrathyushaPolepalli:megatron-cross-entropy-integration
Jun 10, 2026
Merged

[Megatron] Add cross-entropy integration#1207
vaibhavjindal merged 9 commits into
linkedin:mainfrom
PrathyushaPolepalli:megatron-cross-entropy-integration

Conversation

@PrathyushaPolepalli

@PrathyushaPolepalli PrathyushaPolepalli commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR adds Liger's Triton cross-entropy kernel to Megatron-LM.

Megatron has two cross-entropy code paths. This PR replaces both with Liger:

  • The fused path: fused_vocab_parallel_cross_entropy (used when cross_entropy_loss_fusion=True).
  • The unfused path: vocab_parallel_cross_entropy (used when cross_entropy_loss_fusion=False).

You get Liger's kernel no matter which path your Megatron config picks.

Liger's kernel uses online softmax. It writes gradients in place. It never builds the full softmax tensor. The net effect is faster training and lower memory use.

This PR also keeps the RMSNorm patch that was added in PR #1254. So one call patches both kernels at once.

How to use

There are two ways. Pick whichever fits your code better.

Mode 1: One-line monkey patch

Best for users who don't want to touch their model code. Call this once before you build the model:

from liger_kernel.megatron import apply_liger_kernel_to_megatron

apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True)

That's the whole API. No CE-specific knobs. The defaults match Megatron's native behavior.

Mode 2: Use the classes directly

Best for users who build their model spec by hand or need custom config. Import the classes and wire them in yourself:

from liger_kernel.megatron import LigerMegatronRMSNorm, LigerMegatronCrossEntropy

# In your TransformerBlockSubmodules:
input_layernorm = LigerMegatronRMSNorm
pre_mlp_layernorm = LigerMegatronRMSNorm

# In a GPTModel subclass:
class MyModel(GPTModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.liger_ce = LigerMegatronCrossEntropy(
            ignore_index=-100,
            label_smoothing=0.1,
        )

    def compute_language_model_loss(self, labels, logits):
        labels_sb = labels.transpose(0, 1).contiguous()
        loss_sb = self.liger_ce(logits, labels_sb, self.pg_collection.tp)
        return loss_sb.transpose(0, 1).contiguous()

The class accepts the same knobs Liger's standalone CE accepts (ignore_index, label_smoothing, etc.). Use Mode 2 when you need them.

Scope

This PR supports tensor_model_parallel_size=1 only.

Why? At TP>1, each GPU holds a slice of the vocabulary, not the full vocab. Cross-entropy then needs cross-rank all-reduces to get a correct loss. Liger's kernel does not do those all-reduces. So the math would be wrong at TP>1.

To stop you from training with a broken loss, the patch raises RuntimeError if it sees TP>1. The check runs at patch time and again at call time, so misconfigurations fail loudly.

[TODO]: TP>1 support is planned as follow-up work (a separate "vocab-parallel" Triton kernel).

You can still use this PR with pipeline parallelism, data parallelism, or single-GPU runs. Those are all TP=1 cases.

Testing

All unit tests live in test/megatron/.

Run all megatron tests

python -m pytest test/megatron/ --no-cov

That runs three files:

  • test/megatron/test_rms_norm.py — RMSNorm class correctness (forward, backward, dtypes).
  • test/megatron/test_cross_entropy.pyLigerMegatronCrossEntropy class correctness. Covers shape sweeps, ignore_index, label_smoothing, combined config, out-of-bounds targets, torch.no_grad() path, and gradient parity vs. PyTorch's F.cross_entropy.
  • test/megatron/test_monkey_patch.pyapply_liger_kernel_to_megatron() plumbing. Covers symbol replacement for both CE paths and both RMSNorm targets, idempotency, missing-megatron errors, TP>1 guard, dispatch behavior, and end-to-end correctness through the patched symbols.

You should see all tests pass (currently 105 tests in the megatron suite).

Run a single test file

python -m pytest test/megatron/test_cross_entropy.py -v --no-cov

What the tests do NOT need

Megatron-LM is not a test dependency. The monkey-patch tests use stub modules in sys.modules, so they run on CPU without a real Megatron install. You only need PyTorch + Triton + Liger itself.

The correctness tests (RMSNorm, CE) need a GPU because they run the actual Triton kernel.

Benchmarking

Benchmarked against megatron-core==0.19.0+96a894ef6

Step 1: Run the benchmark

python benchmark/scripts/benchmark_megatron_cross_entropy.py --overwrite

This compares four providers across vocab sizes 4K → 128K (with S=2048, B=4, BF16):

  • ligerLigerMegatronCrossEntropy
  • torch — vanilla F.cross_entropy
  • megatron — Megatron's fused CE (@jit_fuser via TorchScript)
  • megatron-unfused — Megatron's unfused CE (eager Python)

The script needs a CUDA/ROCm GPU. If Megatron-LM is not installed, the two megatron providers are silently skipped and you only get liger vs torch.

Results land in benchmark/data/all_benchmark_data.csv. Each row is tagged with kernel_name="megatron_cross_entropy".

Step 2: Render plots

python benchmark/benchmarks_visualizer.py \
    --kernel-name megatron_cross_entropy --metric-name speed

python benchmark/benchmarks_visualizer.py \
    --kernel-name megatron_cross_entropy --metric-name memory

PNGs land in benchmark/visualizations/.

Memory savings

megatron_cross_entropy_memory_full_token_length

Backward speedup

megatron_cross_entropy_speed_backward_token_length

Forward speedup

megatron_cross_entropy_speed_forward_token_length

Backward+Forward speedup

megatron_cross_entropy_speed_full_token_length

Example training scripts

Two end-to-end examples in examples/megatron/. Both train a tiny GPT model with mock data for 5 iterations. They print which classes filled which norm slots and which function backs each CE entry point, so you can see Liger took over.

Prerequisites

  • A working Megatron-Core install (pip install megatron-core).
  • liger-kernel installed.
  • psutil (Megatron's async checkpoint worker pool uses it).
  • At least 2 GPUs.

Example 1: Mode 1 (monkey patch)

torchrun --nproc_per_node=2 \
    --master_addr=127.0.0.1 --master_port=29500 \
    examples/megatron/run_mode1_monkey_patch.py

One apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True) call at the top. Everything else is stock Megatron code.

Example 2: Mode 2 (hand-built spec)

torchrun --nproc_per_node=2 \
    --master_addr=127.0.0.1 --master_port=29500 \
    examples/megatron/run_mode2_hand_spec.py

No monkey patch. The script builds a TransformerBlockSubmodules with LigerMegatronRMSNorm slots and subclasses GPTModel to use LigerMegatronCrossEntropy in compute_language_model_loss.

What you should see

For both:

  • 5 lines of [modeN] iter <i> loss=<float> with the loss going down.
  • A printed module tree with LigerMegatronRMSNorm in 5 of 5 norm slots (4 per-layer + 1 block-level final).
  • The CE symbols print as liger_fused_vocab_parallel_cross_entropy and liger_vocab_parallel_cross_entropy (Mode 1) or LigerMegatronCrossEntropy (Mode 2).
  • Successfully loaded the model after a distributed checkpoint round-trip.

If Apex or TransformerEngine isn't installed, you'll see harmless warnings. Megatron falls back to the local backend, which is where Liger plugs in.

Testing Done

  • Hardware Type: H100
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@PrathyushaPolepalli PrathyushaPolepalli marked this pull request as draft April 28, 2026 05:59
@PrathyushaPolepalli PrathyushaPolepalli force-pushed the megatron-cross-entropy-integration branch 5 times, most recently from ed3c27e to b1fa5bc Compare April 29, 2026 23:35
@PrathyushaPolepalli PrathyushaPolepalli marked this pull request as ready for review April 30, 2026 16:27
@PrathyushaPolepalli PrathyushaPolepalli force-pushed the megatron-cross-entropy-integration branch 2 times, most recently from 41362ee to e4b2ff2 Compare April 30, 2026 17:11

@Mecoli1219 Mecoli1219 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks great! Excited to support Megatron with Liger. Left some comments to address.

Comment thread src/liger_kernel/megatron/cross_entropy.py Outdated
Comment thread src/liger_kernel/megatron/cross_entropy.py Outdated
Comment thread src/liger_kernel/megatron/cross_entropy.py Outdated
@PrathyushaPolepalli PrathyushaPolepalli force-pushed the megatron-cross-entropy-integration branch from e4b2ff2 to bce8664 Compare May 22, 2026 06:47
@PrathyushaPolepalli PrathyushaPolepalli force-pushed the megatron-cross-entropy-integration branch from bce8664 to 4ffd93b Compare May 22, 2026 06:52
PR linkedin#1254 (Megatron-Core RMSNorm integration) landed on main while this PR
was open. Both PRs introduce src/liger_kernel/megatron/{__init__,monkey_patch}.py
and both define apply_liger_kernel_to_megatron(...), so there were add/add
conflicts on those two files.

Resolution:

  src/liger_kernel/megatron/__init__.py
      Union the exports. Now re-exports both LigerMegatronRMSNorm (from
      the RMSNorm PR) and apply_liger_kernel_to_megatron.

  src/liger_kernel/megatron/monkey_patch.py
      Unify the two entry points into one function:

          apply_liger_kernel_to_megatron(
              rms_norm: bool = True,
              cross_entropy: bool = False,
              *,
              ignore_index: int = -100,
              label_smoothing: float = 0.0,
              reduction: str = "none",
          )

      rms_norm defaults to True (matches linkedin#1254 behaviour). cross_entropy
      defaults to False because this PR's TP=1 limitation means we
      shouldn't auto-enable a feature that crashes at TP>1. The
      tensor-parallel-size guard moves inside the cross_entropy=True
      branch so RMSNorm-only users don't see it.

  test/megatron/test_monkey_patch.py
      Existing CE tests called apply_liger_kernel_to_megatron() with no
      args. With the unified signature that now means "apply RMSNorm".
      Updated each test to pass rms_norm=False, cross_entropy=True so
      they continue to exercise the CE path only.

Everything else (cross_entropy.py, rms_norm.py, the test/benchmark/docs
additions from both PRs) is additive and didn't conflict.

All 34 tests in test/megatron/ pass post-merge.
…sed paths

Builds on the merged-from-main entry point (apply_liger_kernel_to_megatron
with rms_norm/cross_entropy flags) so this PR cleanly extends the CE side
to match the RMSNorm style PR linkedin#1254 established.

src/liger_kernel/megatron/cross_entropy.py
  Strip to just LigerMegatronCrossEntropy — an nn.Module drop-in for
  Megatron's vocab-parallel CE matching the fused signature
  (logits, target, tp_group=None). Single source of truth for the CE
  math; the monkey-patch wrappers in monkey_patch.py instantiate this
  class. Mirrors LigerMegatronRMSNorm. Reads no megatron-core imports.
  3-D shape guard + ValueError for bad reduction.

src/liger_kernel/megatron/monkey_patch.py
  Add _patch_fused_vocab_parallel_cross_entropy AND
  _patch_vocab_parallel_cross_entropy. Both wired into the
  cross_entropy=True branch so a single apply call replaces both
  Megatron CE paths. Both guarded by _PATCH_MARKER for idempotence
  (matching the RMSNorm helpers). Unfused wrapper uses a sentinel
  _LABEL_SMOOTHING_UNSET to distinguish "caller didn't pass" from
  "caller explicitly passed 0.0" — preserves Megatron's contract.

src/liger_kernel/megatron/__init__.py
  Export LigerMegatronCrossEntropy alongside LigerMegatronRMSNorm.

test/megatron/test_cross_entropy.py
  Rewrite tests around the class instead of the old _build_wrapper.
  Same parametrize matrix as test_rms_norm.py. New: 3-D shape guard,
  reduction rejection sweep, full extra_repr check.

test/megatron/test_monkey_patch.py
  Extend _install_fake_megatron to stub the unfused module too. Add:
    - test_patch_replaces_unfused_symbol
    - test_patch_replaces_both_fused_and_unfused_symbols_in_one_call
    - test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched
    - test_patch_is_idempotent_for_both_symbols
    - test_patch_fused_wrapper_passes_tp_group_through
    - test_patch_raises_when_unfused_symbol_missing
    - test_unfused_wrapper_honors_runtime_label_smoothing
    - test_unfused_wrapper_uses_default_when_caller_does_not_pass_label_smoothing
    - test_unfused_wrapper_honors_explicit_zero_label_smoothing

examples/megatron/run_mode{1,2}_*.py
  Mode 1 now patches both RMSNorm and CE in one call with
  label_smoothing=0.1 and prints both patched CE symbols. Mode 2 adds
  a _LigerCEGPTModel(GPTModel) subclass that overrides
  compute_language_model_loss to use LigerMegatronCrossEntropy directly
  (CE has no spec slot, so subclassing is the symmetric hand-built
  path). Both run at TP=1 + DP=2 — CE patches are TP=1 only.

benchmark/scripts/benchmark_megatron_cross_entropy.py
  Add "megatron-unfused" provider. Switch CSV output to
  benchmark/data/all_benchmark_data_megatron.csv via a surgical
  patch on update_benchmark_data_csv (LIGER_KERNEL_IMPL would have
  triggered backend selection too). Add optional matplotlib plot
  generation under benchmark/visualizations/megatron_cross_entropy_*.png.

48/48 tests pass. End-to-end mode-1 and mode-2 verified on 2× H100.
Benchmark shows Liger CE 2-5x faster than both Megatron paths and
40% less memory than the fused path at vocab=128K, S=2048, B=4, bf16.
@vaibhavjindal vaibhavjindal changed the title Add Megatron-LM cross-entropy integration [Do not merge] Add Megatron-LM cross-entropy integration Jun 10, 2026
vaibhavjindal and others added 4 commits June 10, 2026 06:45
Trim the framework-level patch API to just (rms_norm=True,
cross_entropy=False). CE-specific knobs (ignore_index, label_smoothing,
reduction) leaked the wrong abstraction at the wrong scope — the public
entry point's job is "patch Megatron with Liger," not "configure CE."
The patch now matches Megatron's native fused-CE behavior exactly by
constructing LigerMegatronCrossEntropy with class defaults, so callers
flip Liger on without touching loss config.

The unfused wrapper keeps its per-call label_smoothing argument — that
matches Megatron's native vocab_parallel_cross_entropy signature and is
the right scope for the knob (per-call, not patch-time). Sentinel
dispatch preserved so an explicit label_smoothing=0.0 from the caller
is honored verbatim.

Mode-2 (hand-built) users who need custom CE config still get it via
LigerMegatronCrossEntropy's __init__ kwargs — that surface is unchanged.

src/liger_kernel/megatron/monkey_patch.py
  Drop ignore_index, label_smoothing, reduction kwargs from
  apply_liger_kernel_to_megatron and from both _patch_*_cross_entropy
  helpers. Update docstrings to point Mode-2 users at
  LigerMegatronCrossEntropy for explicit config.

test/megatron/test_monkey_patch.py
  Remove test_patch_rejects_non_none_reduction (path unreachable from
  trimmed API). Rename test_patch_forwards_… →
  test_patch_constructs_ce_with_class_defaults, asserting the patch
  uses class defaults. Strip now-invalid kwargs from the three runtime
  label_smoothing tests on the unfused wrapper; their per-call
  behavior is unchanged.

docs/High-Level-APIs.md
  Replace usage example with the trimmed signature; add pointer to
  LigerMegatronCrossEntropy for Mode-2 explicit-config callers.

examples/megatron/run_mode1_monkey_patch.py
  Drop label_smoothing=0.1 demo from the apply call and from the
  header comment.

15/15 monkey-patch tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The existing suites covered the happy path; sweeping the transformers
side surfaced two gaps worth closing on the Megatron mirror:

  * test_cross_entropy.py only tested the trivial shape/dtype matrix.
    Missing: scalar-magnitude sweep, broader ignore_index sweep,
    combined label_smoothing × ignore_index, gradient parity when CE
    isn't the last op in the graph, out-of-bounds-target guard,
    forward-only (torch.no_grad) path.
  * test_monkey_patch.py asserted symbol identity ("the patched
    function exists on the fake module") but never invoked the
    wrapper with real tensors. So a wrapper-math regression could
    pass every existing test. Added an end-to-end integration block
    that applies the patch and then calls the resulting symbol
    against F.cross_entropy.

test/megatron/test_cross_entropy.py
  Six new tests adapted from test/transformers/test_cross_entropy.py
  to the [s, b, v] -> [s, b] reshape contract:
    - test_class_correctness_scalar_sweep
        (mirrors `test_correctness` with `scalar` param)
    - test_class_correctness_with_ignore_index_sweep
        (broader sentinel sweep incl. positive/weird negatives)
    - test_class_correctness_with_label_smoothing_and_ignore_index
        (the combined branch where smoothing math historically broke)
    - test_class_correctness_not_last_layer
        (loss × downstream factor, verify in-place grad survives)
    - test_class_rejects_out_of_bounds_target
        (kernel-level OOB assert reachable through the wrapper)
    - test_class_correctness_forward_only
        (torch.no_grad parity + post-hoc .backward() raises)
  Plus an _assign_ignore_index helper matching the transformers-side
  pattern of randomizing the masked-out positions.

test/megatron/test_monkey_patch.py
  Ten new tests:
    Public-API surface (mirrors transformers' test_import_from_root):
      - test_import_from_root
      - test_public_apply_function_has_no_ce_specific_kwargs
          (guards against re-introducing the kwargs just trimmed)
    End-to-end integration (closes the wrapper-math gap):
      - test_patched_fused_symbol_computes_correct_loss
      - test_patched_unfused_symbol_computes_correct_loss
      - test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch
      - test_patched_fused_symbol_preserves_gradients
      - test_patched_unfused_symbol_preserves_gradients
      - test_patched_fused_symbol_default_ignore_index_minus_100
          (pins the "Liger > native fused" behavior — native silently
           garbles target=-100; Liger zeros it)
    Symmetric split:
      - test_rms_norm_only_patch_does_not_touch_ce_symbols

95/95 megatron tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…suite per-kernel + add RMSNorm coverage

Two coverage gaps in the existing megatron suites, closed:

  * test_cross_entropy.py's three base correctness tests (matches_pytorch,
    respects_ignore_index, respects_label_smoothing) were forward-only.
    test_class_preserves_gradients was misleadingly named — it asserted
    grad is not None + shape, never that the gradient value was correct.
    So a CE backward-math regression could slip past every base test.

  * test_monkey_patch.py was scoped to cross-entropy (per its docstring
    and stub installer name) even though apply_liger_kernel_to_megatron
    patches both RMSNorm and CE. RMSNorm patch behavior had zero
    monkey-patch-side coverage — replacement, idempotency, the
    spec-provider dispatch, the "skip non-WrappedTorchNorm fallback"
    contract were all unverified at the patch level.

test/megatron/test_cross_entropy.py
  - test_class_matches_pytorch_cross_entropy: now feeds independent
    clones to ref and Liger paths (Liger writes grad in-place to the
    input; sharing the buffer would corrupt the reference), then asserts
    h_got.grad matches h_ref.grad after backward through both.
  - test_class_respects_ignore_index, test_class_respects_label_smoothing:
    same upgrade — full forward + backward parity vs. F.cross_entropy.
  - test_class_preserves_gradients: tightened body to assert grad value
    parity against PyTorch's reference, not just "grad is not None".
    Name preserved; docstring updated to call out the change.

test/megatron/test_monkey_patch.py
  Restructure:
    - Header docstring rewritten to describe per-kernel structure;
      sections numbered 1-5 so future kernels append cleanly.
    - _install_fake_megatron → _install_fake_megatron_ce (symmetric
      with the new _install_fake_megatron_rms_norm).
    - Fixture fake_megatron → fake_megatron_ce.
    - New _ensure_megatron_roots helper so the two installers can be
      called in either order without clobbering shared root modules.
    - _uninstall_fake_megatron extended to clean both kernel slices.

  New RMSNorm stub:
    - _install_fake_megatron_rms_norm stubs megatron.core.models.backends
      (LocalSpecProvider with a layer_norm method returning a sentinel)
      and megatron.core.transformer.{transformer_block,torch_norm}
      (LayerNormImpl, WrappedTorchNorm with __new__ returning a sentinel
      instance). Toggle layer_norm_is_wrapped_torch_norm to test the
      block-level skip path.
    - fake_megatron_rms_norm fixture.

  New RMSNorm tests (10), grouped 3.1–3.3:
    Replacement & idempotency:
      - test_rms_norm_patch_replaces_local_spec_provider_layer_norm
      - test_rms_norm_patch_replaces_transformer_block_layernorm_impl
      - test_rms_norm_patch_replaces_both_symbols_in_one_call
      - test_rms_norm_patch_with_rms_norm_false_leaves_norm_symbols_untouched
      - test_rms_norm_patch_is_idempotent
    Dispatch through patched targets:
      - test_rms_norm_patch_local_spec_provider_returns_liger_for_rms_norm_true
      - test_rms_norm_patch_local_spec_provider_delegates_for_rms_norm_false
      - test_rms_norm_patch_transformer_block_routes_rmsnorm_through_liger
        (instantiates the wrapping class with a RMSNorm config and
         confirms an actual LigerMegatronRMSNorm instance comes back;
         no kernel call needed, runs on CPU)
      - test_rms_norm_patch_transformer_block_routes_layernorm_through_original
    Block-level skip contract:
      - test_rms_norm_patch_transformer_block_skips_when_layer_norm_is_not_wrapped_torch_norm
        (verifies Liger does not displace TE / Apex norm fallbacks)

105/105 megatron tests pass (was 95: +4 stronger CE grad-parity checks
across existing tests' parametrizations, +10 RMSNorm patch tests).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…script plotter

Two anti-patterns in the benchmark script unwound:

  * Wrote to its own benchmark/data/all_benchmark_data_megatron.csv via a
    runtime monkey-patch of update_benchmark_data_csv. Every other Liger
    benchmark writes to the shared all_benchmark_data.csv and the rows are
    partitioned by the kernel_name column — that's what the visualizer
    keys off. The redirect was self-inflicted special-casing.

  * Embedded a ~100-line _generate_plots function that re-implemented
    benchmarks_visualizer.py's job, only differently enough to be a
    maintenance liability (log-2 x-axis, fill_between band, hard-coded
    title). The reason it existed was because the visualizer couldn't
    find our redirected CSV. Once the CSV moves back to the shared file,
    benchmarks_visualizer.py reads our rows out of the box.

benchmark/scripts/benchmark_megatron_cross_entropy.py
  - Drop _CSV_FILENAME constant + _patched_update_benchmark_data_csv +
    the benchmark_utils.update_benchmark_data_csv reassignment. Default
    write path lands on benchmark/data/all_benchmark_data.csv.
  - Drop the matplotlib import dance and the entire _generate_plots
    function (and its invocation in __main__).
  - Header docstring updated: replaces the "Output → PNG" promise with
    the standard 2-step workflow command pointing at benchmarks_visualizer.

benchmark/data/all_benchmark_data.csv
  Migrate the 96 existing H100 megatron rows in from the orphaned
  per-component CSV (schemas matched verbatim — header line dropped,
  rows appended). Preserves the measurements without requiring a GPU
  re-run.

benchmark/data/all_benchmark_data_megatron.csv
  Deleted — its data lives in the shared CSV now.

Workflow is now the same as every other Liger kernel benchmark:

  python benchmark/scripts/benchmark_megatron_cross_entropy.py --overwrite
  python benchmark/benchmarks_visualizer.py \
      --kernel-name megatron_cross_entropy --metric-name speed
  python benchmark/benchmarks_visualizer.py \
      --kernel-name megatron_cross_entropy --metric-name memory

PNG outputs land in benchmark/visualizations/ (gitignored, as before).

Net: benchmark script shrinks from 348 → 207 lines; one tracked CSV
file consolidated away; visualizer slices megatron rows out of the
shared CSV with no changes on its side.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@vaibhavjindal vaibhavjindal changed the title [Do not merge] Add Megatron-LM cross-entropy integration Add Megatron-LM cross-entropy integration Jun 10, 2026
@vaibhavjindal vaibhavjindal changed the title Add Megatron-LM cross-entropy integration [Megatron] Add cross-entropy integration Jun 10, 2026
@vaibhavjindal vaibhavjindal enabled auto-merge June 10, 2026 20:47
@vaibhavjindal vaibhavjindal dismissed Mecoli1219’s stale review June 10, 2026 20:51

Made changes on top

@vaibhavjindal vaibhavjindal added this pull request to the merge queue Jun 10, 2026
Merged via the queue into linkedin:main with commit 0400d31 Jun 10, 2026
5 of 7 checks passed
sunyi0505 pushed a commit to sunyi0505/Liger-Kernel that referenced this pull request Jun 11, 2026
## Summary

Adds `benchmark/scripts/benchmark_megatron_rms_norm.py` — the RMSNorm
parallel to the Megatron cross-entropy benchmark landed in linkedin#1207.
Provides empirical speed/memory numbers so we can point at concrete data
when claiming RMSNorm wins, instead of just citing the kernel sweep.

Output goes to the shared `benchmark/data/all_benchmark_data.csv`
(tagged `kernel_name="megatron_rms_norm"`), so the standard visualizer
renders the plots:

```bash
python benchmark/benchmarks_visualizer.py \
    --kernel-name megatron_rms_norm --metric-name speed
python benchmark/benchmarks_visualizer.py \
    --kernel-name megatron_rms_norm --metric-name memory
```

## Providers compared

- **liger** — `LigerMegatronRMSNorm` (Liger's Triton RMSNorm via the
Megatron-shaped wrapper from linkedin#1254)
- **torch** — vanilla `torch.nn.RMSNorm`
- **megatron** — Megatron's `WrappedTorchNorm` — its `__new__` returns
`torch.nn.RMSNorm`, so timings should match `torch` exactly. Included
for explicit parity confirmation since `WrappedTorchNorm` is the symbol
Liger displaces in the local-backend path.

If `megatron-core` is not installed, the `megatron` provider is silently
dropped and the run proceeds with `liger` + `torch`.

## Results on H100 80GB (S=4096, B=1, bf16)

60 rows committed in the CSV (5 hidden sizes × 3 providers × 4
measurements).

### Speed — forward

| H     | liger    | torch    | speedup |
|-------|----------|----------|---------|
| 1024  | 0.012 ms | 0.013 ms | ~flat   |
| 2048  | 0.018 ms | 0.019 ms | ~6%     |
| 4096  | 0.029 ms | 0.033 ms | ~12%    |
| 8192  | 0.051 ms | 0.074 ms | ~31%    |
| 16384 | 0.095 ms | 0.149 ms | **~36%** |

Liger forward wins, and the gap widens with hidden size.

### Speed — full (fwd + bwd)

| H     | liger   | torch    | notes                                  |
|-------|---------|----------|----------------------------------------|
| 1024  | 0.30 ms | 0.075 ms | Liger SLOWER (Triton launch overhead)  |
| 4096  | 0.31 ms | 0.20 ms  | Liger slower                           |
| 8192  | 0.32 ms | 0.40 ms  | crossover                              |
| 16384 | 0.59 ms | 0.77 ms  | **~24% faster**                        |

The flat Liger curve at small H is the giveaway: kernel launch overhead
dominates. `nn.RMSNorm`'s backward is a single fused C++/CUDA kernel;
Liger's backward launches multiple Triton kernels (dx + dw reduction +
element_mul). At small H the actual compute is tiny relative to
per-launch overhead, so the math wins from Triton get drowned. At H ≥
~6K the compute dominates and Liger wins.

### Memory — full

Approximately neutral across all H — Liger uses 0.5–1% more than torch,
within measurement noise.

`nn.RMSNorm` is already a single fused CUDA kernel with minimal
intermediates, so Liger doesn't get the activation-memory win it gets
when replacing eager-PyTorch RMSNorm (which materializes variance +
rsqrt + scale separately). Speed is the win in this comparison, not
memory.

### Parity sanity check

`torch` and `megatron` rows are bit-identical across the entire sweep —
exactly what we'd expect since `WrappedTorchNorm` is a factory that
returns `nn.RMSNorm`. Confirms Liger is replacing the right baseline.

## Honest read

Production LLMs run at H ≥ 4096. At those sizes:
- Forward: Liger wins ~12–36%
- Full: Liger wins from ~H=6K onward
- Memory: neutral

So Liger's RMSNorm is a real speed improvement for typical training
shapes. The tiny-H regression on full is launch overhead, not a
numerical/math issue.

## Plots

### Memory
<img width="1000" height="600"
alt="megatron_rms_norm_memory_full_token_length"
src="https://github.com/user-attachments/assets/168a8783-bbeb-4467-ab65-2e54a28d2882"
/>

### Backward speed
<img width="1000" height="600"
alt="megatron_rms_norm_speed_backward_token_length"
src="https://github.com/user-attachments/assets/31501a5f-082e-47f7-a596-760a5add9de7"
/>

### Forward speed
<img width="1000" height="600"
alt="megatron_rms_norm_speed_forward_token_length"
src="https://github.com/user-attachments/assets/b5084441-8011-473a-b9b4-9ff67bcee7c4"
/>

### Forward + Backward speed
<img width="1000" height="600"
alt="megatron_rms_norm_speed_full_token_length"
src="https://github.com/user-attachments/assets/5504491d-e15a-42b4-9c02-e5a3d75a781f"
/>


## Testing Done

- Hardware Type: H100 80GB HBM3
- [x] Benchmark runs end-to-end on H100 with all 3 providers active
- [x] Standard visualizer renders all 4 PNGs (3 speed modes + memory)
from the shared CSV
- [x] `torch` / `megatron` providers produce bit-identical numbers
(parity confirmation)
- [x] `make checkstyle` passes

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants