[Megatron] Add cross-entropy integration#1207
Merged
vaibhavjindal merged 9 commits intoJun 10, 2026
Merged
Conversation
ed3c27e to
b1fa5bc
Compare
41362ee to
e4b2ff2
Compare
Mecoli1219
previously requested changes
May 13, 2026
Mecoli1219
left a comment
Collaborator
There was a problem hiding this comment.
Overall looks great! Excited to support Megatron with Liger. Left some comments to address.
e4b2ff2 to
bce8664
Compare
bce8664 to
4ffd93b
Compare
3 tasks
PR linkedin#1254 (Megatron-Core RMSNorm integration) landed on main while this PR was open. Both PRs introduce src/liger_kernel/megatron/{__init__,monkey_patch}.py and both define apply_liger_kernel_to_megatron(...), so there were add/add conflicts on those two files. Resolution: src/liger_kernel/megatron/__init__.py Union the exports. Now re-exports both LigerMegatronRMSNorm (from the RMSNorm PR) and apply_liger_kernel_to_megatron. src/liger_kernel/megatron/monkey_patch.py Unify the two entry points into one function: apply_liger_kernel_to_megatron( rms_norm: bool = True, cross_entropy: bool = False, *, ignore_index: int = -100, label_smoothing: float = 0.0, reduction: str = "none", ) rms_norm defaults to True (matches linkedin#1254 behaviour). cross_entropy defaults to False because this PR's TP=1 limitation means we shouldn't auto-enable a feature that crashes at TP>1. The tensor-parallel-size guard moves inside the cross_entropy=True branch so RMSNorm-only users don't see it. test/megatron/test_monkey_patch.py Existing CE tests called apply_liger_kernel_to_megatron() with no args. With the unified signature that now means "apply RMSNorm". Updated each test to pass rms_norm=False, cross_entropy=True so they continue to exercise the CE path only. Everything else (cross_entropy.py, rms_norm.py, the test/benchmark/docs additions from both PRs) is additive and didn't conflict. All 34 tests in test/megatron/ pass post-merge.
…sed paths Builds on the merged-from-main entry point (apply_liger_kernel_to_megatron with rms_norm/cross_entropy flags) so this PR cleanly extends the CE side to match the RMSNorm style PR linkedin#1254 established. src/liger_kernel/megatron/cross_entropy.py Strip to just LigerMegatronCrossEntropy — an nn.Module drop-in for Megatron's vocab-parallel CE matching the fused signature (logits, target, tp_group=None). Single source of truth for the CE math; the monkey-patch wrappers in monkey_patch.py instantiate this class. Mirrors LigerMegatronRMSNorm. Reads no megatron-core imports. 3-D shape guard + ValueError for bad reduction. src/liger_kernel/megatron/monkey_patch.py Add _patch_fused_vocab_parallel_cross_entropy AND _patch_vocab_parallel_cross_entropy. Both wired into the cross_entropy=True branch so a single apply call replaces both Megatron CE paths. Both guarded by _PATCH_MARKER for idempotence (matching the RMSNorm helpers). Unfused wrapper uses a sentinel _LABEL_SMOOTHING_UNSET to distinguish "caller didn't pass" from "caller explicitly passed 0.0" — preserves Megatron's contract. src/liger_kernel/megatron/__init__.py Export LigerMegatronCrossEntropy alongside LigerMegatronRMSNorm. test/megatron/test_cross_entropy.py Rewrite tests around the class instead of the old _build_wrapper. Same parametrize matrix as test_rms_norm.py. New: 3-D shape guard, reduction rejection sweep, full extra_repr check. test/megatron/test_monkey_patch.py Extend _install_fake_megatron to stub the unfused module too. Add: - test_patch_replaces_unfused_symbol - test_patch_replaces_both_fused_and_unfused_symbols_in_one_call - test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched - test_patch_is_idempotent_for_both_symbols - test_patch_fused_wrapper_passes_tp_group_through - test_patch_raises_when_unfused_symbol_missing - test_unfused_wrapper_honors_runtime_label_smoothing - test_unfused_wrapper_uses_default_when_caller_does_not_pass_label_smoothing - test_unfused_wrapper_honors_explicit_zero_label_smoothing examples/megatron/run_mode{1,2}_*.py Mode 1 now patches both RMSNorm and CE in one call with label_smoothing=0.1 and prints both patched CE symbols. Mode 2 adds a _LigerCEGPTModel(GPTModel) subclass that overrides compute_language_model_loss to use LigerMegatronCrossEntropy directly (CE has no spec slot, so subclassing is the symmetric hand-built path). Both run at TP=1 + DP=2 — CE patches are TP=1 only. benchmark/scripts/benchmark_megatron_cross_entropy.py Add "megatron-unfused" provider. Switch CSV output to benchmark/data/all_benchmark_data_megatron.csv via a surgical patch on update_benchmark_data_csv (LIGER_KERNEL_IMPL would have triggered backend selection too). Add optional matplotlib plot generation under benchmark/visualizations/megatron_cross_entropy_*.png. 48/48 tests pass. End-to-end mode-1 and mode-2 verified on 2× H100. Benchmark shows Liger CE 2-5x faster than both Megatron paths and 40% less memory than the fused path at vocab=128K, S=2048, B=4, bf16.
Trim the framework-level patch API to just (rms_norm=True, cross_entropy=False). CE-specific knobs (ignore_index, label_smoothing, reduction) leaked the wrong abstraction at the wrong scope — the public entry point's job is "patch Megatron with Liger," not "configure CE." The patch now matches Megatron's native fused-CE behavior exactly by constructing LigerMegatronCrossEntropy with class defaults, so callers flip Liger on without touching loss config. The unfused wrapper keeps its per-call label_smoothing argument — that matches Megatron's native vocab_parallel_cross_entropy signature and is the right scope for the knob (per-call, not patch-time). Sentinel dispatch preserved so an explicit label_smoothing=0.0 from the caller is honored verbatim. Mode-2 (hand-built) users who need custom CE config still get it via LigerMegatronCrossEntropy's __init__ kwargs — that surface is unchanged. src/liger_kernel/megatron/monkey_patch.py Drop ignore_index, label_smoothing, reduction kwargs from apply_liger_kernel_to_megatron and from both _patch_*_cross_entropy helpers. Update docstrings to point Mode-2 users at LigerMegatronCrossEntropy for explicit config. test/megatron/test_monkey_patch.py Remove test_patch_rejects_non_none_reduction (path unreachable from trimmed API). Rename test_patch_forwards_… → test_patch_constructs_ce_with_class_defaults, asserting the patch uses class defaults. Strip now-invalid kwargs from the three runtime label_smoothing tests on the unfused wrapper; their per-call behavior is unchanged. docs/High-Level-APIs.md Replace usage example with the trimmed signature; add pointer to LigerMegatronCrossEntropy for Mode-2 explicit-config callers. examples/megatron/run_mode1_monkey_patch.py Drop label_smoothing=0.1 demo from the apply call and from the header comment. 15/15 monkey-patch tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The existing suites covered the happy path; sweeping the transformers
side surfaced two gaps worth closing on the Megatron mirror:
* test_cross_entropy.py only tested the trivial shape/dtype matrix.
Missing: scalar-magnitude sweep, broader ignore_index sweep,
combined label_smoothing × ignore_index, gradient parity when CE
isn't the last op in the graph, out-of-bounds-target guard,
forward-only (torch.no_grad) path.
* test_monkey_patch.py asserted symbol identity ("the patched
function exists on the fake module") but never invoked the
wrapper with real tensors. So a wrapper-math regression could
pass every existing test. Added an end-to-end integration block
that applies the patch and then calls the resulting symbol
against F.cross_entropy.
test/megatron/test_cross_entropy.py
Six new tests adapted from test/transformers/test_cross_entropy.py
to the [s, b, v] -> [s, b] reshape contract:
- test_class_correctness_scalar_sweep
(mirrors `test_correctness` with `scalar` param)
- test_class_correctness_with_ignore_index_sweep
(broader sentinel sweep incl. positive/weird negatives)
- test_class_correctness_with_label_smoothing_and_ignore_index
(the combined branch where smoothing math historically broke)
- test_class_correctness_not_last_layer
(loss × downstream factor, verify in-place grad survives)
- test_class_rejects_out_of_bounds_target
(kernel-level OOB assert reachable through the wrapper)
- test_class_correctness_forward_only
(torch.no_grad parity + post-hoc .backward() raises)
Plus an _assign_ignore_index helper matching the transformers-side
pattern of randomizing the masked-out positions.
test/megatron/test_monkey_patch.py
Ten new tests:
Public-API surface (mirrors transformers' test_import_from_root):
- test_import_from_root
- test_public_apply_function_has_no_ce_specific_kwargs
(guards against re-introducing the kwargs just trimmed)
End-to-end integration (closes the wrapper-math gap):
- test_patched_fused_symbol_computes_correct_loss
- test_patched_unfused_symbol_computes_correct_loss
- test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch
- test_patched_fused_symbol_preserves_gradients
- test_patched_unfused_symbol_preserves_gradients
- test_patched_fused_symbol_default_ignore_index_minus_100
(pins the "Liger > native fused" behavior — native silently
garbles target=-100; Liger zeros it)
Symmetric split:
- test_rms_norm_only_patch_does_not_touch_ce_symbols
95/95 megatron tests pass.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…suite per-kernel + add RMSNorm coverage
Two coverage gaps in the existing megatron suites, closed:
* test_cross_entropy.py's three base correctness tests (matches_pytorch,
respects_ignore_index, respects_label_smoothing) were forward-only.
test_class_preserves_gradients was misleadingly named — it asserted
grad is not None + shape, never that the gradient value was correct.
So a CE backward-math regression could slip past every base test.
* test_monkey_patch.py was scoped to cross-entropy (per its docstring
and stub installer name) even though apply_liger_kernel_to_megatron
patches both RMSNorm and CE. RMSNorm patch behavior had zero
monkey-patch-side coverage — replacement, idempotency, the
spec-provider dispatch, the "skip non-WrappedTorchNorm fallback"
contract were all unverified at the patch level.
test/megatron/test_cross_entropy.py
- test_class_matches_pytorch_cross_entropy: now feeds independent
clones to ref and Liger paths (Liger writes grad in-place to the
input; sharing the buffer would corrupt the reference), then asserts
h_got.grad matches h_ref.grad after backward through both.
- test_class_respects_ignore_index, test_class_respects_label_smoothing:
same upgrade — full forward + backward parity vs. F.cross_entropy.
- test_class_preserves_gradients: tightened body to assert grad value
parity against PyTorch's reference, not just "grad is not None".
Name preserved; docstring updated to call out the change.
test/megatron/test_monkey_patch.py
Restructure:
- Header docstring rewritten to describe per-kernel structure;
sections numbered 1-5 so future kernels append cleanly.
- _install_fake_megatron → _install_fake_megatron_ce (symmetric
with the new _install_fake_megatron_rms_norm).
- Fixture fake_megatron → fake_megatron_ce.
- New _ensure_megatron_roots helper so the two installers can be
called in either order without clobbering shared root modules.
- _uninstall_fake_megatron extended to clean both kernel slices.
New RMSNorm stub:
- _install_fake_megatron_rms_norm stubs megatron.core.models.backends
(LocalSpecProvider with a layer_norm method returning a sentinel)
and megatron.core.transformer.{transformer_block,torch_norm}
(LayerNormImpl, WrappedTorchNorm with __new__ returning a sentinel
instance). Toggle layer_norm_is_wrapped_torch_norm to test the
block-level skip path.
- fake_megatron_rms_norm fixture.
New RMSNorm tests (10), grouped 3.1–3.3:
Replacement & idempotency:
- test_rms_norm_patch_replaces_local_spec_provider_layer_norm
- test_rms_norm_patch_replaces_transformer_block_layernorm_impl
- test_rms_norm_patch_replaces_both_symbols_in_one_call
- test_rms_norm_patch_with_rms_norm_false_leaves_norm_symbols_untouched
- test_rms_norm_patch_is_idempotent
Dispatch through patched targets:
- test_rms_norm_patch_local_spec_provider_returns_liger_for_rms_norm_true
- test_rms_norm_patch_local_spec_provider_delegates_for_rms_norm_false
- test_rms_norm_patch_transformer_block_routes_rmsnorm_through_liger
(instantiates the wrapping class with a RMSNorm config and
confirms an actual LigerMegatronRMSNorm instance comes back;
no kernel call needed, runs on CPU)
- test_rms_norm_patch_transformer_block_routes_layernorm_through_original
Block-level skip contract:
- test_rms_norm_patch_transformer_block_skips_when_layer_norm_is_not_wrapped_torch_norm
(verifies Liger does not displace TE / Apex norm fallbacks)
105/105 megatron tests pass (was 95: +4 stronger CE grad-parity checks
across existing tests' parametrizations, +10 RMSNorm patch tests).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…script plotter
Two anti-patterns in the benchmark script unwound:
* Wrote to its own benchmark/data/all_benchmark_data_megatron.csv via a
runtime monkey-patch of update_benchmark_data_csv. Every other Liger
benchmark writes to the shared all_benchmark_data.csv and the rows are
partitioned by the kernel_name column — that's what the visualizer
keys off. The redirect was self-inflicted special-casing.
* Embedded a ~100-line _generate_plots function that re-implemented
benchmarks_visualizer.py's job, only differently enough to be a
maintenance liability (log-2 x-axis, fill_between band, hard-coded
title). The reason it existed was because the visualizer couldn't
find our redirected CSV. Once the CSV moves back to the shared file,
benchmarks_visualizer.py reads our rows out of the box.
benchmark/scripts/benchmark_megatron_cross_entropy.py
- Drop _CSV_FILENAME constant + _patched_update_benchmark_data_csv +
the benchmark_utils.update_benchmark_data_csv reassignment. Default
write path lands on benchmark/data/all_benchmark_data.csv.
- Drop the matplotlib import dance and the entire _generate_plots
function (and its invocation in __main__).
- Header docstring updated: replaces the "Output → PNG" promise with
the standard 2-step workflow command pointing at benchmarks_visualizer.
benchmark/data/all_benchmark_data.csv
Migrate the 96 existing H100 megatron rows in from the orphaned
per-component CSV (schemas matched verbatim — header line dropped,
rows appended). Preserves the measurements without requiring a GPU
re-run.
benchmark/data/all_benchmark_data_megatron.csv
Deleted — its data lives in the shared CSV now.
Workflow is now the same as every other Liger kernel benchmark:
python benchmark/scripts/benchmark_megatron_cross_entropy.py --overwrite
python benchmark/benchmarks_visualizer.py \
--kernel-name megatron_cross_entropy --metric-name speed
python benchmark/benchmarks_visualizer.py \
--kernel-name megatron_cross_entropy --metric-name memory
PNG outputs land in benchmark/visualizations/ (gitignored, as before).
Net: benchmark script shrinks from 348 → 207 lines; one tracked CSV
file consolidated away; visualizer slices megatron rows out of the
shared CSV with no changes on its side.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kolehma8
approved these changes
Jun 10, 2026
4 tasks
sunyi0505
pushed a commit
to sunyi0505/Liger-Kernel
that referenced
this pull request
Jun 11, 2026
## Summary Adds `benchmark/scripts/benchmark_megatron_rms_norm.py` — the RMSNorm parallel to the Megatron cross-entropy benchmark landed in linkedin#1207. Provides empirical speed/memory numbers so we can point at concrete data when claiming RMSNorm wins, instead of just citing the kernel sweep. Output goes to the shared `benchmark/data/all_benchmark_data.csv` (tagged `kernel_name="megatron_rms_norm"`), so the standard visualizer renders the plots: ```bash python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_rms_norm --metric-name speed python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_rms_norm --metric-name memory ``` ## Providers compared - **liger** — `LigerMegatronRMSNorm` (Liger's Triton RMSNorm via the Megatron-shaped wrapper from linkedin#1254) - **torch** — vanilla `torch.nn.RMSNorm` - **megatron** — Megatron's `WrappedTorchNorm` — its `__new__` returns `torch.nn.RMSNorm`, so timings should match `torch` exactly. Included for explicit parity confirmation since `WrappedTorchNorm` is the symbol Liger displaces in the local-backend path. If `megatron-core` is not installed, the `megatron` provider is silently dropped and the run proceeds with `liger` + `torch`. ## Results on H100 80GB (S=4096, B=1, bf16) 60 rows committed in the CSV (5 hidden sizes × 3 providers × 4 measurements). ### Speed — forward | H | liger | torch | speedup | |-------|----------|----------|---------| | 1024 | 0.012 ms | 0.013 ms | ~flat | | 2048 | 0.018 ms | 0.019 ms | ~6% | | 4096 | 0.029 ms | 0.033 ms | ~12% | | 8192 | 0.051 ms | 0.074 ms | ~31% | | 16384 | 0.095 ms | 0.149 ms | **~36%** | Liger forward wins, and the gap widens with hidden size. ### Speed — full (fwd + bwd) | H | liger | torch | notes | |-------|---------|----------|----------------------------------------| | 1024 | 0.30 ms | 0.075 ms | Liger SLOWER (Triton launch overhead) | | 4096 | 0.31 ms | 0.20 ms | Liger slower | | 8192 | 0.32 ms | 0.40 ms | crossover | | 16384 | 0.59 ms | 0.77 ms | **~24% faster** | The flat Liger curve at small H is the giveaway: kernel launch overhead dominates. `nn.RMSNorm`'s backward is a single fused C++/CUDA kernel; Liger's backward launches multiple Triton kernels (dx + dw reduction + element_mul). At small H the actual compute is tiny relative to per-launch overhead, so the math wins from Triton get drowned. At H ≥ ~6K the compute dominates and Liger wins. ### Memory — full Approximately neutral across all H — Liger uses 0.5–1% more than torch, within measurement noise. `nn.RMSNorm` is already a single fused CUDA kernel with minimal intermediates, so Liger doesn't get the activation-memory win it gets when replacing eager-PyTorch RMSNorm (which materializes variance + rsqrt + scale separately). Speed is the win in this comparison, not memory. ### Parity sanity check `torch` and `megatron` rows are bit-identical across the entire sweep — exactly what we'd expect since `WrappedTorchNorm` is a factory that returns `nn.RMSNorm`. Confirms Liger is replacing the right baseline. ## Honest read Production LLMs run at H ≥ 4096. At those sizes: - Forward: Liger wins ~12–36% - Full: Liger wins from ~H=6K onward - Memory: neutral So Liger's RMSNorm is a real speed improvement for typical training shapes. The tiny-H regression on full is launch overhead, not a numerical/math issue. ## Plots ### Memory <img width="1000" height="600" alt="megatron_rms_norm_memory_full_token_length" src="https://github.com/user-attachments/assets/168a8783-bbeb-4467-ab65-2e54a28d2882" /> ### Backward speed <img width="1000" height="600" alt="megatron_rms_norm_speed_backward_token_length" src="https://github.com/user-attachments/assets/31501a5f-082e-47f7-a596-760a5add9de7" /> ### Forward speed <img width="1000" height="600" alt="megatron_rms_norm_speed_forward_token_length" src="https://github.com/user-attachments/assets/b5084441-8011-473a-b9b4-9ff67bcee7c4" /> ### Forward + Backward speed <img width="1000" height="600" alt="megatron_rms_norm_speed_full_token_length" src="https://github.com/user-attachments/assets/5504491d-e15a-42b4-9c02-e5a3d75a781f" /> ## Testing Done - Hardware Type: H100 80GB HBM3 - [x] Benchmark runs end-to-end on H100 with all 3 providers active - [x] Standard visualizer renders all 4 PNGs (3 speed modes + memory) from the shared CSV - [x] `torch` / `megatron` providers produce bit-identical numbers (parity confirmation) - [x] `make checkstyle` passes Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds Liger's Triton cross-entropy kernel to Megatron-LM.
Megatron has two cross-entropy code paths. This PR replaces both with Liger:
fused_vocab_parallel_cross_entropy(used whencross_entropy_loss_fusion=True).vocab_parallel_cross_entropy(used whencross_entropy_loss_fusion=False).You get Liger's kernel no matter which path your Megatron config picks.
Liger's kernel uses online softmax. It writes gradients in place. It never builds the full softmax tensor. The net effect is faster training and lower memory use.
This PR also keeps the RMSNorm patch that was added in PR #1254. So one call patches both kernels at once.
How to use
There are two ways. Pick whichever fits your code better.
Mode 1: One-line monkey patch
Best for users who don't want to touch their model code. Call this once before you build the model:
That's the whole API. No CE-specific knobs. The defaults match Megatron's native behavior.
Mode 2: Use the classes directly
Best for users who build their model spec by hand or need custom config. Import the classes and wire them in yourself:
The class accepts the same knobs Liger's standalone CE accepts (
ignore_index,label_smoothing, etc.). Use Mode 2 when you need them.Scope
This PR supports
tensor_model_parallel_size=1only.Why? At TP>1, each GPU holds a slice of the vocabulary, not the full vocab. Cross-entropy then needs cross-rank all-reduces to get a correct loss. Liger's kernel does not do those all-reduces. So the math would be wrong at TP>1.
To stop you from training with a broken loss, the patch raises
RuntimeErrorif it sees TP>1. The check runs at patch time and again at call time, so misconfigurations fail loudly.[TODO]: TP>1 support is planned as follow-up work (a separate "vocab-parallel" Triton kernel).
You can still use this PR with pipeline parallelism, data parallelism, or single-GPU runs. Those are all TP=1 cases.
Testing
All unit tests live in
test/megatron/.Run all megatron tests
That runs three files:
test/megatron/test_rms_norm.py— RMSNorm class correctness (forward, backward, dtypes).test/megatron/test_cross_entropy.py—LigerMegatronCrossEntropyclass correctness. Covers shape sweeps,ignore_index,label_smoothing, combined config, out-of-bounds targets,torch.no_grad()path, and gradient parity vs. PyTorch'sF.cross_entropy.test/megatron/test_monkey_patch.py—apply_liger_kernel_to_megatron()plumbing. Covers symbol replacement for both CE paths and both RMSNorm targets, idempotency, missing-megatron errors, TP>1 guard, dispatch behavior, and end-to-end correctness through the patched symbols.You should see all tests pass (currently 105 tests in the megatron suite).
Run a single test file
What the tests do NOT need
Megatron-LM is not a test dependency. The monkey-patch tests use stub modules in
sys.modules, so they run on CPU without a real Megatron install. You only need PyTorch + Triton + Liger itself.The correctness tests (RMSNorm, CE) need a GPU because they run the actual Triton kernel.
Benchmarking
Benchmarked against
megatron-core==0.19.0+96a894ef6Step 1: Run the benchmark
This compares four providers across vocab sizes 4K → 128K (with S=2048, B=4, BF16):
LigerMegatronCrossEntropyF.cross_entropy@jit_fuservia TorchScript)The script needs a CUDA/ROCm GPU. If Megatron-LM is not installed, the two megatron providers are silently skipped and you only get
ligervstorch.Results land in
benchmark/data/all_benchmark_data.csv. Each row is tagged withkernel_name="megatron_cross_entropy".Step 2: Render plots
python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_cross_entropy --metric-name speed python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_cross_entropy --metric-name memoryPNGs land in
benchmark/visualizations/.Memory savings
Backward speedup
Forward speedup
Backward+Forward speedup
Example training scripts
Two end-to-end examples in
examples/megatron/. Both train a tiny GPT model with mock data for 5 iterations. They print which classes filled which norm slots and which function backs each CE entry point, so you can see Liger took over.Prerequisites
pip install megatron-core).liger-kernelinstalled.psutil(Megatron's async checkpoint worker pool uses it).Example 1: Mode 1 (monkey patch)
torchrun --nproc_per_node=2 \ --master_addr=127.0.0.1 --master_port=29500 \ examples/megatron/run_mode1_monkey_patch.pyOne
apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True)call at the top. Everything else is stock Megatron code.Example 2: Mode 2 (hand-built spec)
torchrun --nproc_per_node=2 \ --master_addr=127.0.0.1 --master_port=29500 \ examples/megatron/run_mode2_hand_spec.pyNo monkey patch. The script builds a
TransformerBlockSubmoduleswithLigerMegatronRMSNormslots and subclassesGPTModelto useLigerMegatronCrossEntropyincompute_language_model_loss.What you should see
For both:
[modeN] iter <i> loss=<float>with the loss going down.LigerMegatronRMSNormin 5 of 5 norm slots (4 per-layer + 1 block-level final).liger_fused_vocab_parallel_cross_entropyandliger_vocab_parallel_cross_entropy(Mode 1) orLigerMegatronCrossEntropy(Mode 2).Successfully loaded the modelafter a distributed checkpoint round-trip.If Apex or TransformerEngine isn't installed, you'll see harmless warnings. Megatron falls back to the local backend, which is where Liger plugs in.
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence