From 4ffd93b061fd41fdc542966e10ceb231a0de98d9 Mon Sep 17 00:00:00 2001 From: Prathyusha Polepalli Date: Wed, 29 Apr 2026 16:24:35 -0700 Subject: [PATCH 1/7] Add Megatron-LM cross-entropy integration --- .../benchmark_megatron_cross_entropy.py | 179 ++++++++++++++++++ docs/High-Level-APIs.md | 42 ++++ src/liger_kernel/megatron/__init__.py | 3 + src/liger_kernel/megatron/cross_entropy.py | 81 ++++++++ src/liger_kernel/megatron/monkey_patch.py | 81 ++++++++ test/megatron/__init__.py | 0 test/megatron/test_cross_entropy.py | 157 +++++++++++++++ test/megatron/test_monkey_patch.py | 167 ++++++++++++++++ 8 files changed, 710 insertions(+) create mode 100644 benchmark/scripts/benchmark_megatron_cross_entropy.py create mode 100644 src/liger_kernel/megatron/__init__.py create mode 100644 src/liger_kernel/megatron/cross_entropy.py create mode 100644 src/liger_kernel/megatron/monkey_patch.py create mode 100644 test/megatron/__init__.py create mode 100644 test/megatron/test_cross_entropy.py create mode 100644 test/megatron/test_monkey_patch.py diff --git a/benchmark/scripts/benchmark_megatron_cross_entropy.py b/benchmark/scripts/benchmark_megatron_cross_entropy.py new file mode 100644 index 000000000..6eca12edc --- /dev/null +++ b/benchmark/scripts/benchmark_megatron_cross_entropy.py @@ -0,0 +1,179 @@ +"""Benchmark Liger's Megatron-LM cross-entropy wrapper. + +Benchmarks the Liger [seq, batch, vocab] cross-entropy wrapper against PyTorch's +native ``F.cross_entropy`` on equivalent input shapes. When megatron-core is +installed, Megatron's own ``fused_vocab_parallel_cross_entropy`` is added as a +third provider to reproduce end-to-end comparisons. + +Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not +installed, the "megatron" provider is silently skipped. +""" + +import torch +import torch.nn.functional as F +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.megatron.cross_entropy import _build_wrapper +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.utils import infer_device + +device = infer_device() + +try: + from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy + + _MEGATRON_AVAILABLE = True +except ImportError: + fused_vocab_parallel_cross_entropy = None + _MEGATRON_AVAILABLE = False + + +def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True): + logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + return logits, target + + +def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + s, b, v = logits.shape + return F.cross_entropy( + logits.reshape(-1, v).float(), + target.reshape(-1), + reduction="none", + ).reshape(s, b) + + +def _ensure_single_rank_tp_group(): + """Initialize torch.distributed (single-rank) and return a usable TP group. + + For a single-process benchmark we use the world group of + size 1, where the internal all-reduce becomes a no-op. + """ + import os + + import torch.distributed as dist + + if not dist.is_initialized(): + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("LOCAL_RANK", "0") + dist.init_process_group(backend="nccl") + return dist.group.WORLD + + +def _select_fwd(provider: str): + if provider == "liger": + wrapper = _build_wrapper(LigerCrossEntropyLoss(reduction="none")) + return wrapper + if provider == "torch": + return _pytorch_cross_entropy + if provider == "megatron": + if not _MEGATRON_AVAILABLE: + raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider") + tp_group = _ensure_single_rank_tp_group() + + def _megatron_call(logits, target): + return fused_vocab_parallel_cross_entropy(logits, target, tp_group) + + return _megatron_call + raise ValueError(f"unknown provider: {provider!r}") + + +def bench_speed_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + v = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + s = input.extra_benchmark_config["S"] + b = input.extra_benchmark_config["B"] + + logits, target = _make_inputs(s, b, v) + fwd_fn = _select_fwd(provider) + + def fwd(): + return fwd_fn(logits, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + # Megatron's fused CE writes gradients in-place into saved tensors during backward, + # which breaks the standard retain_graph=True / repeated-backward pattern do_bench + # uses elsewhere. Run a fresh fwd+bwd each iteration so each backward sees an + # unmodified autograd graph. Measurement therefore includes forward time — + # subtract the "forward" measurement to derive backward-only timing. + def _fwd_bwd(): + if logits.grad is not None: + logits.grad = None + out = fwd_fn(logits, target) + out.sum().backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES) + elif mode == "full": + + def full(): + y = fwd() + y.sum().backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"unknown mode: {mode!r}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + v = input.x + provider = input.kernel_provider + s = input.extra_benchmark_config["S"] + b = input.extra_benchmark_config["B"] + + logits, target = _make_inputs(s, b, v) + fwd_fn = _select_fwd(provider) + + def full(): + y = fwd_fn(logits, target) + y.sum().backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + providers = ["liger", "torch"] + if _MEGATRON_AVAILABLE: + providers.append("megatron") + + common_configs = { + "kernel_name": "megatron_cross_entropy", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, 18)], + "kernel_providers": providers, + "extra_benchmark_configs": [{"S": 2048, "B": 4}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_megatron_cross_entropy, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_megatron_cross_entropy, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/docs/High-Level-APIs.md b/docs/High-Level-APIs.md index 5433e03d3..b90fc6e96 100644 --- a/docs/High-Level-APIs.md +++ b/docs/High-Level-APIs.md @@ -91,3 +91,45 @@ You can also use the Patching APIs to use the kernels for a specific model archi extra: show_docstring: true show_signature: true + +--- + +## Megatron-LM + +Liger also exposes a patch for the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) +training framework, replacing Megatron's native +`fused_vocab_parallel_cross_entropy` with Liger's Triton cross-entropy kernel. + +| **Framework** | **API** | **Supported Operations** | +|---------------|--------------------------------------------------------|--------------------------| +| Megatron-LM | `liger_kernel.megatron.apply_liger_kernel_to_megatron` | CrossEntropyLoss | + +**Scope**: Initial release supports `tensor_model_parallel_size=1` only. +Vocab-parallel cross-entropy (TP>1) is follow-up work — with TP>1, each rank +holds a sharded `[N, V/tp]` logits slice and cross-entropy requires cross-rank +all-reduces that Liger's kernel does not perform. The patch raises a +`RuntimeError` at patch time or call time if TP>1 is detected. + +**Usage**: + +```python +from liger_kernel.megatron import apply_liger_kernel_to_megatron + +# Call before Megatron's forward pass reaches compute_language_model_loss. +# Match Megatron's config: pass the same ignore_index and label_smoothing +# values used by your training setup (Liger does not auto-detect them). +apply_liger_kernel_to_megatron( + ignore_index=-100, + label_smoothing=cfg.label_smoothing_factor, +) +``` + +Ensure Megatron's fused-CE code path is enabled in your training config (e.g. +`--cross-entropy-loss-fusion` in the Megatron-LM CLI) — if the unfused path is +selected, the patched symbol is never called. + +::: liger_kernel.megatron.apply_liger_kernel_to_megatron + options: + extra: + show_docstring: true + show_signature: true diff --git a/src/liger_kernel/megatron/__init__.py b/src/liger_kernel/megatron/__init__.py new file mode 100644 index 000000000..6b31b091b --- /dev/null +++ b/src/liger_kernel/megatron/__init__.py @@ -0,0 +1,3 @@ +from liger_kernel.megatron.monkey_patch import apply_liger_kernel_to_megatron + +__all__ = ["apply_liger_kernel_to_megatron"] diff --git a/src/liger_kernel/megatron/cross_entropy.py b/src/liger_kernel/megatron/cross_entropy.py new file mode 100644 index 000000000..a4a1ec12f --- /dev/null +++ b/src/liger_kernel/megatron/cross_entropy.py @@ -0,0 +1,81 @@ +import torch + +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + + +def _build_wrapper(loss_fn: LigerCrossEntropyLoss): + """Build a drop-in replacement for ``fused_vocab_parallel_cross_entropy``. + + The returned callable has exactly the same parameter list Megatron expects + (``vocab_parallel_logits``, ``target``, ``tp_group``). Any unknown kwargs + will raise ``TypeError`` naturally — this is intentional: if a future + Megatron release adds new parameters to the fused-CE contract, we want to + fail loudly rather than silently drop them. + """ + + def liger_fused_vocab_parallel_cross_entropy( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + tp_group=None, + ) -> torch.Tensor: + if tp_group is not None and hasattr(tp_group, "size") and tp_group.size() > 1: + raise RuntimeError( + f"Liger Megatron cross-entropy wrapper requires tensor_model_parallel_size=1, " + f"got tp_group.size()={tp_group.size()}. Vocab-parallel support is tracked as " + f"follow-up work." + ) + + s, b, v = vocab_parallel_logits.shape + logits_2d = vocab_parallel_logits.reshape(-1, v) + target_1d = target.reshape(-1) + loss = loss_fn(logits_2d, target_1d) + return loss.reshape(s, b) + + return liger_fused_vocab_parallel_cross_entropy + + +def _patch_fused_vocab_parallel_cross_entropy( + reduction: str = "none", + ignore_index: int = -100, + label_smoothing: float = 0.0, +) -> None: + """Replace ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``. + + See ``apply_liger_kernel_to_megatron`` in ``monkey_patch.py`` for the public + entry point; this helper holds the cross-entropy-specific patching logic so + that future Megatron kernel integrations can sit alongside it without + polluting the framework-level apply function. + + Args: + reduction: Must be ``"none"``; Megatron's fused-CE contract returns + per-token loss shaped ``[seq, batch]``. + ignore_index: Target index to ignore. + label_smoothing: Cross-entropy label smoothing factor. + """ + assert reduction == "none", ( + f"Megatron's fused_vocab_parallel_cross_entropy contract requires per-token loss; " + f"reduction must be 'none', got {reduction!r}." + ) + + try: + import megatron.core.fusions.fused_cross_entropy as fce + except ImportError as exc: + raise ImportError( + "apply_liger_kernel_to_megatron requires megatron-core to be installed. " + "Expected symbol path: " + "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy." + ) from exc + + if not hasattr(fce, "fused_vocab_parallel_cross_entropy"): + raise ImportError( + "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy not found. " + "The symbol path may have changed in your Megatron-LM version. Please file an issue " + "on https://github.com/linkedin/Liger-Kernel with your megatron-core version." + ) + + loss_fn = LigerCrossEntropyLoss( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction="none", + ) + fce.fused_vocab_parallel_cross_entropy = _build_wrapper(loss_fn) diff --git a/src/liger_kernel/megatron/monkey_patch.py b/src/liger_kernel/megatron/monkey_patch.py new file mode 100644 index 000000000..652a2b274 --- /dev/null +++ b/src/liger_kernel/megatron/monkey_patch.py @@ -0,0 +1,81 @@ +from liger_kernel.megatron.cross_entropy import _patch_fused_vocab_parallel_cross_entropy + + +def _check_tensor_parallel_size_at_patch_time() -> None: + """Raise RuntimeError if Megatron's parallel state already reports TP>1. + + If Megatron is importable but the parallel state is not yet initialized + (for example, ``apply_liger_kernel_to_megatron`` is called before + ``initialize_megatron``), silently defer; per-kernel wrappers check again + at call time against the ``tp_group`` argument Megatron supplies. + """ + try: + from megatron.core import parallel_state + except ImportError: + return + try: + tp_size = parallel_state.get_tensor_model_parallel_world_size() + except (AssertionError, RuntimeError): + return + if tp_size > 1: + raise RuntimeError( + f"apply_liger_kernel_to_megatron currently requires tensor_model_parallel_size=1, " + f"got {tp_size}. Vocab-parallel cross-entropy support is planned as follow-up work." + ) + + +def apply_liger_kernel_to_megatron( + reduction: str = "none", + ignore_index: int = -100, + label_smoothing: float = 0.0, +) -> None: + """Replace Megatron-LM's fused_vocab_parallel_cross_entropy with Liger's Triton cross-entropy. + + This monkey-patches + ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy`` + so that Megatron training pipelines use Liger's Triton kernel (online + softmax, in-place gradients, no full-softmax materialization) instead of + Megatron's native fused implementation. + + Args: + reduction: Must be ``"none"``; Megatron's fused-CE contract returns + per-token loss shaped ``[seq, batch]`` and handles reduction itself + downstream. + ignore_index: Target index to ignore. Pass the value used in your + Megatron training config. + label_smoothing: Cross-entropy label smoothing factor. Liger does not + auto-detect this — callers should pass + ``cfg.label_smoothing_factor`` (or equivalent) from their + Megatron ``TransformerConfig`` if label smoothing is enabled, to + preserve the native behavior. + + Scope: + Initial release supports ``tensor_model_parallel_size=1`` only. With + TP>1, each rank holds a vocab-sharded logits slice ``[N, V/tp]`` and + computing cross-entropy requires cross-rank all-reduces that Liger's + kernel does not perform. A ``RuntimeError`` is raised at patch time if + the Megatron parallel state already reports TP>1, and again at call + time if a multi-rank ``tp_group`` is passed. + + Raises: + AssertionError: If ``reduction != "none"``. + ImportError: If ``megatron.core.fusions.fused_cross_entropy`` is not + importable, or if the expected + ``fused_vocab_parallel_cross_entropy`` symbol is missing from that + module (indicating an incompatible Megatron version). + RuntimeError: If tensor model parallelism > 1 is detected. + + Example: + >>> from liger_kernel.megatron import apply_liger_kernel_to_megatron + >>> apply_liger_kernel_to_megatron( + ... ignore_index=-100, + ... label_smoothing=cfg.label_smoothing_factor, + ... ) + >>> # call before Megatron's forward pass reaches compute_language_model_loss + """ + _check_tensor_parallel_size_at_patch_time() + _patch_fused_vocab_parallel_cross_entropy( + reduction=reduction, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) diff --git a/test/megatron/__init__.py b/test/megatron/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/megatron/test_cross_entropy.py b/test/megatron/test_cross_entropy.py new file mode 100644 index 000000000..f995b051e --- /dev/null +++ b/test/megatron/test_cross_entropy.py @@ -0,0 +1,157 @@ +"""Correctness tests for the Liger Megatron cross-entropy wrapper. + +These tests exercise ``_build_wrapper`` directly without importing +megatron-core — the wrapper is the [s, b, v] -> [s, b] reshape shim around +``LigerCrossEntropyLoss`` and is meaningful to test on its own. + +The wrapper calls the underlying Triton kernel, so these tests require a +Liger-supported accelerator (same as ``test/transformers/test_cross_entropy.py``). +""" + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.megatron.cross_entropy import _build_wrapper +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.utils import infer_device +from test.utils import assert_verbose_allclose +from test.utils import set_seed +from test.utils import supports_bfloat16 + +device = infer_device() +set_seed(42) + + +def _make_wrapper(ignore_index: int = -100, label_smoothing: float = 0.0): + loss_fn = LigerCrossEntropyLoss( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction="none", + ) + return _build_wrapper(loss_fn) + + +def _reference_loss( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + ignore_index: int, + label_smoothing: float, +) -> torch.Tensor: + s, b, v = vocab_parallel_logits.shape + loss_flat = F.cross_entropy( + vocab_parallel_logits.reshape(-1, v).float(), + target.reshape(-1), + reduction="none", + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) + return loss_flat.reshape(s, b) + + +@pytest.mark.parametrize( + "s, b, v", + [ + (8, 2, 128), + (16, 4, 4096), + (32, 1, 32000), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_wrapper_matches_pytorch_cross_entropy(s, b, v, dtype, atol, rtol): + wrapper = _make_wrapper() + + logits = torch.randn(s, b, v, device=device, dtype=dtype) * 0.5 + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + ref = _reference_loss(logits, target, ignore_index=-100, label_smoothing=0.0) + got = wrapper(logits, target) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("ignore_index", [-100, 0]) +def test_wrapper_respects_ignore_index(ignore_index): + s, b, v = 16, 2, 1024 + wrapper = _make_wrapper(ignore_index=ignore_index) + + logits = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + target.view(-1)[: (s * b) // 4] = ignore_index + + ref = _reference_loss(logits, target, ignore_index=ignore_index, label_smoothing=0.0) + got = wrapper(logits, target) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) +def test_wrapper_respects_label_smoothing(label_smoothing): + s, b, v = 8, 2, 512 + wrapper = _make_wrapper(label_smoothing=label_smoothing) + + logits = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + ref = _reference_loss(logits, target, ignore_index=-100, label_smoothing=label_smoothing) + got = wrapper(logits, target) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4) + + +def test_wrapper_rejects_unknown_kwargs(): + wrapper = _make_wrapper() + logits = torch.randn(4, 1, 32, device=device) + target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) + with pytest.raises(TypeError): + wrapper(logits, target, unknown_arg=123) + + +def test_wrapper_rejects_multi_rank_tp_group(): + wrapper = _make_wrapper() + logits = torch.randn(4, 1, 32, device=device) + target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) + + class _FakeGroup: + def size(self): + return 2 + + with pytest.raises(RuntimeError, match="tensor_model_parallel_size=1"): + wrapper(logits, target, tp_group=_FakeGroup()) + + +def test_wrapper_accepts_single_rank_tp_group(): + wrapper = _make_wrapper() + logits = torch.randn(4, 1, 32, device=device) + target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) + + class _FakeGroup: + def size(self): + return 1 + + out = wrapper(logits, target, tp_group=_FakeGroup()) + assert out.shape == (4, 1) + + +def test_wrapper_preserves_gradients(): + s, b, v = 8, 2, 256 + wrapper = _make_wrapper() + + logits = torch.randn(s, b, v, device=device, dtype=torch.float32, requires_grad=True) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + loss = wrapper(logits, target).sum() + loss.backward() + + assert logits.grad is not None + assert logits.grad.shape == logits.shape diff --git a/test/megatron/test_monkey_patch.py b/test/megatron/test_monkey_patch.py new file mode 100644 index 000000000..642ba403e --- /dev/null +++ b/test/megatron/test_monkey_patch.py @@ -0,0 +1,167 @@ +"""Tests for apply_liger_kernel_to_megatron's patch mechanism. + +Megatron-LM is not a test dependency. We inject stub modules into +``sys.modules`` so the patch function can run entirely on CPU without a real +megatron-core install. Tests verify: + +- the patch replaces ``fused_vocab_parallel_cross_entropy`` on the stub module +- ``reduction != "none"`` is rejected +- TP>1 at patch time raises RuntimeError +- missing megatron-core raises a helpful ImportError +- missing symbol path raises a helpful ImportError +- the constructed LigerCrossEntropyLoss receives the user-supplied kwargs +""" + +import sys +import types + +from unittest.mock import patch + +import pytest + + +def _install_fake_megatron(tp_size: int = 1, with_fused_symbol: bool = True): + """Install stub megatron modules into sys.modules; return the fused module.""" + megatron = types.ModuleType("megatron") + megatron_core = types.ModuleType("megatron.core") + fusions = types.ModuleType("megatron.core.fusions") + fused_ce = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + parallel_state = types.ModuleType("megatron.core.parallel_state") + + if with_fused_symbol: + + def original_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_group=None): + raise AssertionError("original megatron kernel called — patch failed") + + fused_ce.fused_vocab_parallel_cross_entropy = original_fused_vocab_parallel_cross_entropy + + parallel_state.get_tensor_model_parallel_world_size = lambda: tp_size + + sys.modules["megatron"] = megatron + sys.modules["megatron.core"] = megatron_core + sys.modules["megatron.core.fusions"] = fusions + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused_ce + sys.modules["megatron.core.parallel_state"] = parallel_state + + megatron.core = megatron_core + megatron_core.fusions = fusions + megatron_core.parallel_state = parallel_state + fusions.fused_cross_entropy = fused_ce + + return fused_ce + + +def _uninstall_fake_megatron(): + for mod in [ + "megatron.core.parallel_state", + "megatron.core.fusions.fused_cross_entropy", + "megatron.core.fusions", + "megatron.core", + "megatron", + ]: + sys.modules.pop(mod, None) + + +@pytest.fixture +def fake_megatron(): + fused_ce = _install_fake_megatron(tp_size=1) + try: + yield fused_ce + finally: + _uninstall_fake_megatron() + + +def test_patch_replaces_fused_symbol(fake_megatron): + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = fake_megatron.fused_vocab_parallel_cross_entropy + apply_liger_kernel_to_megatron() + patched = fake_megatron.fused_vocab_parallel_cross_entropy + + assert patched is not original + assert patched.__name__ == "liger_fused_vocab_parallel_cross_entropy" + + +def test_patch_rejects_non_none_reduction(fake_megatron): + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(AssertionError, match="reduction must be 'none'"): + apply_liger_kernel_to_megatron(reduction="mean") + + +def test_patch_raises_on_tp_greater_than_one(): + _install_fake_megatron(tp_size=2) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(RuntimeError, match="tensor_model_parallel_size=1"): + apply_liger_kernel_to_megatron() + finally: + _uninstall_fake_megatron() + + +def test_patch_defers_tp_check_when_parallel_state_not_initialized(): + """If get_tensor_model_parallel_world_size() raises, patch should still succeed.""" + fused_ce = _install_fake_megatron(tp_size=1) + + def raising_tp_size(): + raise AssertionError("parallel_state not initialized") + + sys.modules["megatron.core.parallel_state"].get_tensor_model_parallel_world_size = raising_tp_size + + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron() + assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" + finally: + _uninstall_fake_megatron() + + +def test_patch_raises_when_megatron_not_installed(): + _uninstall_fake_megatron() + # Block imports of any "megatron*" module to simulate absent install. + real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def blocking_import(name, *args, **kwargs): + if name == "megatron" or name.startswith("megatron."): + raise ImportError(f"No module named {name!r}") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=blocking_import): + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(ImportError, match="requires megatron-core"): + apply_liger_kernel_to_megatron() + + +def test_patch_raises_when_fused_symbol_missing(): + _install_fake_megatron(tp_size=1, with_fused_symbol=False) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(ImportError, match="symbol path may have changed"): + apply_liger_kernel_to_megatron() + finally: + _uninstall_fake_megatron() + + +def test_patch_forwards_ignore_index_and_label_smoothing(fake_megatron): + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + captured = {} + + class FakeLoss: + def __init__(self, ignore_index, label_smoothing, reduction): + captured["ignore_index"] = ignore_index + captured["label_smoothing"] = label_smoothing + captured["reduction"] = reduction + + def __call__(self, _input, target): + raise AssertionError("not expected to be called in this test") + + with patch.object(ce_mod, "LigerCrossEntropyLoss", FakeLoss): + apply_liger_kernel_to_megatron(ignore_index=42, label_smoothing=0.25) + + assert captured == {"ignore_index": 42, "label_smoothing": 0.25, "reduction": "none"} From 7097ab9c4f1f17176e1bd493ad400e89312df0c0 Mon Sep 17 00:00:00 2001 From: Vaibhav Jindal Date: Wed, 10 Jun 2026 00:56:03 +0000 Subject: [PATCH 2/7] Megatron CE: ship LigerMegatronCrossEntropy + patch both fused & unfused paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on the merged-from-main entry point (apply_liger_kernel_to_megatron with rms_norm/cross_entropy flags) so this PR cleanly extends the CE side to match the RMSNorm style PR #1254 established. src/liger_kernel/megatron/cross_entropy.py Strip to just LigerMegatronCrossEntropy — an nn.Module drop-in for Megatron's vocab-parallel CE matching the fused signature (logits, target, tp_group=None). Single source of truth for the CE math; the monkey-patch wrappers in monkey_patch.py instantiate this class. Mirrors LigerMegatronRMSNorm. Reads no megatron-core imports. 3-D shape guard + ValueError for bad reduction. src/liger_kernel/megatron/monkey_patch.py Add _patch_fused_vocab_parallel_cross_entropy AND _patch_vocab_parallel_cross_entropy. Both wired into the cross_entropy=True branch so a single apply call replaces both Megatron CE paths. Both guarded by _PATCH_MARKER for idempotence (matching the RMSNorm helpers). Unfused wrapper uses a sentinel _LABEL_SMOOTHING_UNSET to distinguish "caller didn't pass" from "caller explicitly passed 0.0" — preserves Megatron's contract. src/liger_kernel/megatron/__init__.py Export LigerMegatronCrossEntropy alongside LigerMegatronRMSNorm. test/megatron/test_cross_entropy.py Rewrite tests around the class instead of the old _build_wrapper. Same parametrize matrix as test_rms_norm.py. New: 3-D shape guard, reduction rejection sweep, full extra_repr check. test/megatron/test_monkey_patch.py Extend _install_fake_megatron to stub the unfused module too. Add: - test_patch_replaces_unfused_symbol - test_patch_replaces_both_fused_and_unfused_symbols_in_one_call - test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched - test_patch_is_idempotent_for_both_symbols - test_patch_fused_wrapper_passes_tp_group_through - test_patch_raises_when_unfused_symbol_missing - test_unfused_wrapper_honors_runtime_label_smoothing - test_unfused_wrapper_uses_default_when_caller_does_not_pass_label_smoothing - test_unfused_wrapper_honors_explicit_zero_label_smoothing examples/megatron/run_mode{1,2}_*.py Mode 1 now patches both RMSNorm and CE in one call with label_smoothing=0.1 and prints both patched CE symbols. Mode 2 adds a _LigerCEGPTModel(GPTModel) subclass that overrides compute_language_model_loss to use LigerMegatronCrossEntropy directly (CE has no spec slot, so subclassing is the symmetric hand-built path). Both run at TP=1 + DP=2 — CE patches are TP=1 only. benchmark/scripts/benchmark_megatron_cross_entropy.py Add "megatron-unfused" provider. Switch CSV output to benchmark/data/all_benchmark_data_megatron.csv via a surgical patch on update_benchmark_data_csv (LIGER_KERNEL_IMPL would have triggered backend selection too). Add optional matplotlib plot generation under benchmark/visualizations/megatron_cross_entropy_*.png. 48/48 tests pass. End-to-end mode-1 and mode-2 verified on 2× H100. Benchmark shows Liger CE 2-5x faster than both Megatron paths and 40% less memory than the fused path at vocab=128K, S=2048, B=4, bf16. --- .../data/all_benchmark_data_megatron.csv | 97 +++++ .../benchmark_megatron_cross_entropy.py | 200 +++++++++- examples/megatron/run_mode1_monkey_patch.py | 51 ++- examples/megatron/run_mode2_hand_spec.py | 76 +++- src/liger_kernel/megatron/__init__.py | 17 +- src/liger_kernel/megatron/cross_entropy.py | 129 +++---- src/liger_kernel/megatron/monkey_patch.py | 153 +++++++- test/megatron/test_cross_entropy.py | 115 +++--- test/megatron/test_monkey_patch.py | 348 ++++++++++++++++-- 9 files changed, 1003 insertions(+), 183 deletions(-) create mode 100644 benchmark/data/all_benchmark_data_megatron.csv diff --git a/benchmark/data/all_benchmark_data_megatron.csv b/benchmark/data/all_benchmark_data_megatron.csv new file mode 100644 index 000000000..496f0f904 --- /dev/null +++ b/benchmark/data/all_benchmark_data_megatron.csv @@ -0,0 +1,97 @@ +kernel_name,kernel_provider,kernel_operation_mode,metric_name,metric_unit,x_name,x_label,x_value,y_value_50,y_value_20,y_value_80,extra_benchmark_config_str,gpu_name,timestamp,liger_version +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,4096,0.23187200725078583,0.22867199778556824,0.23430399596691132,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,8192,0.26259198784828186,0.26010880470275877,0.26554239392280576,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,16384,0.35471999645233154,0.3516608059406281,0.3578752100467682,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,32768,0.6460640132427216,0.6421120166778564,0.6500800251960754,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,65536,1.0836479663848877,1.0789376258850099,1.0899327754974366,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,131072,2.107487916946411,2.10230393409729,2.1125311851501465,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,4096,0.22684800624847412,0.22604799270629883,0.22790400683879852,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,8192,0.48054400086402893,0.4792320132255554,0.48161279559135434,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,16384,0.8116160035133362,0.8104000091552734,0.8132799863815308,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,32768,2.1355679035186768,2.1311680316925052,2.1407871723175047,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,65536,4.391728162765503,4.389222240447998,4.394495964050293,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,131072,8.81164836883545,8.805439949035645,8.816160202026367,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,4096,0.3126560002565384,0.312063992023468,0.3132735908031463,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,8192,0.6373920142650604,0.6365439891815186,0.6382399797439575,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,16384,1.3829439878463745,1.3810559511184692,1.3850752115249634,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,32768,2.829360008239746,2.8282368659973143,2.831148767471314,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,65536,5.815711975097656,5.807616233825684,5.824639797210693,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,131072,11.804255962371826,11.790649795532227,11.81064338684082,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,4096,0.5651199817657471,0.5638527750968934,0.5661119818687439,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,8192,1.095136046409607,1.0940415859222412,1.0963200330734253,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,16384,2.069024085998535,2.0674240589141846,2.0703680515289307,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,32768,4.026576042175293,4.025190353393555,4.027974319458008,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,65536,7.939584016799927,7.937094402313232,7.94121618270874,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,131072,15.816815853118896,15.814399719238281,15.81926441192627,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,4096,0.3551519960165024,0.3523648023605347,0.3593983888626099,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,8192,0.466608002781868,0.46348801255226135,0.470521605014801,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,16384,0.7164639830589294,0.7130944013595581,0.7210240125656128,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,32768,1.326367974281311,1.3212159872055054,1.3306879997253418,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,65536,2.4017759561538696,2.392825651168823,2.412748908996582,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,131072,4.6971518993377686,4.690547180175781,4.702419185638428,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,4096,0.4853439927101135,0.48447999358177185,0.4862591981887817,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,8192,0.975488007068634,0.9742976188659669,0.9764159917831421,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,16384,1.8761919736862183,1.8743679523468018,1.8787519931793213,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,32768,4.398000001907349,4.3941184997558596,4.408435344696045,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,65536,8.924351692199707,8.919648170471191,8.92563247680664,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,131072,17.839040756225586,17.836934661865236,17.843289947509767,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,4096,0.5374079942703247,0.5366335868835449,0.5382080078125,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,8192,1.0536960363388062,1.0528128385543822,1.0545535564422608,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,16384,2.1704800128936768,2.1692031383514405,2.172160005569458,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,32768,4.446431875228882,4.445299243927002,4.4482752799987795,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,65536,9.035776138305664,9.031341171264648,9.042457962036133,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,131072,17.946975708007812,17.945389556884766,17.95260887145996,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,4096,0.7788800001144409,0.7774208068847656,0.780454421043396,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,8192,1.4757919907569885,1.4740799903869628,1.4768768072128298,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,16384,2.786736011505127,2.7851327419281007,2.7889408111572265,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,32768,5.420752048492432,5.417894268035889,5.422150421142579,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,65536,10.675616264343262,10.674079704284669,10.679340934753418,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,131072,21.251296043395996,21.248454666137697,21.253714752197265,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,4096,0.4208959937095642,0.41683200001716614,0.4256959855556488,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,8192,0.5997599959373474,0.5965440273284912,0.6039680242538452,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,16384,0.9797760248184204,0.9747712016105652,0.9856639981269837,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,32768,1.8492159843444824,1.8442879915237427,1.85481595993042,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,65536,3.455839991569519,3.4454591274261475,3.471328020095825,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,131072,6.783936023712158,6.772691154479981,6.791551971435547,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,4096,0.5539519786834717,0.5527359843254089,0.5549759864807129,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,8192,1.106719970703125,1.1052928209304809,1.1081023931503298,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,16384,2.137727975845337,2.1360192775726317,2.139878463745117,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,32768,4.921311855316162,4.918290996551514,4.929247951507568,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,65536,9.962112426757812,9.959987258911132,9.963628578186036,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,131072,19.926719665527344,19.91397705078125,19.931795120239258,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,4096,0.6056640148162842,0.6047423720359802,0.606719970703125,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,8192,1.1874719858169556,1.1862783908843995,1.188646388053894,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,16384,2.434976100921631,2.4332736015319822,2.4367167949676514,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,32768,4.970719814300537,4.968512058258057,4.9720383644104,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,65536,10.077088356018066,10.070701026916502,10.083423614501953,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,131072,20.020895957946777,20.018156433105467,20.022406387329102,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,4096,0.8470079898834229,0.8454336166381836,0.8486016154289245,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,8192,1.6092480421066284,1.6075520515441895,1.6106047868728637,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,16384,3.04915189743042,3.0480000972747803,3.0506880283355713,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,32768,5.9349119663238525,5.934432029724121,5.936927795410156,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,65536,11.71017599105835,11.70949764251709,11.711187362670898,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,131072,23.321871757507324,23.319308853149415,23.324396514892577,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,4096,192.0791015625,192.0791015625,192.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,8192,384.0791015625,384.0791015625,384.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,16384,768.0791015625,768.0791015625,768.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,32768,1536.0791015625,1536.0791015625,1536.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,65536,3072.0791015625,3072.0791015625,3072.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,131072,6144.0791015625,6144.0791015625,6144.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,4096,512.0947265625,512.0947265625,512.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,8192,1024.0947265625,1024.0947265625,1024.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,16384,2048.0947265625,2048.0947265625,2048.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,32768,4096.0947265625,4096.0947265625,4096.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,65536,8192.0947265625,8192.0947265625,8192.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,131072,16384.09375,16384.09375,16384.09375,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,4096,448.1650390625,448.1650390625,448.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,8192,896.1650390625,896.1650390625,896.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,16384,1792.1650390625,1792.1650390625,1792.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,32768,3584.1650390625,3584.1650390625,3584.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,65536,7168.1650390625,7168.1650390625,7168.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,131072,14336.1650390625,14336.1650390625,14336.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,4096,320.1650390625,320.1650390625,320.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,8192,640.1650390625,640.1650390625,640.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,16384,1280.1650390625,1280.1650390625,1280.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,32768,2560.1650390625,2560.1650390625,2560.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,65536,5120.1650390625,5120.1650390625,5120.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,131072,10240.1650390625,10240.1650390625,10240.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 diff --git a/benchmark/scripts/benchmark_megatron_cross_entropy.py b/benchmark/scripts/benchmark_megatron_cross_entropy.py index 6eca12edc..afa502c5f 100644 --- a/benchmark/scripts/benchmark_megatron_cross_entropy.py +++ b/benchmark/scripts/benchmark_megatron_cross_entropy.py @@ -1,18 +1,36 @@ """Benchmark Liger's Megatron-LM cross-entropy wrapper. -Benchmarks the Liger [seq, batch, vocab] cross-entropy wrapper against PyTorch's -native ``F.cross_entropy`` on equivalent input shapes. When megatron-core is -installed, Megatron's own ``fused_vocab_parallel_cross_entropy`` is added as a -third provider to reproduce end-to-end comparisons. +Compares four providers on the per-token CE call shape ``[seq, batch, vocab]``: + + - **torch**: vanilla ``F.cross_entropy`` + - **megatron**: Megatron's *fused* ``fused_vocab_parallel_cross_entropy`` path + (``cross_entropy_loss_fusion=True``, JIT-fused via TorchScript) + - **megatron-unfused**: Megatron's *unfused* ``vocab_parallel_cross_entropy`` + path (``cross_entropy_loss_fusion=False``, eager Python; the path users on + ``label_smoothing`` typically end up on) + - **liger**: ``LigerMegatronCrossEntropy`` — Liger's Triton CE wrapped in the + Megatron fused signature. Same kernel regardless of which Megatron symbol + it was patched onto, so we only benchmark it once. Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not -installed, the "megatron" provider is silently skipped. +installed, both megatron providers are silently skipped. + +Output: + - CSV: ``benchmark/data/all_benchmark_data_megatron.csv`` + (separate per-component file, mirroring the recent ``all_benchmark_data_cutile.csv`` + precedent — keeps the PR diff scannable) + - Plots (best-effort): ``benchmark/visualizations/megatron_cross_entropy_*.png`` + rendered when matplotlib is available; skipped silently otherwise. """ +import functools +import os + import torch import torch.nn.functional as F import triton +import utils as benchmark_utils from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -20,20 +38,50 @@ from utils import parse_benchmark_script_args from utils import run_benchmarks -from liger_kernel.megatron.cross_entropy import _build_wrapper -from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +_CSV_FILENAME = "all_benchmark_data_megatron.csv" + +# Redirect CSV output to a per-component file (parallel to all_benchmark_data_cutile.csv). +# We can't reuse the LIGER_KERNEL_IMPL knob because it also drives kernel-backend +# selection in liger_kernel.ops — overloading it would force us onto a megatron-named +# backend that doesn't exist. Patching update_benchmark_data_csv is surgical and only +# affects this benchmark process. +_original_update_benchmark_data_csv = benchmark_utils.update_benchmark_data_csv + + +@functools.wraps(_original_update_benchmark_data_csv) +def _patched_update_benchmark_data_csv(*args, **kwargs): + kwargs["filename"] = _CSV_FILENAME + return _original_update_benchmark_data_csv(*args, **kwargs) + + +benchmark_utils.update_benchmark_data_csv = _patched_update_benchmark_data_csv + +from liger_kernel.megatron import LigerMegatronCrossEntropy from liger_kernel.utils import infer_device device = infer_device() try: from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy + from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy _MEGATRON_AVAILABLE = True except ImportError: fused_vocab_parallel_cross_entropy = None + vocab_parallel_cross_entropy = None _MEGATRON_AVAILABLE = False +try: + import matplotlib + + matplotlib.use("Agg") # headless + import matplotlib.pyplot as plt # noqa: E402 + + _HAVE_MATPLOTLIB = True +except ImportError: + plt = None + _HAVE_MATPLOTLIB = False + def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True): logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad) @@ -53,11 +101,9 @@ def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch. def _ensure_single_rank_tp_group(): """Initialize torch.distributed (single-rank) and return a usable TP group. - For a single-process benchmark we use the world group of - size 1, where the internal all-reduce becomes a no-op. + For a single-process benchmark we use the world group of size 1; the + internal all-reduce becomes a no-op. """ - import os - import torch.distributed as dist if not dist.is_initialized(): @@ -72,8 +118,8 @@ def _ensure_single_rank_tp_group(): def _select_fwd(provider: str): if provider == "liger": - wrapper = _build_wrapper(LigerCrossEntropyLoss(reduction="none")) - return wrapper + ce = LigerMegatronCrossEntropy(reduction="none") + return lambda logits, target: ce(logits, target) if provider == "torch": return _pytorch_cross_entropy if provider == "megatron": @@ -81,10 +127,22 @@ def _select_fwd(provider: str): raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider") tp_group = _ensure_single_rank_tp_group() - def _megatron_call(logits, target): + def _megatron_fused_call(logits, target): return fused_vocab_parallel_cross_entropy(logits, target, tp_group) - return _megatron_call + return _megatron_fused_call + if provider == "megatron-unfused": + if not _MEGATRON_AVAILABLE: + raise RuntimeError( + "megatron-core not installed; cannot benchmark 'megatron-unfused' provider" + ) + tp_group = _ensure_single_rank_tp_group() + + def _megatron_unfused_call(logits, target): + # Unfused signature: (logits, target, label_smoothing=0.0, tp_group=None) + return vocab_parallel_cross_entropy(logits, target, 0.0, tp_group) + + return _megatron_unfused_call raise ValueError(f"unknown provider: {provider!r}") @@ -146,12 +204,122 @@ def full(): return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) +# --------------------------------------------------------------------------- +# Plot generation (best-effort). +# --------------------------------------------------------------------------- + + +def _generate_plots(out_dir: str) -> None: + """Generate one PNG per (metric, mode) combination from the CSV we just wrote. + + Silently skipped when matplotlib is unavailable. Reads the CSV rather than + re-running benchmarks so the plots use the same numbers that landed on disk. + """ + if not _HAVE_MATPLOTLIB: + print("[plots] matplotlib not available; skipping plot generation.") + return + + import csv + from pathlib import Path + + csv_path = Path(os.path.join(os.path.dirname(__file__), "..", "data", "all_benchmark_data_megatron.csv")) + if not csv_path.exists(): + print(f"[plots] CSV not found at {csv_path}; skipping plot generation.") + return + + out_dir_path = Path(out_dir) + out_dir_path.mkdir(parents=True, exist_ok=True) + + # CSV layout is denormalized: one row per (provider, mode, metric, x_value). + # We need to aggregate rows back into (provider → x_values, y50_list, y20_list, y80_list) + # series before plotting. + rows = [] + with open(csv_path, "r") as f: + reader = csv.DictReader(f) + for row in reader: + if row.get("kernel_name") == "megatron_cross_entropy": + rows.append(row) + + if not rows: + print("[plots] no megatron_cross_entropy rows in the CSV; skipping plots.") + return + + series: dict = {} # (metric_name, mode, provider) → {xs, y50, y20, y80, meta} + for row in rows: + key = (row["metric_name"], row["kernel_operation_mode"], row["kernel_provider"]) + entry = series.setdefault( + key, + { + "xs": [], "y50": [], "y20": [], "y80": [], + "x_label": row.get("x_label", "x"), + "metric_unit": row.get("metric_unit", ""), + }, + ) + try: + entry["xs"].append(float(row["x_value"])) + entry["y50"].append(float(row["y_value_50"])) + entry["y20"].append(float(row["y_value_20"])) + entry["y80"].append(float(row["y_value_80"])) + except (KeyError, ValueError): + # If the CSV schema differs from what we expect, surface it loudly but skip + # plotting rather than crashing the whole benchmark. + print(f"[plots] WARNING: skipping malformed row: {row}") + continue + + if not series: + print("[plots] no usable series in the CSV; skipping plots.") + return + + # Group series by (metric_name, mode) so each plot collects all providers for that slice. + plots: dict = {} + for (metric_name, mode, provider), entry in series.items(): + plots.setdefault((metric_name, mode), []).append((provider, entry)) + + plot_paths = [] + for (metric_name, mode), provider_entries in plots.items(): + mode_label = mode if mode not in (None, "", "None") else "full" + fig, ax = plt.subplots(figsize=(8, 5)) + x_label = "x" + metric_unit = "" + for provider, entry in sorted(provider_entries, key=lambda pe: pe[0]): + # Sort points by x so the line plot is monotone. + order = sorted(range(len(entry["xs"])), key=lambda i: entry["xs"][i]) + xs = [entry["xs"][i] for i in order] + y50 = [entry["y50"][i] for i in order] + y20 = [entry["y20"][i] for i in order] + y80 = [entry["y80"][i] for i in order] + # Capture the line's color so fill_between's band uses the same hue across + # matplotlib versions that don't share auto-cycles between the two calls. + (line,) = ax.plot(xs, y50, marker="o", label=provider) + ax.fill_between(xs, y20, y80, alpha=0.2, color=line.get_color()) + x_label = entry["x_label"] + metric_unit = entry["metric_unit"] + + ax.set_xscale("log", base=2) + ax.set_xlabel(x_label) + ax.set_ylabel(f"{metric_name} ({metric_unit})") + ax.set_title(f"Megatron CE — {metric_name}, mode={mode_label}") + ax.legend() + ax.grid(True, alpha=0.3) + + out_path = out_dir_path / f"megatron_cross_entropy_{metric_name}_{mode_label}.png" + fig.tight_layout() + fig.savefig(out_path, dpi=120) + plt.close(fig) + plot_paths.append(str(out_path)) + + print(f"[plots] wrote {len(plot_paths)} plot(s):") + for p in plot_paths: + print(f" - {p}") + + if __name__ == "__main__": args = parse_benchmark_script_args() providers = ["liger", "torch"] if _MEGATRON_AVAILABLE: providers.append("megatron") + providers.append("megatron-unfused") common_configs = { "kernel_name": "megatron_cross_entropy", @@ -177,3 +345,5 @@ def full(): metric_unit="MB", **common_configs, ) + + _generate_plots(out_dir=os.path.join(os.path.dirname(__file__), "..", "visualizations")) diff --git a/examples/megatron/run_mode1_monkey_patch.py b/examples/megatron/run_mode1_monkey_patch.py index 5daff20fa..449e2b042 100644 --- a/examples/megatron/run_mode1_monkey_patch.py +++ b/examples/megatron/run_mode1_monkey_patch.py @@ -1,20 +1,25 @@ -"""Mode 1 — monkey-patch Megatron-Core to use Liger RMSNorm. +"""Mode 1 — monkey-patch Megatron-Core to use Liger RMSNorm + cross-entropy. Adapted from Megatron's ``examples/run_simple_mcore_train_loop.py``. The relevant additions (vs. that file) are: - 1. ``apply_liger_kernel_to_megatron(rms_norm=True)`` called once at the - top of ``model_provider()``. This patches both - ``LocalSpecProvider.layer_norm`` (per-layer norm slots) and - ``transformer_block.LayerNormImpl`` (block-level ``final_layernorm``). + 1. ``apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True, + label_smoothing=0.1)`` called once at the top of ``model_provider()``. + This patches: + - ``LocalSpecProvider.layer_norm`` (per-layer norm slots) + - ``transformer_block.LayerNormImpl`` (block-level ``final_layernorm``) + - ``fused_cross_entropy.fused_vocab_parallel_cross_entropy`` + (the fused CE path) + - ``tensor_parallel.cross_entropy.vocab_parallel_cross_entropy`` + (the unfused CE path) 2. ``normalization="RMSNorm"`` added to ``TransformerConfig`` so the model actually has RMSNorm slots to patch (Megatron defaults to ``LayerNorm``). - 3. ``_print_norm_classes`` after model construction — prints the - resolved class for every norm slot so you can verify Liger took - over. + 3. ``_print_norm_classes`` + ``_print_ce_symbols`` after model construction + — print the resolved class/function bindings so you can verify Liger + took over for every slot. Run with: torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=29500 \\ @@ -70,7 +75,11 @@ def initialize_distributed(tp: int = 2, pp: int = 1) -> None: def model_provider() -> GPTModel: # ↓↓ Mode 1 — patch once, everything below picks up Liger ↓↓ - apply_liger_kernel_to_megatron(rms_norm=True) + apply_liger_kernel_to_megatron( + rms_norm=True, + cross_entropy=True, + label_smoothing=0.1, + ) # ↑↑ ------------------------------------------------------ ↑↑ cfg = TransformerConfig( @@ -84,7 +93,7 @@ def model_provider() -> GPTModel: return GPTModel( config=cfg, transformer_layer_spec=get_gpt_layer_local_spec(normalization="RMSNorm"), - vocab_size=100, + vocab_size=128, max_sequence_length=_SEQUENCE_LENGTH, ) @@ -142,8 +151,27 @@ def _print_norm_classes(model: torch.nn.Module) -> None: print() +def _print_ce_symbols() -> None: + """Show the current bindings of Megatron's two CE entry points.""" + import megatron.core.fusions.fused_cross_entropy as fused + import megatron.core.tensor_parallel.cross_entropy as unfused + + print("\n=== Resolved CE symbols ===") + print( + f" fused.fused_vocab_parallel_cross_entropy → " + f"{fused.fused_vocab_parallel_cross_entropy.__name__}" + ) + print( + f" unfused.vocab_parallel_cross_entropy → " + f"{unfused.vocab_parallel_cross_entropy.__name__}" + ) + print() + + def main() -> None: - initialize_distributed(tp=2, pp=1) + # TP=1, DP=2 — CE patch (TP=1 only). Norms are correct under any TP value, so + # demonstrating both Liger features in one script means running data-parallel. + initialize_distributed(tp=1, pp=1) model_parallel_cuda_manual_seed(123) torch.manual_seed(123) @@ -153,6 +181,7 @@ def main() -> None: print("\n=== Full model tree (mode 1: monkey-patch) ===") print(gpt_model) _print_norm_classes(gpt_model) + _print_ce_symbols() ddp_cfg = DistributedDataParallelConfig( grad_reduce_in_fp32=False, diff --git a/examples/megatron/run_mode2_hand_spec.py b/examples/megatron/run_mode2_hand_spec.py index 1e17a6268..c4d66ef9f 100644 --- a/examples/megatron/run_mode2_hand_spec.py +++ b/examples/megatron/run_mode2_hand_spec.py @@ -1,22 +1,28 @@ -"""Mode 2 — hand-assembled TransformerBlockSubmodules using LigerMegatronRMSNorm. +"""Mode 2 — hand-assembled spec + GPTModel subclass using Liger directly. Adapted from Megatron's ``examples/run_simple_mcore_train_loop.py``. The relevant additions (vs. that file) are: - 1. Direct import of ``LigerMegatronRMSNorm`` (no monkey-patch). + 1. Direct imports of ``LigerMegatronRMSNorm`` and ``LigerMegatronCrossEntropy`` + (no monkey-patch). - 2. ``model_provider()`` assembles a ``TransformerBlockSubmodules`` by - hand, placing ``LigerMegatronRMSNorm`` into every norm slot: + 2. ``model_provider()`` assembles a ``TransformerBlockSubmodules`` by hand, + placing ``LigerMegatronRMSNorm`` into every norm slot: - per-layer ``input_layernorm`` and ``pre_mlp_layernorm`` - the block-level ``layer_norm`` field that backs ``decoder.final_layernorm`` - This is the slot-level integration path — verbose but maximally - explicit. It is the only way to control ``final_layernorm`` from - user code without monkey-patching. + This is the slot-level integration path — verbose but maximally explicit. + It is the only way to control ``final_layernorm`` from user code without + monkey-patching. - 3. ``_print_norm_classes`` after model construction — prints the - resolved class for every norm slot so you can verify Liger took - over. + 3. ``_LigerCEGPTModel(GPTModel)`` overrides + ``LanguageModule.compute_language_model_loss`` to route the loss through + a ``LigerMegatronCrossEntropy`` instance. Cross-entropy has no spec slot + in Megatron, so subclassing is the symmetric "hand-built" path. + + 4. ``_print_norm_classes`` + ``_print_ce_class`` after model construction + — print the resolved class for every norm slot AND the resolved CE + class on the model so you can verify Liger took over. Run with: torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=29500 \\ @@ -65,6 +71,7 @@ from torch.utils.data import DataLoader # --- Liger integration: Mode 2 --------------------------------------------- +from liger_kernel.megatron import LigerMegatronCrossEntropy from liger_kernel.megatron import LigerMegatronRMSNorm # --------------------------------------------------------------------------- @@ -72,6 +79,31 @@ _SEQUENCE_LENGTH = 64 _NUM_ITERS = 5 _NUM_LAYERS = 2 +_LABEL_SMOOTHING = 0.1 + + +class _LigerCEGPTModel(GPTModel): + """``GPTModel`` subclass that routes its loss through ``LigerMegatronCrossEntropy``. + + Megatron's CE is not a spec slot — ``LanguageModule.compute_language_model_loss`` + calls ``fused_vocab_parallel_cross_entropy`` directly. The symmetric "hand-built" + integration is therefore to subclass ``GPTModel`` and override that method. + """ + + def __init__(self, *args, liger_ce_label_smoothing: float = 0.0, **kwargs): + super().__init__(*args, **kwargs) + self.liger_ce = LigerMegatronCrossEntropy( + ignore_index=-100, + label_smoothing=liger_ce_label_smoothing, + reduction="none", + ) + + def compute_language_model_loss(self, labels, logits): + # LanguageModule contract: input labels are [b, s], output loss is [b, s]. + # LigerMegatronCrossEntropy matches the fused signature, which expects [s, b]. + labels_sb = labels.transpose(0, 1).contiguous() # [s, b] + loss_sb = self.liger_ce(logits, labels_sb, self.pg_collection.tp) # [s, b] + return loss_sb.transpose(0, 1).contiguous() # [b, s] def initialize_distributed(tp: int = 2, pp: int = 1) -> None: @@ -133,11 +165,12 @@ def model_provider() -> GPTModel: ) # ↑↑ ----------------------------------------------------------------- ↑↑ - return GPTModel( + return _LigerCEGPTModel( config=cfg, transformer_layer_spec=block_spec, - vocab_size=100, + vocab_size=128, max_sequence_length=_SEQUENCE_LENGTH, + liger_ce_label_smoothing=_LABEL_SMOOTHING, ) @@ -194,8 +227,24 @@ def _print_norm_classes(model: torch.nn.Module) -> None: print() +def _print_ce_class(model: torch.nn.Module) -> None: + """Show that ``compute_language_model_loss`` will route through Liger.""" + ce = getattr(model, "liger_ce", None) + print("=== Resolved CE class ===") + if ce is None: + print(" model.liger_ce → (not set; subclass missing)") + else: + print(f" model.liger_ce → " + f"{type(ce).__module__}.{type(ce).__name__}") + print(f" ce.label_smoothing → {ce.label_smoothing}") + print(f" ce.ignore_index → {ce.ignore_index}") + print() + + def main() -> None: - initialize_distributed(tp=2, pp=1) + # TP=1, DP=2 — CE patch (TP=1 only). Norms are correct under any TP value, so + # demonstrating both Liger features in one script means running data-parallel. + initialize_distributed(tp=1, pp=1) model_parallel_cuda_manual_seed(123) torch.manual_seed(123) @@ -205,6 +254,7 @@ def main() -> None: print("\n=== Full model tree (mode 2: hand-built spec) ===") print(gpt_model) _print_norm_classes(gpt_model) + _print_ce_class(gpt_model) ddp_cfg = DistributedDataParallelConfig( grad_reduce_in_fp32=False, diff --git a/src/liger_kernel/megatron/__init__.py b/src/liger_kernel/megatron/__init__.py index 27441de0d..beb9b00a3 100644 --- a/src/liger_kernel/megatron/__init__.py +++ b/src/liger_kernel/megatron/__init__.py @@ -3,13 +3,20 @@ Public API: LigerMegatronRMSNorm — RMSNorm module conforming to Megatron-Core's LayerNormBuilder protocol. - apply_liger_kernel_to_megatron — patches Megatron-Core so existing - training scripts pick up Liger kernels with one line. Currently - supports RMSNorm (via BackendSpecProvider) and fused vocab-parallel - cross-entropy. + LigerMegatronCrossEntropy — nn.Module drop-in for Megatron's vocab-parallel + cross-entropy (fused signature). + apply_liger_kernel_to_megatron — patches Megatron-Core so existing training + scripts pick up Liger kernels with one line. Currently supports + RMSNorm (via BackendSpecProvider) plus both the fused and unfused + vocab-parallel cross-entropy paths. """ +from liger_kernel.megatron.cross_entropy import LigerMegatronCrossEntropy from liger_kernel.megatron.monkey_patch import apply_liger_kernel_to_megatron from liger_kernel.megatron.rms_norm import LigerMegatronRMSNorm -__all__ = ["LigerMegatronRMSNorm", "apply_liger_kernel_to_megatron"] +__all__ = [ + "LigerMegatronCrossEntropy", + "LigerMegatronRMSNorm", + "apply_liger_kernel_to_megatron", +] diff --git a/src/liger_kernel/megatron/cross_entropy.py b/src/liger_kernel/megatron/cross_entropy.py index a4a1ec12f..4ed11e69b 100644 --- a/src/liger_kernel/megatron/cross_entropy.py +++ b/src/liger_kernel/megatron/cross_entropy.py @@ -1,81 +1,84 @@ +"""Megatron-Core compatible cross-entropy backed by Liger's Triton kernel.""" + +from __future__ import annotations + import torch -from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.transformers.functional import liger_cross_entropy + +class LigerMegatronCrossEntropy(torch.nn.Module): + """``nn.Module`` drop-in for Megatron's vocab-parallel cross-entropy. -def _build_wrapper(loss_fn: LigerCrossEntropyLoss): - """Build a drop-in replacement for ``fused_vocab_parallel_cross_entropy``. + Conforms to ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``'s + signature, ``(vocab_parallel_logits, target, tp_group=None)``. Public Mode-2 (hand-built) + API: instantiate once with the per-training-run config, then call from your overridden + ``LanguageModule.compute_language_model_loss`` (or wherever Megatron's CE would live in your + custom model). + + Mirrors the ``LigerMegatronRMSNorm`` pattern shipped in PR #1254: config-time kwargs on + ``__init__``, data-only ``forward``. Single source of truth for the underlying Liger call; + the monkey-patch wrappers in ``monkey_patch.py`` instantiate this class. + + Args: + ignore_index: Target index to ignore. + label_smoothing: Cross-entropy label smoothing factor. + reduction: Must be ``"none"`` — Megatron's vocab-parallel CE contract returns per-token + loss shaped ``[seq, batch]`` and handles reduction itself downstream. - The returned callable has exactly the same parameter list Megatron expects - (``vocab_parallel_logits``, ``target``, ``tp_group``). Any unknown kwargs - will raise ``TypeError`` naturally — this is intentional: if a future - Megatron release adds new parameters to the fused-CE contract, we want to - fail loudly rather than silently drop them. + Scope: + TP=1 only. Vocab-parallel cross-entropy (TP>1) requires cross-rank reductions + that Liger's kernel does not perform; tracked as Phase 1.5 follow-up. Raises + ``RuntimeError`` at call time if a multi-rank ``tp_group`` is supplied. """ - def liger_fused_vocab_parallel_cross_entropy( + def __init__( + self, + ignore_index: int = -100, + label_smoothing: float = 0.0, + reduction: str = "none", + ): + super().__init__() + if reduction != "none": + raise ValueError( + f"Megatron's vocab-parallel CE contract requires per-token loss; " + f"reduction must be 'none', got {reduction!r}." + ) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward( + self, vocab_parallel_logits: torch.Tensor, target: torch.Tensor, tp_group=None, ) -> torch.Tensor: if tp_group is not None and hasattr(tp_group, "size") and tp_group.size() > 1: raise RuntimeError( - f"Liger Megatron cross-entropy wrapper requires tensor_model_parallel_size=1, " - f"got tp_group.size()={tp_group.size()}. Vocab-parallel support is tracked as " - f"follow-up work." + f"LigerMegatronCrossEntropy requires tensor_model_parallel_size=1, " + f"got tp_group.size()={tp_group.size()}. Vocab-parallel support is " + f"tracked as follow-up work." + ) + if vocab_parallel_logits.dim() != 3: + raise ValueError( + f"vocab_parallel_logits must be 3-D ([seq, batch, vocab]); " + f"got shape {tuple(vocab_parallel_logits.shape)}. (HuggingFace's " + f"[batch, seq, vocab] callers must transpose before calling.)" ) - s, b, v = vocab_parallel_logits.shape - logits_2d = vocab_parallel_logits.reshape(-1, v) - target_1d = target.reshape(-1) - loss = loss_fn(logits_2d, target_1d) + loss = liger_cross_entropy( + vocab_parallel_logits.reshape(-1, v), + target.reshape(-1), + ignore_index=self.ignore_index, + label_smoothing=self.label_smoothing, + reduction=self.reduction, + ) return loss.reshape(s, b) - return liger_fused_vocab_parallel_cross_entropy - - -def _patch_fused_vocab_parallel_cross_entropy( - reduction: str = "none", - ignore_index: int = -100, - label_smoothing: float = 0.0, -) -> None: - """Replace ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``. - - See ``apply_liger_kernel_to_megatron`` in ``monkey_patch.py`` for the public - entry point; this helper holds the cross-entropy-specific patching logic so - that future Megatron kernel integrations can sit alongside it without - polluting the framework-level apply function. - - Args: - reduction: Must be ``"none"``; Megatron's fused-CE contract returns - per-token loss shaped ``[seq, batch]``. - ignore_index: Target index to ignore. - label_smoothing: Cross-entropy label smoothing factor. - """ - assert reduction == "none", ( - f"Megatron's fused_vocab_parallel_cross_entropy contract requires per-token loss; " - f"reduction must be 'none', got {reduction!r}." - ) - - try: - import megatron.core.fusions.fused_cross_entropy as fce - except ImportError as exc: - raise ImportError( - "apply_liger_kernel_to_megatron requires megatron-core to be installed. " - "Expected symbol path: " - "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy." - ) from exc - - if not hasattr(fce, "fused_vocab_parallel_cross_entropy"): - raise ImportError( - "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy not found. " - "The symbol path may have changed in your Megatron-LM version. Please file an issue " - "on https://github.com/linkedin/Liger-Kernel with your megatron-core version." + def extra_repr(self) -> str: + return ( + f"ignore_index={self.ignore_index}, " + f"label_smoothing={self.label_smoothing}, " + f"reduction={self.reduction!r}" ) - - loss_fn = LigerCrossEntropyLoss( - ignore_index=ignore_index, - label_smoothing=label_smoothing, - reduction="none", - ) - fce.fused_vocab_parallel_cross_entropy = _build_wrapper(loss_fn) diff --git a/src/liger_kernel/megatron/monkey_patch.py b/src/liger_kernel/megatron/monkey_patch.py index 7d927d5f1..9c1c42438 100644 --- a/src/liger_kernel/megatron/monkey_patch.py +++ b/src/liger_kernel/megatron/monkey_patch.py @@ -4,8 +4,6 @@ import logging -from liger_kernel.megatron.cross_entropy import _patch_fused_vocab_parallel_cross_entropy - logger = logging.getLogger(__name__) _PATCH_MARKER = "__liger_patched__" @@ -73,6 +71,11 @@ def apply_liger_kernel_to_megatron( ignore_index=ignore_index, label_smoothing=label_smoothing, ) + _patch_vocab_parallel_cross_entropy( + reduction=reduction, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) def _check_tensor_parallel_size_at_patch_time() -> None: @@ -181,3 +184,149 @@ def __new__(cls, config, hidden_size, eps=1e-5, **kwargs): "Patched megatron.core.transformer.transformer_block.LayerNormImpl " "to route RMSNorm configs through LigerMegatronRMSNorm." ) + + +# Sentinel for "caller did not pass this kwarg". A plain ``0.0`` default would be +# observationally indistinguishable from "user explicitly asked for 0.0" — and Megatron's +# native vocab_parallel_cross_entropy accepts call-time ``label_smoothing=0.0`` as a real +# request for that value. We must not silently override. +_LABEL_SMOOTHING_UNSET = object() + + +def _patch_fused_vocab_parallel_cross_entropy( + reduction: str = "none", + ignore_index: int = -100, + label_smoothing: float = 0.0, +) -> None: + """Replace ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``. + + Wraps a single ``LigerMegatronCrossEntropy`` instance (configured at patch time) in + a closure matching Megatron's fused-CE signature ``(logits, target, tp_group)``. + Idempotent: a sentinel attribute on the replacement prevents wrappers from stacking. + """ + if reduction != "none": + raise ValueError( + f"Megatron's fused_vocab_parallel_cross_entropy contract requires per-token loss; " + f"reduction must be 'none', got {reduction!r}." + ) + + try: + import megatron.core.fusions.fused_cross_entropy as fused_ce + except ImportError as exc: + raise ImportError( + "apply_liger_kernel_to_megatron(cross_entropy=True) requires megatron-core to be " + "installed. Expected symbol path: " + "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy." + ) from exc + + if not hasattr(fused_ce, "fused_vocab_parallel_cross_entropy"): + raise ImportError( + "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy not " + "found. The symbol path may have changed in your Megatron-LM version. Please file " + "an issue on https://github.com/linkedin/Liger-Kernel with your megatron-core version." + ) + + if getattr(fused_ce.fused_vocab_parallel_cross_entropy, _PATCH_MARKER, False): + return # already patched + + original = fused_ce.fused_vocab_parallel_cross_entropy + + from liger_kernel.megatron.cross_entropy import LigerMegatronCrossEntropy + + ce = LigerMegatronCrossEntropy( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction=reduction, + ) + + def liger_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_group=None): + return ce(vocab_parallel_logits, target, tp_group=tp_group) + + setattr(liger_fused_vocab_parallel_cross_entropy, _PATCH_MARKER, True) + setattr(liger_fused_vocab_parallel_cross_entropy, "__wrapped__", original) + fused_ce.fused_vocab_parallel_cross_entropy = liger_fused_vocab_parallel_cross_entropy + + logger.info( + "Patched megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy " + "with Liger cross-entropy." + ) + + +def _patch_vocab_parallel_cross_entropy( + reduction: str = "none", + ignore_index: int = -100, + label_smoothing: float = 0.0, +) -> None: + """Replace ``megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy``. + + This is Megatron's *unfused* eager-Python vocab-parallel CE path, dispatched to when + ``config.cross_entropy_loss_fusion=False``. Its signature accepts ``label_smoothing`` + at call time, so the wrapper honors a runtime value when the caller actually passed + one. A sentinel disambiguates "caller passed 0.0" (use 0.0) from "caller didn't pass" + (use patch-time default). + """ + if reduction != "none": + raise ValueError( + f"vocab_parallel_cross_entropy returns per-token loss; " + f"reduction must be 'none', got {reduction!r}." + ) + + try: + import megatron.core.tensor_parallel.cross_entropy as unfused_ce + except ImportError as exc: + raise ImportError( + "apply_liger_kernel_to_megatron(cross_entropy=True) requires megatron-core to be " + "installed. Expected symbol path: " + "megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy." + ) from exc + + if not hasattr(unfused_ce, "vocab_parallel_cross_entropy"): + raise ImportError( + "megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy not " + "found. The symbol path may have changed in your Megatron-LM version. Please file " + "an issue on https://github.com/linkedin/Liger-Kernel with your megatron-core version." + ) + + if getattr(unfused_ce.vocab_parallel_cross_entropy, _PATCH_MARKER, False): + return # already patched + + original = unfused_ce.vocab_parallel_cross_entropy + + from liger_kernel.megatron.cross_entropy import LigerMegatronCrossEntropy + + # Patch-time default; reused for every call where the caller doesn't pass label_smoothing. + # Avoids allocating a fresh module per CE call in the common case. + default_ce = LigerMegatronCrossEntropy( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction=reduction, + ) + + def liger_vocab_parallel_cross_entropy( + vocab_parallel_logits, + target, + label_smoothing=_LABEL_SMOOTHING_UNSET, + tp_group=None, + ): + # Sentinel-based "did the caller pass this?" check so that an explicit + # label_smoothing=0.0 from the caller is honored verbatim (matching Megatron's + # native vocab_parallel_cross_entropy contract). Construct a fresh + # LigerMegatronCrossEntropy only on the runtime-override path; nn.Module + # construction is microseconds vs. CE-kernel milliseconds. + if label_smoothing is _LABEL_SMOOTHING_UNSET: + return default_ce(vocab_parallel_logits, target, tp_group=tp_group) + ce = LigerMegatronCrossEntropy( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction=reduction, + ) + return ce(vocab_parallel_logits, target, tp_group=tp_group) + + setattr(liger_vocab_parallel_cross_entropy, _PATCH_MARKER, True) + setattr(liger_vocab_parallel_cross_entropy, "__wrapped__", original) + unfused_ce.vocab_parallel_cross_entropy = liger_vocab_parallel_cross_entropy + + logger.info( + "Patched megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy " + "with Liger cross-entropy." + ) diff --git a/test/megatron/test_cross_entropy.py b/test/megatron/test_cross_entropy.py index f995b051e..534a9a57a 100644 --- a/test/megatron/test_cross_entropy.py +++ b/test/megatron/test_cross_entropy.py @@ -1,19 +1,19 @@ -"""Correctness tests for the Liger Megatron cross-entropy wrapper. +"""Correctness tests for ``LigerMegatronCrossEntropy``. -These tests exercise ``_build_wrapper`` directly without importing -megatron-core — the wrapper is the [s, b, v] -> [s, b] reshape shim around -``LigerCrossEntropyLoss`` and is meaningful to test on its own. +The class is the public Mode-2 API; the monkey-patch wrappers in +``monkey_patch.py`` are thin closures around an instance of this class. Tests +target the class directly — that's the single source of truth for the CE math. -The wrapper calls the underlying Triton kernel, so these tests require a -Liger-supported accelerator (same as ``test/transformers/test_cross_entropy.py``). +Mirrors ``test/megatron/test_rms_norm.py``'s parametrize style for the +fp32/bf16 sweep so the visual symmetry across the two megatron-side files is +preserved. """ import pytest import torch import torch.nn.functional as F -from liger_kernel.megatron.cross_entropy import _build_wrapper -from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.megatron import LigerMegatronCrossEntropy from liger_kernel.utils import infer_device from test.utils import assert_verbose_allclose from test.utils import set_seed @@ -23,15 +23,6 @@ set_seed(42) -def _make_wrapper(ignore_index: int = -100, label_smoothing: float = 0.0): - loss_fn = LigerCrossEntropyLoss( - ignore_index=ignore_index, - label_smoothing=label_smoothing, - reduction="none", - ) - return _build_wrapper(loss_fn) - - def _reference_loss( vocab_parallel_logits: torch.Tensor, target: torch.Tensor, @@ -49,6 +40,11 @@ def _reference_loss( return loss_flat.reshape(s, b) +# --------------------------------------------------------------------------- +# Forward correctness vs. F.cross_entropy reference. +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize( "s, b, v", [ @@ -69,56 +65,79 @@ def _reference_loss( ), ], ) -def test_wrapper_matches_pytorch_cross_entropy(s, b, v, dtype, atol, rtol): - wrapper = _make_wrapper() +def test_class_matches_pytorch_cross_entropy(s, b, v, dtype, atol, rtol): + ce = LigerMegatronCrossEntropy() logits = torch.randn(s, b, v, device=device, dtype=dtype) * 0.5 target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) ref = _reference_loss(logits, target, ignore_index=-100, label_smoothing=0.0) - got = wrapper(logits, target) + got = ce(logits, target) assert got.shape == (s, b) assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) +# --------------------------------------------------------------------------- +# Configuration plumbing — wrapper-specific contracts. +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize("ignore_index", [-100, 0]) -def test_wrapper_respects_ignore_index(ignore_index): +def test_class_respects_ignore_index(ignore_index): s, b, v = 16, 2, 1024 - wrapper = _make_wrapper(ignore_index=ignore_index) + ce = LigerMegatronCrossEntropy(ignore_index=ignore_index) logits = torch.randn(s, b, v, device=device, dtype=torch.float32) target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) target.view(-1)[: (s * b) // 4] = ignore_index ref = _reference_loss(logits, target, ignore_index=ignore_index, label_smoothing=0.0) - got = wrapper(logits, target) + got = ce(logits, target) assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) @pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) -def test_wrapper_respects_label_smoothing(label_smoothing): +def test_class_respects_label_smoothing(label_smoothing): s, b, v = 8, 2, 512 - wrapper = _make_wrapper(label_smoothing=label_smoothing) + ce = LigerMegatronCrossEntropy(label_smoothing=label_smoothing) logits = torch.randn(s, b, v, device=device, dtype=torch.float32) target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) ref = _reference_loss(logits, target, ignore_index=-100, label_smoothing=label_smoothing) - got = wrapper(logits, target) + got = ce(logits, target) assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4) -def test_wrapper_rejects_unknown_kwargs(): - wrapper = _make_wrapper() - logits = torch.randn(4, 1, 32, device=device) - target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) - with pytest.raises(TypeError): - wrapper(logits, target, unknown_arg=123) +@pytest.mark.parametrize("bad_reduction", ["mean", "sum", "MEAN", "garbage"]) +def test_class_rejects_non_none_reduction(bad_reduction): + """Megatron's contract is per-token loss; mean/sum break the [s, b] return shape.""" + with pytest.raises(ValueError, match="reduction must be 'none'"): + LigerMegatronCrossEntropy(reduction=bad_reduction) + + +def test_class_rejects_non_3d_logits(): + """The class explicitly guards against HuggingFace-shape [b, s, v] callers etc.""" + ce = LigerMegatronCrossEntropy() + bad = torch.randn(8, 16, device=device) # 2-D + target = torch.randint(0, 16, (8,), device=device, dtype=torch.long) + with pytest.raises(ValueError, match="3-D"): + ce(bad, target) + + too_many = torch.randn(2, 2, 4, 16, device=device) # 4-D + target2 = torch.randint(0, 16, (2, 2, 4), device=device, dtype=torch.long) + with pytest.raises(ValueError, match="3-D"): + ce(too_many, target2) + +# --------------------------------------------------------------------------- +# TP guard — the only safety net the class itself enforces. +# --------------------------------------------------------------------------- -def test_wrapper_rejects_multi_rank_tp_group(): - wrapper = _make_wrapper() + +def test_class_raises_on_tp_group_size_greater_than_one(): + ce = LigerMegatronCrossEntropy() logits = torch.randn(4, 1, 32, device=device) target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) @@ -127,11 +146,11 @@ def size(self): return 2 with pytest.raises(RuntimeError, match="tensor_model_parallel_size=1"): - wrapper(logits, target, tp_group=_FakeGroup()) + ce(logits, target, tp_group=_FakeGroup()) -def test_wrapper_accepts_single_rank_tp_group(): - wrapper = _make_wrapper() +def test_class_accepts_single_rank_tp_group(): + ce = LigerMegatronCrossEntropy() logits = torch.randn(4, 1, 32, device=device) target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) @@ -139,19 +158,33 @@ class _FakeGroup: def size(self): return 1 - out = wrapper(logits, target, tp_group=_FakeGroup()) + out = ce(logits, target, tp_group=_FakeGroup()) assert out.shape == (4, 1) -def test_wrapper_preserves_gradients(): +# --------------------------------------------------------------------------- +# Gradient sanity — Liger's CE writes gradients in place; verify the class +# preserves them through Megatron's [s, b, v] reshape contract. +# --------------------------------------------------------------------------- + + +def test_class_preserves_gradients(): s, b, v = 8, 2, 256 - wrapper = _make_wrapper() + ce = LigerMegatronCrossEntropy() logits = torch.randn(s, b, v, device=device, dtype=torch.float32, requires_grad=True) target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) - loss = wrapper(logits, target).sum() + loss = ce(logits, target).sum() loss.backward() assert logits.grad is not None assert logits.grad.shape == logits.shape + + +def test_class_extra_repr(): + ce = LigerMegatronCrossEntropy(ignore_index=42, label_smoothing=0.07) + rep = ce.extra_repr() + assert "ignore_index=42" in rep + assert "label_smoothing=0.07" in rep + assert "reduction='none'" in rep diff --git a/test/megatron/test_monkey_patch.py b/test/megatron/test_monkey_patch.py index ac42234a5..e49110ab2 100644 --- a/test/megatron/test_monkey_patch.py +++ b/test/megatron/test_monkey_patch.py @@ -1,15 +1,18 @@ -"""Tests for apply_liger_kernel_to_megatron's patch mechanism. +"""Tests for ``apply_liger_kernel_to_megatron``'s cross-entropy patch mechanism. -Megatron-LM is not a test dependency. We inject stub modules into -``sys.modules`` so the patch function can run entirely on CPU without a real -megatron-core install. Tests verify: +Megatron-LM is not a test dependency. We inject stub modules into ``sys.modules`` so the +patch helpers can run entirely on CPU without a real megatron-core install. Tests verify: -- the patch replaces ``fused_vocab_parallel_cross_entropy`` on the stub module +- the patch replaces both the fused symbol + (``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``) + AND the unfused symbol + (``megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy``) - ``reduction != "none"`` is rejected -- TP>1 at patch time raises RuntimeError -- missing megatron-core raises a helpful ImportError -- missing symbol path raises a helpful ImportError -- the constructed LigerCrossEntropyLoss receives the user-supplied kwargs +- TP>1 at patch time raises ``RuntimeError`` +- missing megatron-core / missing symbol path raise helpful ``ImportError``\\s +- the constructed ``LigerMegatronCrossEntropy`` receives the user-supplied kwargs +- the unfused wrapper honors a runtime ``label_smoothing`` override (Megatron's unfused + signature is ``(logits, target, label_smoothing=0.0, tp_group=None)``) """ import sys @@ -20,35 +23,58 @@ import pytest -def _install_fake_megatron(tp_size: int = 1, with_fused_symbol: bool = True): - """Install stub megatron modules into sys.modules; return the fused module.""" +def _install_fake_megatron( + tp_size: int = 1, + with_fused_symbol: bool = True, + with_unfused_symbol: bool = True, +): + """Install stub megatron modules into ``sys.modules``. + + Returns a tuple ``(fused_ce_module, unfused_ce_module)`` so tests can inspect what + the patch helpers wrote onto them. + """ megatron = types.ModuleType("megatron") megatron_core = types.ModuleType("megatron.core") fusions = types.ModuleType("megatron.core.fusions") fused_ce = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + tensor_parallel = types.ModuleType("megatron.core.tensor_parallel") + unfused_ce = types.ModuleType("megatron.core.tensor_parallel.cross_entropy") parallel_state = types.ModuleType("megatron.core.parallel_state") if with_fused_symbol: def original_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_group=None): - raise AssertionError("original megatron kernel called — patch failed") + raise AssertionError("original megatron fused kernel called — patch failed") fused_ce.fused_vocab_parallel_cross_entropy = original_fused_vocab_parallel_cross_entropy + if with_unfused_symbol: + + def original_vocab_parallel_cross_entropy( + vocab_parallel_logits, target, label_smoothing=0.0, tp_group=None, + ): + raise AssertionError("original megatron unfused kernel called — patch failed") + + unfused_ce.vocab_parallel_cross_entropy = original_vocab_parallel_cross_entropy + parallel_state.get_tensor_model_parallel_world_size = lambda: tp_size sys.modules["megatron"] = megatron sys.modules["megatron.core"] = megatron_core sys.modules["megatron.core.fusions"] = fusions sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused_ce + sys.modules["megatron.core.tensor_parallel"] = tensor_parallel + sys.modules["megatron.core.tensor_parallel.cross_entropy"] = unfused_ce sys.modules["megatron.core.parallel_state"] = parallel_state megatron.core = megatron_core megatron_core.fusions = fusions + megatron_core.tensor_parallel = tensor_parallel megatron_core.parallel_state = parallel_state fusions.fused_cross_entropy = fused_ce + tensor_parallel.cross_entropy = unfused_ce - return fused_ce + return fused_ce, unfused_ce def _uninstall_fake_megatron(): @@ -56,6 +82,8 @@ def _uninstall_fake_megatron(): "megatron.core.parallel_state", "megatron.core.fusions.fused_cross_entropy", "megatron.core.fusions", + "megatron.core.tensor_parallel.cross_entropy", + "megatron.core.tensor_parallel", "megatron.core", "megatron", ]: @@ -64,28 +92,130 @@ def _uninstall_fake_megatron(): @pytest.fixture def fake_megatron(): - fused_ce = _install_fake_megatron(tp_size=1) + fused_ce, unfused_ce = _install_fake_megatron(tp_size=1) try: - yield fused_ce + yield fused_ce, unfused_ce finally: _uninstall_fake_megatron() +# --------------------------------------------------------------------------- +# Both symbols get replaced. +# --------------------------------------------------------------------------- + + def test_patch_replaces_fused_symbol(fake_megatron): + fused_ce, _ = fake_megatron from liger_kernel.megatron import apply_liger_kernel_to_megatron - original = fake_megatron.fused_vocab_parallel_cross_entropy + original = fused_ce.fused_vocab_parallel_cross_entropy apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) - patched = fake_megatron.fused_vocab_parallel_cross_entropy - assert patched is not original - assert patched.__name__ == "liger_fused_vocab_parallel_cross_entropy" + assert fused_ce.fused_vocab_parallel_cross_entropy is not original + assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" + + +def test_patch_replaces_unfused_symbol(fake_megatron): + _, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = unfused_ce.vocab_parallel_cross_entropy + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + assert unfused_ce.vocab_parallel_cross_entropy is not original + assert unfused_ce.vocab_parallel_cross_entropy.__name__ == "liger_vocab_parallel_cross_entropy" + + +def test_patch_replaces_both_fused_and_unfused_symbols_in_one_call(fake_megatron): + """A single ``cross_entropy=True`` call must replace both Megatron CE paths.""" + fused_ce, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" + assert unfused_ce.vocab_parallel_cross_entropy.__name__ == "liger_vocab_parallel_cross_entropy" + + +def test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched(fake_megatron): + """Default ``cross_entropy=False`` must not touch the CE symbols even if the call runs.""" + fused_ce, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + fused_before = fused_ce.fused_vocab_parallel_cross_entropy + unfused_before = unfused_ce.vocab_parallel_cross_entropy + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=False) + + assert fused_ce.fused_vocab_parallel_cross_entropy is fused_before + assert unfused_ce.vocab_parallel_cross_entropy is unfused_before + + +def test_patch_is_idempotent_for_both_symbols(fake_megatron): + """Calling ``apply_liger_kernel_to_megatron(cross_entropy=True)`` twice must not stack + wrappers — the sentinel attribute guards against double-patching.""" + fused_ce, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + fused_first = fused_ce.fused_vocab_parallel_cross_entropy + unfused_first = unfused_ce.vocab_parallel_cross_entropy + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + # Same identity → no stacked wrapping. + assert fused_ce.fused_vocab_parallel_cross_entropy is fused_first + assert unfused_ce.vocab_parallel_cross_entropy is unfused_first + # __wrapped__ still references the original Megatron symbol, not the first Liger wrapper. + assert fused_first.__wrapped__.__name__ == "original_fused_vocab_parallel_cross_entropy" + assert unfused_first.__wrapped__.__name__ == "original_vocab_parallel_cross_entropy" + + +def test_patch_fused_wrapper_passes_tp_group_through(fake_megatron): + """The fused wrapper closure must forward ``tp_group`` to the underlying class. + + We swap the CE class for a recording fake so the call doesn't need CUDA — just + confirms ``tp_group`` reaches the class's ``__call__``.""" + import torch + + fused_ce, _ = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + captured = {} + + class _FakeCE: + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + pass + + def __call__(self, logits, target, tp_group=None): + captured["tp_group"] = tp_group + captured["shape"] = tuple(logits.shape) + return torch.zeros(logits.shape[:2]) + + class _FakeGroup: + def size(self): + return 1 + + with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + logits = torch.zeros(2, 1, 4) + target = torch.zeros(2, 1, dtype=torch.long) + group = _FakeGroup() + fused_ce.fused_vocab_parallel_cross_entropy(logits, target, group) + + assert captured["tp_group"] is group + assert captured["shape"] == (2, 1, 4) + + +# --------------------------------------------------------------------------- +# Argument validation + TP-1 guard. +# --------------------------------------------------------------------------- def test_patch_rejects_non_none_reduction(fake_megatron): from liger_kernel.megatron import apply_liger_kernel_to_megatron - with pytest.raises(AssertionError, match="reduction must be 'none'"): + with pytest.raises(ValueError, match="reduction must be 'none'"): apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True, reduction="mean") @@ -102,7 +232,7 @@ def test_patch_raises_on_tp_greater_than_one(): def test_patch_defers_tp_check_when_parallel_state_not_initialized(): """If get_tensor_model_parallel_world_size() raises, patch should still succeed.""" - fused_ce = _install_fake_megatron(tp_size=1) + fused_ce, unfused_ce = _install_fake_megatron(tp_size=1) def raising_tp_size(): raise AssertionError("parallel_state not initialized") @@ -114,13 +244,18 @@ def raising_tp_size(): apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" + assert unfused_ce.vocab_parallel_cross_entropy.__name__ == "liger_vocab_parallel_cross_entropy" finally: _uninstall_fake_megatron() +# --------------------------------------------------------------------------- +# Missing-megatron / missing-symbol errors. +# --------------------------------------------------------------------------- + + def test_patch_raises_when_megatron_not_installed(): _uninstall_fake_megatron() - # Block imports of any "megatron*" module to simulate absent install. real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ def blocking_import(name, *args, **kwargs): @@ -146,24 +281,171 @@ def test_patch_raises_when_fused_symbol_missing(): _uninstall_fake_megatron() +def test_patch_raises_when_unfused_symbol_missing(): + """Symmetric to the fused-missing case; the unfused module exists but its symbol doesn't.""" + _install_fake_megatron(tp_size=1, with_unfused_symbol=False) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(ImportError, match="symbol path may have changed"): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + finally: + _uninstall_fake_megatron() + + +# --------------------------------------------------------------------------- +# Kwarg propagation + runtime label_smoothing override on the unfused path. +# --------------------------------------------------------------------------- + + def test_patch_forwards_ignore_index_and_label_smoothing(fake_megatron): + """Patch-time ignore_index + label_smoothing reach the underlying LigerMegatronCrossEntropy. + + Both the fused and unfused wrappers create an instance configured with these kwargs. + """ from liger_kernel.megatron import apply_liger_kernel_to_megatron from liger_kernel.megatron import cross_entropy as ce_mod - captured = {} + captured = [] - class FakeLoss: - def __init__(self, ignore_index, label_smoothing, reduction): - captured["ignore_index"] = ignore_index - captured["label_smoothing"] = label_smoothing - captured["reduction"] = reduction + real_ctor = ce_mod.LigerMegatronCrossEntropy.__init__ - def __call__(self, _input, target): - raise AssertionError("not expected to be called in this test") + def recording_init(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + captured.append({ + "ignore_index": ignore_index, + "label_smoothing": label_smoothing, + "reduction": reduction, + }) + real_ctor(self, ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction) - with patch.object(ce_mod, "LigerCrossEntropyLoss", FakeLoss): + with patch.object(ce_mod.LigerMegatronCrossEntropy, "__init__", recording_init): apply_liger_kernel_to_megatron( - rms_norm=False, cross_entropy=True, ignore_index=42, label_smoothing=0.25 + rms_norm=False, cross_entropy=True, ignore_index=42, label_smoothing=0.25, ) - assert captured == {"ignore_index": 42, "label_smoothing": 0.25, "reduction": "none"} + # The fused wrapper builds 1 instance; the unfused wrapper builds 1 default instance. + # Both should carry the user-supplied kwargs. + assert len(captured) >= 2 + for entry in captured: + assert entry == {"ignore_index": 42, "label_smoothing": 0.25, "reduction": "none"} + + +def test_unfused_wrapper_honors_runtime_label_smoothing(fake_megatron): + """The unfused signature takes ``label_smoothing`` as a runtime arg; the wrapper must honor it. + + When the caller passes a non-default value, the wrapper constructs a fresh + ``LigerMegatronCrossEntropy`` with that value rather than reusing the patch-time default. + + We verify this by replacing the class with a recording fake **before** calling + ``apply_liger_kernel_to_megatron`` — the patch helper does a fresh + ``from … import LigerMegatronCrossEntropy`` so the closure captures the fake. + """ + import torch + + _, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + constructed = [] + + class _FakeCE: + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + constructed.append(label_smoothing) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def __call__(self, logits, target, tp_group=None): + # Skip Liger kernel — just return a CPU-friendly tensor in the right shape. + return torch.zeros(logits.shape[:2]) + + with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): + apply_liger_kernel_to_megatron( + rms_norm=False, cross_entropy=True, ignore_index=-100, label_smoothing=0.05, + ) + + # Reset the recorder to focus on calls triggered by the next line. + constructed.clear() + logits = torch.zeros(2, 1, 4) + target = torch.zeros(2, 1, dtype=torch.long) + unfused_ce.vocab_parallel_cross_entropy(logits, target, label_smoothing=0.3) + + assert constructed == [0.3], ( + f"unfused wrapper should construct one fresh instance with the runtime override; " + f"got: {constructed}" + ) + + +def test_unfused_wrapper_uses_default_when_caller_does_not_pass_label_smoothing(fake_megatron): + """When the caller doesn't pass ``label_smoothing``, the wrapper reuses the patch-time + ``default_ce`` instance — no fresh allocation per call.""" + import torch + + _, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + constructed = [] + + class _FakeCE: + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + constructed.append(label_smoothing) + + def __call__(self, logits, target, tp_group=None): + return torch.zeros(logits.shape[:2]) + + with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): + apply_liger_kernel_to_megatron( + rms_norm=False, cross_entropy=True, ignore_index=-100, label_smoothing=0.05, + ) + constructed.clear() + logits = torch.zeros(2, 1, 4) + target = torch.zeros(2, 1, dtype=torch.long) + # No label_smoothing arg — wrapper reuses the default_ce instance. + unfused_ce.vocab_parallel_cross_entropy(logits, target) + # Second positional call also without label_smoothing — still no new construction. + unfused_ce.vocab_parallel_cross_entropy(logits, target) + + assert constructed == [], ( + f"default-path calls must reuse default_ce — no fresh instances; got: {constructed}" + ) + + +def test_unfused_wrapper_honors_explicit_zero_label_smoothing(fake_megatron): + """Explicit ``label_smoothing=0.0`` at call time must be honored verbatim, not silently + replaced by the patch-time default. + + This guards against the bug where the wrapper used ``if label_smoothing == 0.0:`` to + detect "caller passed nothing" — that conflated "caller didn't pass" with "caller + explicitly asked for 0.0" and corrupted loss math for Megatron callers that pass 0.0 + positionally.""" + import torch + + _, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + constructed = [] + + class _FakeCE: + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + constructed.append(label_smoothing) + + def __call__(self, logits, target, tp_group=None): + return torch.zeros(logits.shape[:2]) + + with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): + apply_liger_kernel_to_megatron( + rms_norm=False, cross_entropy=True, ignore_index=-100, label_smoothing=0.05, + ) + constructed.clear() + logits = torch.zeros(2, 1, 4) + target = torch.zeros(2, 1, dtype=torch.long) + + # Explicit positional 0.0 — must construct a fresh instance with 0.0. + unfused_ce.vocab_parallel_cross_entropy(logits, target, 0.0) + + assert constructed == [0.0], ( + f"explicit label_smoothing=0.0 at call time must be honored verbatim; " + f"got: {constructed}" + ) From a3f70a01430774c48edf52a878979fb390dbe2ef Mon Sep 17 00:00:00 2001 From: Vaibhav Jindal Date: Wed, 10 Jun 2026 06:45:01 +0000 Subject: [PATCH 3/7] Megatron CE: drop CE-specific kwargs from apply_liger_kernel_to_megatron MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Trim the framework-level patch API to just (rms_norm=True, cross_entropy=False). CE-specific knobs (ignore_index, label_smoothing, reduction) leaked the wrong abstraction at the wrong scope — the public entry point's job is "patch Megatron with Liger," not "configure CE." The patch now matches Megatron's native fused-CE behavior exactly by constructing LigerMegatronCrossEntropy with class defaults, so callers flip Liger on without touching loss config. The unfused wrapper keeps its per-call label_smoothing argument — that matches Megatron's native vocab_parallel_cross_entropy signature and is the right scope for the knob (per-call, not patch-time). Sentinel dispatch preserved so an explicit label_smoothing=0.0 from the caller is honored verbatim. Mode-2 (hand-built) users who need custom CE config still get it via LigerMegatronCrossEntropy's __init__ kwargs — that surface is unchanged. src/liger_kernel/megatron/monkey_patch.py Drop ignore_index, label_smoothing, reduction kwargs from apply_liger_kernel_to_megatron and from both _patch_*_cross_entropy helpers. Update docstrings to point Mode-2 users at LigerMegatronCrossEntropy for explicit config. test/megatron/test_monkey_patch.py Remove test_patch_rejects_non_none_reduction (path unreachable from trimmed API). Rename test_patch_forwards_… → test_patch_constructs_ce_with_class_defaults, asserting the patch uses class defaults. Strip now-invalid kwargs from the three runtime label_smoothing tests on the unfused wrapper; their per-call behavior is unchanged. docs/High-Level-APIs.md Replace usage example with the trimmed signature; add pointer to LigerMegatronCrossEntropy for Mode-2 explicit-config callers. examples/megatron/run_mode1_monkey_patch.py Drop label_smoothing=0.1 demo from the apply call and from the header comment. 15/15 monkey-patch tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/High-Level-APIs.md | 36 ++++--- examples/megatron/run_mode1_monkey_patch.py | 11 +- src/liger_kernel/megatron/monkey_patch.py | 110 +++++++------------- test/megatron/test_monkey_patch.py | 48 +++------ 4 files changed, 74 insertions(+), 131 deletions(-) diff --git a/docs/High-Level-APIs.md b/docs/High-Level-APIs.md index b90fc6e96..6bbe008a9 100644 --- a/docs/High-Level-APIs.md +++ b/docs/High-Level-APIs.md @@ -97,18 +97,18 @@ You can also use the Patching APIs to use the kernels for a specific model archi ## Megatron-LM Liger also exposes a patch for the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) -training framework, replacing Megatron's native -`fused_vocab_parallel_cross_entropy` with Liger's Triton cross-entropy kernel. +training framework, replacing Megatron's native RMSNorm and both vocab-parallel +cross-entropy paths (fused and unfused) with Liger's Triton kernels. | **Framework** | **API** | **Supported Operations** | |---------------|--------------------------------------------------------|--------------------------| -| Megatron-LM | `liger_kernel.megatron.apply_liger_kernel_to_megatron` | CrossEntropyLoss | +| Megatron-LM | `liger_kernel.megatron.apply_liger_kernel_to_megatron` | RMSNorm, CrossEntropyLoss | -**Scope**: Initial release supports `tensor_model_parallel_size=1` only. -Vocab-parallel cross-entropy (TP>1) is follow-up work — with TP>1, each rank -holds a sharded `[N, V/tp]` logits slice and cross-entropy requires cross-rank -all-reduces that Liger's kernel does not perform. The patch raises a -`RuntimeError` at patch time or call time if TP>1 is detected. +**Scope**: Initial release supports `tensor_model_parallel_size=1` only for +cross-entropy. Vocab-parallel cross-entropy (TP>1) is follow-up work — with +TP>1, each rank holds a sharded `[N, V/tp]` logits slice and cross-entropy +requires cross-rank all-reduces that Liger's kernel does not perform. The +patch raises a `RuntimeError` at patch time or call time if TP>1 is detected. **Usage**: @@ -116,17 +116,19 @@ all-reduces that Liger's kernel does not perform. The patch raises a from liger_kernel.megatron import apply_liger_kernel_to_megatron # Call before Megatron's forward pass reaches compute_language_model_loss. -# Match Megatron's config: pass the same ignore_index and label_smoothing -# values used by your training setup (Liger does not auto-detect them). -apply_liger_kernel_to_megatron( - ignore_index=-100, - label_smoothing=cfg.label_smoothing_factor, -) +# Defaults match Megatron's native CE behavior; no CE-specific config needed. +apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True) ``` -Ensure Megatron's fused-CE code path is enabled in your training config (e.g. -`--cross-entropy-loss-fusion` in the Megatron-LM CLI) — if the unfused path is -selected, the patched symbol is never called. +Both the fused (`config.cross_entropy_loss_fusion=True`, +`cross_entropy_fusion_impl='native'`) and unfused +(`config.cross_entropy_loss_fusion=False`) CE paths are patched in a single +call, so Megatron picks up Liger regardless of which path your config selects. + +For training setups that need explicit kernel configuration (custom +`ignore_index`, `label_smoothing`, etc.), instantiate +`LigerMegatronCrossEntropy` directly and wire it into your model — see +`examples/megatron/run_mode2_hand_spec.py`. ::: liger_kernel.megatron.apply_liger_kernel_to_megatron options: diff --git a/examples/megatron/run_mode1_monkey_patch.py b/examples/megatron/run_mode1_monkey_patch.py index 449e2b042..31638a06a 100644 --- a/examples/megatron/run_mode1_monkey_patch.py +++ b/examples/megatron/run_mode1_monkey_patch.py @@ -3,9 +3,8 @@ Adapted from Megatron's ``examples/run_simple_mcore_train_loop.py``. The relevant additions (vs. that file) are: - 1. ``apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True, - label_smoothing=0.1)`` called once at the top of ``model_provider()``. - This patches: + 1. ``apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True)`` + called once at the top of ``model_provider()``. This patches: - ``LocalSpecProvider.layer_norm`` (per-layer norm slots) - ``transformer_block.LayerNormImpl`` (block-level ``final_layernorm``) - ``fused_cross_entropy.fused_vocab_parallel_cross_entropy`` @@ -75,11 +74,7 @@ def initialize_distributed(tp: int = 2, pp: int = 1) -> None: def model_provider() -> GPTModel: # ↓↓ Mode 1 — patch once, everything below picks up Liger ↓↓ - apply_liger_kernel_to_megatron( - rms_norm=True, - cross_entropy=True, - label_smoothing=0.1, - ) + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True) # ↑↑ ------------------------------------------------------ ↑↑ cfg = TransformerConfig( diff --git a/src/liger_kernel/megatron/monkey_patch.py b/src/liger_kernel/megatron/monkey_patch.py index 9c1c42438..4ed19f948 100644 --- a/src/liger_kernel/megatron/monkey_patch.py +++ b/src/liger_kernel/megatron/monkey_patch.py @@ -12,17 +12,13 @@ def apply_liger_kernel_to_megatron( rms_norm: bool = True, cross_entropy: bool = False, - *, - ignore_index: int = -100, - label_smoothing: float = 0.0, - reduction: str = "none", ) -> None: """Patch Megatron-Core to use Liger Triton kernels. Idempotent. Targets Megatron's ``BackendSpecProvider``, - ``transformer_block.LayerNormImpl``, and (optionally) - ``fused_vocab_parallel_cross_entropy`` so models that route through the - standard spec system pick up Liger without per-model code. + ``transformer_block.LayerNormImpl``, and (optionally) both of Megatron's + vocab-parallel cross-entropy entry points so models that route through + the standard spec system pick up Liger without per-model code. Args: rms_norm: When ``True`` (default) replace both @@ -30,22 +26,17 @@ def apply_liger_kernel_to_megatron( ``transformer_block.LayerNormImpl`` (the block-level ``final_layernorm`` slot) so all RMSNorm modules in the model become ``LigerMegatronRMSNorm``. - cross_entropy: When ``True`` replace + cross_entropy: When ``True`` replace both ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy`` - with Liger's Triton cross-entropy (online softmax, in-place - gradients, no full-softmax materialization). Default ``False`` - because this path currently supports - ``tensor_model_parallel_size=1`` only. - ignore_index: Cross-entropy ignore index. Only used when - ``cross_entropy=True``. - label_smoothing: Cross-entropy label smoothing factor. Liger does not - auto-detect this — pass ``cfg.label_smoothing_factor`` (or - equivalent) from your Megatron ``TransformerConfig`` if label - smoothing is enabled, to preserve native behavior. Only used when - ``cross_entropy=True``. - reduction: Must be ``"none"`` when ``cross_entropy=True``; Megatron's - fused-CE contract returns per-token loss shaped ``[seq, batch]`` - and handles reduction itself downstream. + (fused path) and + ``megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy`` + (unfused path) with Liger's Triton cross-entropy. Default + ``False`` because this path currently supports + ``tensor_model_parallel_size=1`` only. The fused wrapper matches + native's ``(logits, target, tp_group)`` signature exactly; the + unfused wrapper additionally honors a runtime ``label_smoothing`` + argument, matching native's + ``(logits, target, label_smoothing=0.0, tp_group=None)``. Notes: Call this BEFORE building your model. Patching after instantiation @@ -57,6 +48,12 @@ def apply_liger_kernel_to_megatron( ``TELayerNormColumnParallelLinear`` folds the norm into the QKV linear; naive substitution would either double-norm or skip the norm. + For explicit kernel configuration (custom ``ignore_index``, + ``label_smoothing``, etc.) instantiate ``LigerMegatronCrossEntropy`` + directly and wire it into your model (Mode 2). The monkey-patch path + is intentionally a transparent drop-in: it matches Megatron's native + defaults so callers can flip Liger on without touching loss config. + Raises: RuntimeError: When ``cross_entropy=True`` and Megatron's parallel state already reports ``tensor_model_parallel_size > 1``. @@ -66,16 +63,8 @@ def apply_liger_kernel_to_megatron( _patch_transformer_block_layernorm_impl() if cross_entropy: _check_tensor_parallel_size_at_patch_time() - _patch_fused_vocab_parallel_cross_entropy( - reduction=reduction, - ignore_index=ignore_index, - label_smoothing=label_smoothing, - ) - _patch_vocab_parallel_cross_entropy( - reduction=reduction, - ignore_index=ignore_index, - label_smoothing=label_smoothing, - ) + _patch_fused_vocab_parallel_cross_entropy() + _patch_vocab_parallel_cross_entropy() def _check_tensor_parallel_size_at_patch_time() -> None: @@ -193,23 +182,14 @@ def __new__(cls, config, hidden_size, eps=1e-5, **kwargs): _LABEL_SMOOTHING_UNSET = object() -def _patch_fused_vocab_parallel_cross_entropy( - reduction: str = "none", - ignore_index: int = -100, - label_smoothing: float = 0.0, -) -> None: +def _patch_fused_vocab_parallel_cross_entropy() -> None: """Replace ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``. - Wraps a single ``LigerMegatronCrossEntropy`` instance (configured at patch time) in - a closure matching Megatron's fused-CE signature ``(logits, target, tp_group)``. - Idempotent: a sentinel attribute on the replacement prevents wrappers from stacking. + Wraps a single ``LigerMegatronCrossEntropy`` instance (constructed with class defaults + that match Megatron's native fused-CE behavior) in a closure matching Megatron's fused-CE + signature ``(logits, target, tp_group)``. Idempotent: a sentinel attribute on the + replacement prevents wrappers from stacking. """ - if reduction != "none": - raise ValueError( - f"Megatron's fused_vocab_parallel_cross_entropy contract requires per-token loss; " - f"reduction must be 'none', got {reduction!r}." - ) - try: import megatron.core.fusions.fused_cross_entropy as fused_ce except ImportError as exc: @@ -233,11 +213,7 @@ def _patch_fused_vocab_parallel_cross_entropy( from liger_kernel.megatron.cross_entropy import LigerMegatronCrossEntropy - ce = LigerMegatronCrossEntropy( - ignore_index=ignore_index, - label_smoothing=label_smoothing, - reduction=reduction, - ) + ce = LigerMegatronCrossEntropy() def liger_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_group=None): return ce(vocab_parallel_logits, target, tp_group=tp_group) @@ -252,25 +228,15 @@ def liger_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_g ) -def _patch_vocab_parallel_cross_entropy( - reduction: str = "none", - ignore_index: int = -100, - label_smoothing: float = 0.0, -) -> None: +def _patch_vocab_parallel_cross_entropy() -> None: """Replace ``megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy``. This is Megatron's *unfused* eager-Python vocab-parallel CE path, dispatched to when ``config.cross_entropy_loss_fusion=False``. Its signature accepts ``label_smoothing`` at call time, so the wrapper honors a runtime value when the caller actually passed one. A sentinel disambiguates "caller passed 0.0" (use 0.0) from "caller didn't pass" - (use patch-time default). + (use class default). """ - if reduction != "none": - raise ValueError( - f"vocab_parallel_cross_entropy returns per-token loss; " - f"reduction must be 'none', got {reduction!r}." - ) - try: import megatron.core.tensor_parallel.cross_entropy as unfused_ce except ImportError as exc: @@ -294,13 +260,11 @@ def _patch_vocab_parallel_cross_entropy( from liger_kernel.megatron.cross_entropy import LigerMegatronCrossEntropy - # Patch-time default; reused for every call where the caller doesn't pass label_smoothing. - # Avoids allocating a fresh module per CE call in the common case. - default_ce = LigerMegatronCrossEntropy( - ignore_index=ignore_index, - label_smoothing=label_smoothing, - reduction=reduction, - ) + # Class-default instance; reused for every call where the caller doesn't pass + # label_smoothing. Avoids allocating a fresh module per CE call in the common case + # (Megatron's own LanguageModule.compute_language_model_loss dispatch does not pass + # label_smoothing — it always lands here). + default_ce = LigerMegatronCrossEntropy() def liger_vocab_parallel_cross_entropy( vocab_parallel_logits, @@ -315,11 +279,7 @@ def liger_vocab_parallel_cross_entropy( # construction is microseconds vs. CE-kernel milliseconds. if label_smoothing is _LABEL_SMOOTHING_UNSET: return default_ce(vocab_parallel_logits, target, tp_group=tp_group) - ce = LigerMegatronCrossEntropy( - ignore_index=ignore_index, - label_smoothing=label_smoothing, - reduction=reduction, - ) + ce = LigerMegatronCrossEntropy(label_smoothing=label_smoothing) return ce(vocab_parallel_logits, target, tp_group=tp_group) setattr(liger_vocab_parallel_cross_entropy, _PATCH_MARKER, True) diff --git a/test/megatron/test_monkey_patch.py b/test/megatron/test_monkey_patch.py index e49110ab2..f94215e8c 100644 --- a/test/megatron/test_monkey_patch.py +++ b/test/megatron/test_monkey_patch.py @@ -7,10 +7,11 @@ (``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``) AND the unfused symbol (``megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy``) -- ``reduction != "none"`` is rejected - TP>1 at patch time raises ``RuntimeError`` - missing megatron-core / missing symbol path raise helpful ``ImportError``\\s -- the constructed ``LigerMegatronCrossEntropy`` receives the user-supplied kwargs +- the patch constructs ``LigerMegatronCrossEntropy`` with class defaults (matches Megatron + native behavior — no CE-specific kwargs on the public ``apply_liger_kernel_to_megatron`` + API) - the unfused wrapper honors a runtime ``label_smoothing`` override (Megatron's unfused signature is ``(logits, target, label_smoothing=0.0, tp_group=None)``) """ @@ -208,17 +209,10 @@ def size(self): # --------------------------------------------------------------------------- -# Argument validation + TP-1 guard. +# TP-1 guard. # --------------------------------------------------------------------------- -def test_patch_rejects_non_none_reduction(fake_megatron): - from liger_kernel.megatron import apply_liger_kernel_to_megatron - - with pytest.raises(ValueError, match="reduction must be 'none'"): - apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True, reduction="mean") - - def test_patch_raises_on_tp_greater_than_one(): _install_fake_megatron(tp_size=2) try: @@ -294,20 +288,21 @@ def test_patch_raises_when_unfused_symbol_missing(): # --------------------------------------------------------------------------- -# Kwarg propagation + runtime label_smoothing override on the unfused path. +# Class-default construction + runtime label_smoothing override on the unfused path. # --------------------------------------------------------------------------- -def test_patch_forwards_ignore_index_and_label_smoothing(fake_megatron): - """Patch-time ignore_index + label_smoothing reach the underlying LigerMegatronCrossEntropy. +def test_patch_constructs_ce_with_class_defaults(fake_megatron): + """The public ``apply_liger_kernel_to_megatron`` API exposes no CE-specific kwargs; + the patch must therefore construct ``LigerMegatronCrossEntropy`` with class defaults. - Both the fused and unfused wrappers create an instance configured with these kwargs. - """ + This intentionally matches Megatron's native fused-CE behavior (no ignore_index, no + label_smoothing). Callers needing custom config use ``LigerMegatronCrossEntropy`` + directly (Mode 2).""" from liger_kernel.megatron import apply_liger_kernel_to_megatron from liger_kernel.megatron import cross_entropy as ce_mod captured = [] - real_ctor = ce_mod.LigerMegatronCrossEntropy.__init__ def recording_init(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): @@ -319,15 +314,12 @@ def recording_init(self, ignore_index=-100, label_smoothing=0.0, reduction="none real_ctor(self, ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction) with patch.object(ce_mod.LigerMegatronCrossEntropy, "__init__", recording_init): - apply_liger_kernel_to_megatron( - rms_norm=False, cross_entropy=True, ignore_index=42, label_smoothing=0.25, - ) + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) - # The fused wrapper builds 1 instance; the unfused wrapper builds 1 default instance. - # Both should carry the user-supplied kwargs. + # Fused wrapper builds 1 instance; unfused wrapper builds 1 default instance. assert len(captured) >= 2 for entry in captured: - assert entry == {"ignore_index": 42, "label_smoothing": 0.25, "reduction": "none"} + assert entry == {"ignore_index": -100, "label_smoothing": 0.0, "reduction": "none"} def test_unfused_wrapper_honors_runtime_label_smoothing(fake_megatron): @@ -360,9 +352,7 @@ def __call__(self, logits, target, tp_group=None): return torch.zeros(logits.shape[:2]) with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): - apply_liger_kernel_to_megatron( - rms_norm=False, cross_entropy=True, ignore_index=-100, label_smoothing=0.05, - ) + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) # Reset the recorder to focus on calls triggered by the next line. constructed.clear() @@ -395,9 +385,7 @@ def __call__(self, logits, target, tp_group=None): return torch.zeros(logits.shape[:2]) with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): - apply_liger_kernel_to_megatron( - rms_norm=False, cross_entropy=True, ignore_index=-100, label_smoothing=0.05, - ) + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) constructed.clear() logits = torch.zeros(2, 1, 4) target = torch.zeros(2, 1, dtype=torch.long) @@ -435,9 +423,7 @@ def __call__(self, logits, target, tp_group=None): return torch.zeros(logits.shape[:2]) with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): - apply_liger_kernel_to_megatron( - rms_norm=False, cross_entropy=True, ignore_index=-100, label_smoothing=0.05, - ) + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) constructed.clear() logits = torch.zeros(2, 1, 4) target = torch.zeros(2, 1, dtype=torch.long) From 6b2cdae36b759fae265b4fc5e512ab4fdff22551 Mon Sep 17 00:00:00 2001 From: Vaibhav Jindal Date: Wed, 10 Jun 2026 06:55:04 +0000 Subject: [PATCH 4/7] Megatron CE: expand test suites mirroring test/transformers patterns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The existing suites covered the happy path; sweeping the transformers side surfaced two gaps worth closing on the Megatron mirror: * test_cross_entropy.py only tested the trivial shape/dtype matrix. Missing: scalar-magnitude sweep, broader ignore_index sweep, combined label_smoothing × ignore_index, gradient parity when CE isn't the last op in the graph, out-of-bounds-target guard, forward-only (torch.no_grad) path. * test_monkey_patch.py asserted symbol identity ("the patched function exists on the fake module") but never invoked the wrapper with real tensors. So a wrapper-math regression could pass every existing test. Added an end-to-end integration block that applies the patch and then calls the resulting symbol against F.cross_entropy. test/megatron/test_cross_entropy.py Six new tests adapted from test/transformers/test_cross_entropy.py to the [s, b, v] -> [s, b] reshape contract: - test_class_correctness_scalar_sweep (mirrors `test_correctness` with `scalar` param) - test_class_correctness_with_ignore_index_sweep (broader sentinel sweep incl. positive/weird negatives) - test_class_correctness_with_label_smoothing_and_ignore_index (the combined branch where smoothing math historically broke) - test_class_correctness_not_last_layer (loss × downstream factor, verify in-place grad survives) - test_class_rejects_out_of_bounds_target (kernel-level OOB assert reachable through the wrapper) - test_class_correctness_forward_only (torch.no_grad parity + post-hoc .backward() raises) Plus an _assign_ignore_index helper matching the transformers-side pattern of randomizing the masked-out positions. test/megatron/test_monkey_patch.py Ten new tests: Public-API surface (mirrors transformers' test_import_from_root): - test_import_from_root - test_public_apply_function_has_no_ce_specific_kwargs (guards against re-introducing the kwargs just trimmed) End-to-end integration (closes the wrapper-math gap): - test_patched_fused_symbol_computes_correct_loss - test_patched_unfused_symbol_computes_correct_loss - test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch - test_patched_fused_symbol_preserves_gradients - test_patched_unfused_symbol_preserves_gradients - test_patched_fused_symbol_default_ignore_index_minus_100 (pins the "Liger > native fused" behavior — native silently garbles target=-100; Liger zeros it) Symmetric split: - test_rms_norm_only_patch_does_not_touch_ce_symbols 95/95 megatron tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- test/megatron/test_cross_entropy.py | 250 ++++++++++++++++++++++++++++ test/megatron/test_monkey_patch.py | 241 +++++++++++++++++++++++++++ 2 files changed, 491 insertions(+) diff --git a/test/megatron/test_cross_entropy.py b/test/megatron/test_cross_entropy.py index 534a9a57a..eabf2ccec 100644 --- a/test/megatron/test_cross_entropy.py +++ b/test/megatron/test_cross_entropy.py @@ -188,3 +188,253 @@ def test_class_extra_repr(): assert "ignore_index=42" in rep assert "label_smoothing=0.07" in rep assert "reduction='none'" in rep + + +# --------------------------------------------------------------------------- +# Beefier sweeps adapted from test/transformers/test_cross_entropy.py. +# +# LigerMegatronCrossEntropy is a [s, b, v] -> [s, b] reshape around Liger's +# CE op (reduction='none'). Numerical behavior should match the kernel itself, +# so the same parametrization patterns the kernel suite uses are the right +# coverage shape here — just adapted to the 3-D contract. +# --------------------------------------------------------------------------- + + +def _assign_ignore_index(target: torch.Tensor, ignore_index: int, frac: float = 0.25) -> None: + """In-place: replace ~frac of target positions with ignore_index. + + Matches the transformers-side helpers that randomize the masked-out indices + so the test isn't degenerate on a particular row layout. + """ + flat = target.view(-1) + n = max(1, int(flat.numel() * frac)) + idx = torch.randperm(flat.numel(), device=flat.device)[:n] + flat[idx] = ignore_index + + +@pytest.mark.parametrize( + "s, b, v", + [ + (16, 1, 4096), + (32, 2, 32000), # llama-ish vocab + (5, 3, 123), # weird shape + ], +) +@pytest.mark.parametrize("scalar", [0.5, 1.0, 5.0]) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_scalar_sweep(s, b, v, scalar, dtype, atol, rtol): + """Vary input magnitude — guards against numerical drift at large logit scales + (mirrors the ``scalar`` parametrize in ``test/transformers/test_cross_entropy.py``).""" + ce = LigerMegatronCrossEntropy() + + base = torch.randn(s, b, v, device=device, dtype=dtype) * scalar + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + # Backward parity: feed the same starting tensor through both paths. + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=0.0) + got = ce(h_got, target) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "s, b, v, ignore_index", + [ + (16, 1, 4096, -100), # standard hf sentinel + (32, 2, 32000, 2), # positive id (valid vocab slot used as ignore) + (5, 3, 123, -123), # weird negative + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_with_ignore_index_sweep(s, b, v, ignore_index, dtype, atol, rtol): + """Broader ignore_index sweep including positive/negative sentinels and forward+backward + correctness vs. PyTorch's reference. Mirrors transformers-side ``test_correctness_with_ignore_index``.""" + ce = LigerMegatronCrossEntropy(ignore_index=ignore_index) + + base = torch.randn(s, b, v, device=device, dtype=dtype) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + _assign_ignore_index(target, ignore_index, frac=0.3) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=ignore_index, label_smoothing=0.0) + got = ce(h_got, target) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "s, b, v, ignore_index, label_smoothing", + [ + (16, 1, 4096, 1, 0.1), + (32, 2, 32000, -100, 0.2), + (5, 3, 123, -300, 0.05), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-6, 1e-5), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_with_label_smoothing_and_ignore_index( + s, b, v, ignore_index, label_smoothing, dtype, atol, rtol, +): + """Combined ignore_index × label_smoothing sweep — the two are independent in Liger's CE + kernel but mixing them historically surfaced bugs in the smoothing math. Mirrors + ``test_correctness_with_label_smoothing_with_ignore_index_once`` from the kernel suite.""" + ce = LigerMegatronCrossEntropy(ignore_index=ignore_index, label_smoothing=label_smoothing) + + base = torch.randn(s, b, v, device=device, dtype=dtype) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + _assign_ignore_index(target, ignore_index, frac=0.25) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=ignore_index, label_smoothing=label_smoothing) + got = ce(h_got, target) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "s, b, v", + [ + (16, 1, 4096), + (5, 3, 123), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-6, 1e-5), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_not_last_layer(s, b, v, dtype, atol, rtol): + """Loss is multiplied by a downstream factor before ``.backward(grad_output)`` — verifies + that Liger's in-place gradient write through the wrapper survives non-trivial chained + autograd (i.e. CE isn't the last op in the graph). Mirrors transformers-side + ``test_correctness_not_last_layer``.""" + ce = LigerMegatronCrossEntropy() + + base = torch.randn(s, b, v, device=device, dtype=dtype) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=0.0) + got = ce(h_got, target) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + # Chain: loss = ref * 3 then backward with arbitrary grad_output. + loss_ref = ref * 3.0 + loss_got = got * 3.0 + grad_out = torch.rand_like(ref) + loss_ref.backward(gradient=grad_out) + loss_got.backward(gradient=grad_out) + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("ignore_index", [-100, 2]) +def test_class_rejects_out_of_bounds_target(ignore_index): + """Liger's CE kernel asserts target ∈ [0, V); a stray out-of-bounds target should + raise rather than silently produce garbage. Mirrors transformers-side + ``test_correctness_with_out_of_bounds_target_once``.""" + s, b, v = 8, 2, 64 + ce = LigerMegatronCrossEntropy(ignore_index=ignore_index) + + logits = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + # Plant a couple of out-of-bounds values; ignore_index is permitted but the + # >=V poisoned slots are not. + flat = target.view(-1) + poison = torch.randperm(flat.numel(), device=flat.device)[:2] + flat[poison] = v + 5 # >= V; the kernel-level assert should fire. + + with pytest.raises(AssertionError, match="out of bounds"): + ce(logits, target) + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_forward_only(dtype, atol, rtol): + """Forward-only path (under ``torch.no_grad()``) — verifies the wrapper still returns the + right loss when autograd is disabled, AND that a subsequent ``.backward()`` raises the + expected "does not require grad" error. Mirrors transformers-side ``test_correctness_with_forward_only``.""" + s, b, v = 16, 2, 1024 + ce = LigerMegatronCrossEntropy() + + logits_input = torch.randn(s, b, v, device=device, dtype=dtype) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + with torch.no_grad(): + # Clone the input separately for each path because Liger writes gradient + # state in-place; sharing a buffer would corrupt the reference. + ref = _reference_loss(logits_input.clone(), target, ignore_index=-100, label_smoothing=0.0) + got = ce(logits_input.clone(), target) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + # Attempting backward on a forward-only output should raise. + with pytest.raises(RuntimeError, match="does not require grad"): + got.sum().backward() diff --git a/test/megatron/test_monkey_patch.py b/test/megatron/test_monkey_patch.py index f94215e8c..5f57fb8a6 100644 --- a/test/megatron/test_monkey_patch.py +++ b/test/megatron/test_monkey_patch.py @@ -435,3 +435,244 @@ def __call__(self, logits, target, tp_group=None): f"explicit label_smoothing=0.0 at call time must be honored verbatim; " f"got: {constructed}" ) + + +# --------------------------------------------------------------------------- +# Public-API surface checks (mirrors transformers-side ``test_import_from_root`` +# and ``test_apply_liger_kernel_only_passes_valid_kwargs`` patterns). +# --------------------------------------------------------------------------- + + +def test_import_from_root(): + """All public Megatron symbols must be reachable from ``liger_kernel.megatron``. + + Mirrors the import-smoke pattern from ``test/transformers/test_monkey_patch.py``: catches + accidental __init__.py removals so the docs' import snippets keep working.""" + try: + from liger_kernel.megatron import LigerMegatronCrossEntropy # noqa: F401 + from liger_kernel.megatron import LigerMegatronRMSNorm # noqa: F401 + from liger_kernel.megatron import apply_liger_kernel_to_megatron # noqa: F401 + except Exception: + pytest.fail("Importing public Megatron symbols from liger_kernel.megatron failed.") + + +def test_public_apply_function_has_no_ce_specific_kwargs(): + """The framework-level patch entry point intentionally hides CE-specific knobs + (ignore_index, label_smoothing, reduction). Catch accidental re-introduction — + Mode-2 callers use ``LigerMegatronCrossEntropy`` directly for that config surface.""" + import inspect + + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + sig = inspect.signature(apply_liger_kernel_to_megatron) + leaked = {"ignore_index", "label_smoothing", "reduction"} & set(sig.parameters) + assert not leaked, ( + f"apply_liger_kernel_to_megatron has re-grown CE-specific kwargs: {sorted(leaked)}. " + f"Those belong on LigerMegatronCrossEntropy, not on the framework patch entry point." + ) + + +# --------------------------------------------------------------------------- +# End-to-end integration through the patched symbols. +# +# Earlier tests verify symbol identity + stub plumbing; the suite was missing +# the "patch + call with real tensors + check the numbers" coverage. These +# tests install the fake megatron, apply the patch, then invoke the resulting +# wrapper with live torch tensors and compare against ``F.cross_entropy``. +# That's the only way to catch wrapper-math bugs that pass the identity tests. +# --------------------------------------------------------------------------- + + +import torch # noqa: E402 (deferred so the no-torch import-smoke tests above are unaffected) +import torch.nn.functional as F # noqa: E402 + +from liger_kernel.utils import infer_device # noqa: E402 +from test.utils import assert_verbose_allclose # noqa: E402 + +_device = infer_device() + + +def _ref_loss_sbv(logits_sbv: torch.Tensor, target_sb: torch.Tensor, + ignore_index: int = -100, label_smoothing: float = 0.0) -> torch.Tensor: + """Reference CE for [s, b, v] logits / [s, b] target, returning [s, b].""" + s, b, v = logits_sbv.shape + loss_flat = F.cross_entropy( + logits_sbv.reshape(-1, v).float(), + target_sb.reshape(-1), + reduction="none", + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) + return loss_flat.reshape(s, b) + + +def test_patched_fused_symbol_computes_correct_loss(fake_megatron): + """End-to-end: install stub megatron, patch, invoke the resulting fused symbol with real + [s, b, v] logits, verify the loss matches ``F.cross_entropy``. Closes the gap between + "patch wired correctly" (existing tests) and "patched function does the right math".""" + fused_ce, _ = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 16, 2, 1024 + torch.manual_seed(0) + logits = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + ref = _ref_loss_sbv(logits.clone(), target) + # Call through the patched symbol. tp_group=None is what Megatron's + # LanguageModule passes when TP is uninitialized. + got = fused_ce.fused_vocab_parallel_cross_entropy(logits.clone(), target, None) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + +def test_patched_unfused_symbol_computes_correct_loss(fake_megatron): + """Same as the fused case, but through the unfused symbol — verifies both wrappers + are exercised and exercises the no-label_smoothing default branch (caller doesn't pass).""" + _, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 4, 512 + torch.manual_seed(1) + logits = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + ref = _ref_loss_sbv(logits.clone(), target) + # Native unfused signature: (logits, target, label_smoothing=0.0, tp_group=None). + # Pass only positional args the caller normally would. + got = unfused_ce.vocab_parallel_cross_entropy(logits.clone(), target) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) +def test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch(fake_megatron, label_smoothing): + """The unfused wrapper's main feature beyond the fused path is honoring a runtime + label_smoothing arg. Verify the resulting loss actually matches + ``F.cross_entropy(..., label_smoothing=...)``, not just that a fresh CE instance is built.""" + _, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 2, 256 + torch.manual_seed(2) + logits = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + ref = _ref_loss_sbv(logits.clone(), target, label_smoothing=label_smoothing) + got = unfused_ce.vocab_parallel_cross_entropy( + logits.clone(), target, label_smoothing=label_smoothing, + ) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4) + + +def test_patched_fused_symbol_preserves_gradients(fake_megatron): + """Backward through the patched fused symbol: gradient shape + parity vs. + PyTorch's reference. Liger writes the gradient back into the input buffer, + so verifying ``.grad`` after backward exercises both the reshape contract + and the in-place write.""" + fused_ce, _ = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 2, 256 + torch.manual_seed(3) + base = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + ref = _ref_loss_sbv(h_ref, target) + got = fused_ce.fused_vocab_parallel_cross_entropy(h_got, target, None) + + ref.sum().backward() + got.sum().backward() + + assert h_got.grad is not None + assert h_got.grad.shape == h_got.shape + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) + + +def test_patched_unfused_symbol_preserves_gradients(fake_megatron): + """Symmetric to the fused-gradient test; ensures the closure in + ``_patch_vocab_parallel_cross_entropy`` doesn't break autograd.""" + _, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 2, 256 + torch.manual_seed(4) + base = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + ref = _ref_loss_sbv(h_ref, target) + got = unfused_ce.vocab_parallel_cross_entropy(h_got, target) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) + + +def test_patched_fused_symbol_default_ignore_index_minus_100(fake_megatron): + """Patch-time defaults: targets containing -100 should be treated as ignored — Liger's + kernel zeros those loss positions, matching ``F.cross_entropy(ignore_index=-100)``. + + Native Megatron's fused CE has no ignore_index concept and would silently produce + garbage on -100; this is one place where Liger is strictly better than the symbol + it replaces, and the test pins that behavior.""" + fused_ce, _ = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 2, 128 + torch.manual_seed(5) + logits = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + # Plant some -100 sentinel positions. + flat = target.view(-1) + flat[: flat.numel() // 4] = -100 + + ref = _ref_loss_sbv(logits.clone(), target, ignore_index=-100) + got = fused_ce.fused_vocab_parallel_cross_entropy(logits.clone(), target, None) + + # Per-token loss at masked positions should be exactly 0. + mask = (target != -100).float() + assert torch.all(got * (1 - mask) == 0) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + +def test_rms_norm_only_patch_does_not_touch_ce_symbols(fake_megatron): + """Symmetric to ``test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched``, + but for the opposite split. With ``rms_norm=True, cross_entropy=False`` (RMSNorm + helpers require real megatron and will ImportError on the stub — that's fine, we + only need to confirm the CE symbols are not pre-emptively touched before the RMSNorm + helpers run). Documenting this protects against future apply_… reorderings that would + silently couple the two.""" + fused_ce, unfused_ce = fake_megatron + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + fused_before = fused_ce.fused_vocab_parallel_cross_entropy + unfused_before = unfused_ce.vocab_parallel_cross_entropy + + # RMSNorm helpers do their own megatron import; on the stub they'll raise. Catch + # any exception so the assertion at the end runs regardless — we only care that the + # CE symbols weren't touched. + try: + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + except Exception: + pass + + assert fused_ce.fused_vocab_parallel_cross_entropy is fused_before + assert unfused_ce.vocab_parallel_cross_entropy is unfused_before From 57e1ec98a167f05a19d6ddad26c3c2102c53a184 Mon Sep 17 00:00:00 2001 From: Vaibhav Jindal Date: Wed, 10 Jun 2026 07:14:17 +0000 Subject: [PATCH 5/7] Megatron tests: enforce grad parity on CE basics; split monkey-patch suite per-kernel + add RMSNorm coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two coverage gaps in the existing megatron suites, closed: * test_cross_entropy.py's three base correctness tests (matches_pytorch, respects_ignore_index, respects_label_smoothing) were forward-only. test_class_preserves_gradients was misleadingly named — it asserted grad is not None + shape, never that the gradient value was correct. So a CE backward-math regression could slip past every base test. * test_monkey_patch.py was scoped to cross-entropy (per its docstring and stub installer name) even though apply_liger_kernel_to_megatron patches both RMSNorm and CE. RMSNorm patch behavior had zero monkey-patch-side coverage — replacement, idempotency, the spec-provider dispatch, the "skip non-WrappedTorchNorm fallback" contract were all unverified at the patch level. test/megatron/test_cross_entropy.py - test_class_matches_pytorch_cross_entropy: now feeds independent clones to ref and Liger paths (Liger writes grad in-place to the input; sharing the buffer would corrupt the reference), then asserts h_got.grad matches h_ref.grad after backward through both. - test_class_respects_ignore_index, test_class_respects_label_smoothing: same upgrade — full forward + backward parity vs. F.cross_entropy. - test_class_preserves_gradients: tightened body to assert grad value parity against PyTorch's reference, not just "grad is not None". Name preserved; docstring updated to call out the change. test/megatron/test_monkey_patch.py Restructure: - Header docstring rewritten to describe per-kernel structure; sections numbered 1-5 so future kernels append cleanly. - _install_fake_megatron → _install_fake_megatron_ce (symmetric with the new _install_fake_megatron_rms_norm). - Fixture fake_megatron → fake_megatron_ce. - New _ensure_megatron_roots helper so the two installers can be called in either order without clobbering shared root modules. - _uninstall_fake_megatron extended to clean both kernel slices. New RMSNorm stub: - _install_fake_megatron_rms_norm stubs megatron.core.models.backends (LocalSpecProvider with a layer_norm method returning a sentinel) and megatron.core.transformer.{transformer_block,torch_norm} (LayerNormImpl, WrappedTorchNorm with __new__ returning a sentinel instance). Toggle layer_norm_is_wrapped_torch_norm to test the block-level skip path. - fake_megatron_rms_norm fixture. New RMSNorm tests (10), grouped 3.1–3.3: Replacement & idempotency: - test_rms_norm_patch_replaces_local_spec_provider_layer_norm - test_rms_norm_patch_replaces_transformer_block_layernorm_impl - test_rms_norm_patch_replaces_both_symbols_in_one_call - test_rms_norm_patch_with_rms_norm_false_leaves_norm_symbols_untouched - test_rms_norm_patch_is_idempotent Dispatch through patched targets: - test_rms_norm_patch_local_spec_provider_returns_liger_for_rms_norm_true - test_rms_norm_patch_local_spec_provider_delegates_for_rms_norm_false - test_rms_norm_patch_transformer_block_routes_rmsnorm_through_liger (instantiates the wrapping class with a RMSNorm config and confirms an actual LigerMegatronRMSNorm instance comes back; no kernel call needed, runs on CPU) - test_rms_norm_patch_transformer_block_routes_layernorm_through_original Block-level skip contract: - test_rms_norm_patch_transformer_block_skips_when_layer_norm_is_not_wrapped_torch_norm (verifies Liger does not displace TE / Apex norm fallbacks) 105/105 megatron tests pass (was 95: +4 stronger CE grad-parity checks across existing tests' parametrizations, +10 RMSNorm patch tests). Co-Authored-By: Claude Opus 4.7 (1M context) --- test/megatron/test_cross_entropy.py | 70 ++++- test/megatron/test_monkey_patch.py | 464 ++++++++++++++++++++++++---- 2 files changed, 452 insertions(+), 82 deletions(-) diff --git a/test/megatron/test_cross_entropy.py b/test/megatron/test_cross_entropy.py index eabf2ccec..589a5df4d 100644 --- a/test/megatron/test_cross_entropy.py +++ b/test/megatron/test_cross_entropy.py @@ -66,17 +66,29 @@ def _reference_loss( ], ) def test_class_matches_pytorch_cross_entropy(s, b, v, dtype, atol, rtol): + """Headline correctness — forward AND backward parity vs. PyTorch's F.cross_entropy. + + Liger writes the gradient back into the input tensor in-place during forward; both paths + are fed independent clones of the same starting tensor so the in-place write on the Liger + side can't corrupt the reference path.""" ce = LigerMegatronCrossEntropy() - logits = torch.randn(s, b, v, device=device, dtype=dtype) * 0.5 + base = torch.randn(s, b, v, device=device, dtype=dtype) * 0.5 target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) - ref = _reference_loss(logits, target, ignore_index=-100, label_smoothing=0.0) - got = ce(logits, target) + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=0.0) + got = ce(h_got, target) assert got.shape == (s, b) assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + # --------------------------------------------------------------------------- # Configuration plumbing — wrapper-specific contracts. @@ -85,30 +97,49 @@ def test_class_matches_pytorch_cross_entropy(s, b, v, dtype, atol, rtol): @pytest.mark.parametrize("ignore_index", [-100, 0]) def test_class_respects_ignore_index(ignore_index): + """ignore_index plumbing — forward AND backward parity. Ignored positions must + contribute zero loss AND zero gradient on the Liger side.""" s, b, v = 16, 2, 1024 ce = LigerMegatronCrossEntropy(ignore_index=ignore_index) - logits = torch.randn(s, b, v, device=device, dtype=torch.float32) + base = torch.randn(s, b, v, device=device, dtype=torch.float32) target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) target.view(-1)[: (s * b) // 4] = ignore_index - ref = _reference_loss(logits, target, ignore_index=ignore_index, label_smoothing=0.0) - got = ce(logits, target) + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=ignore_index, label_smoothing=0.0) + got = ce(h_got, target) assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) + @pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_class_respects_label_smoothing(label_smoothing): + """label_smoothing plumbing — forward AND backward parity. Liger and PyTorch share the + same smoothing formula but with different intermediate kernels; gradient check guards + against algebraic-equivalence-but-numerical-divergence bugs.""" s, b, v = 8, 2, 512 ce = LigerMegatronCrossEntropy(label_smoothing=label_smoothing) - logits = torch.randn(s, b, v, device=device, dtype=torch.float32) + base = torch.randn(s, b, v, device=device, dtype=torch.float32) target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) - ref = _reference_loss(logits, target, ignore_index=-100, label_smoothing=label_smoothing) - got = ce(logits, target) + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=label_smoothing) + got = ce(h_got, target) assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4) + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-5, rtol=1e-4) + @pytest.mark.parametrize("bad_reduction", ["mean", "sum", "MEAN", "garbage"]) def test_class_rejects_non_none_reduction(bad_reduction): @@ -169,17 +200,28 @@ def size(self): def test_class_preserves_gradients(): + """Backward smoke test — grad exists with correct shape AND matches PyTorch's reference. + + Previously asserted only ``grad is not None`` + shape; that gave a misleadingly green test + when grad values were wrong. Now compares against ``F.cross_entropy(...).sum().backward()``.""" s, b, v = 8, 2, 256 ce = LigerMegatronCrossEntropy() - logits = torch.randn(s, b, v, device=device, dtype=torch.float32, requires_grad=True) + base = torch.randn(s, b, v, device=device, dtype=torch.float32) target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) - loss = ce(logits, target).sum() - loss.backward() + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=0.0) + got = ce(h_got, target) + + ref.sum().backward() + got.sum().backward() - assert logits.grad is not None - assert logits.grad.shape == logits.shape + assert h_got.grad is not None + assert h_got.grad.shape == h_got.shape + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) def test_class_extra_repr(): diff --git a/test/megatron/test_monkey_patch.py b/test/megatron/test_monkey_patch.py index 5f57fb8a6..e761bfc28 100644 --- a/test/megatron/test_monkey_patch.py +++ b/test/megatron/test_monkey_patch.py @@ -1,19 +1,25 @@ -"""Tests for ``apply_liger_kernel_to_megatron``'s cross-entropy patch mechanism. +"""Tests for ``apply_liger_kernel_to_megatron``'s patch mechanism. Megatron-LM is not a test dependency. We inject stub modules into ``sys.modules`` so the -patch helpers can run entirely on CPU without a real megatron-core install. Tests verify: +patch helpers can run entirely on CPU without a real megatron-core install. For each kernel +Liger patches into Megatron, this file verifies: -- the patch replaces both the fused symbol - (``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``) - AND the unfused symbol - (``megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy``) -- TP>1 at patch time raises ``RuntimeError`` +- the patched Megatron symbol(s) are actually replaced +- patching is idempotent (calling apply twice doesn't stack wrappers) +- the patch is a no-op when the kernel flag is False - missing megatron-core / missing symbol path raise helpful ``ImportError``\\s -- the patch constructs ``LigerMegatronCrossEntropy`` with class defaults (matches Megatron - native behavior — no CE-specific kwargs on the public ``apply_liger_kernel_to_megatron`` - API) -- the unfused wrapper honors a runtime ``label_smoothing`` override (Megatron's unfused - signature is ``(logits, target, label_smoothing=0.0, tp_group=None)``) +- kernel-specific dispatch contracts (e.g. CE TP>1 raises; RMSNorm only displaces the + ``WrappedTorchNorm`` fallback, not TE / Apex) +- end-to-end: the patched symbol invoked with real tensors produces correct output + +File layout — extend by appending a new ``-- patch --`` section per kernel +Liger learns to patch: + + 1. Stub-megatron installers + fixtures + 2. Cross-entropy patch tests + 3. RMSNorm patch tests + 4. Cross-kernel public-API surface tests + 5. End-to-end integration through patched CE symbols """ import sys @@ -24,18 +30,37 @@ import pytest -def _install_fake_megatron( +# =========================================================================== +# 1. Stub-megatron installers + fixtures +# =========================================================================== + + +def _ensure_megatron_roots(): + """Create / fetch ``megatron`` and ``megatron.core`` module stubs. + + Idempotent: a second installer (RMSNorm + CE called in either order) reuses + the roots rather than clobbering them, so a single test can stand up both + kernel surfaces at once. + """ + megatron = sys.modules.get("megatron") or types.ModuleType("megatron") + megatron_core = sys.modules.get("megatron.core") or types.ModuleType("megatron.core") + sys.modules["megatron"] = megatron + sys.modules["megatron.core"] = megatron_core + megatron.core = megatron_core + return megatron, megatron_core + + +def _install_fake_megatron_ce( tp_size: int = 1, with_fused_symbol: bool = True, with_unfused_symbol: bool = True, ): - """Install stub megatron modules into ``sys.modules``. + """Install the cross-entropy slice of the Megatron stub. Returns a tuple ``(fused_ce_module, unfused_ce_module)`` so tests can inspect what the patch helpers wrote onto them. """ - megatron = types.ModuleType("megatron") - megatron_core = types.ModuleType("megatron.core") + _, megatron_core = _ensure_megatron_roots() fusions = types.ModuleType("megatron.core.fusions") fused_ce = types.ModuleType("megatron.core.fusions.fused_cross_entropy") tensor_parallel = types.ModuleType("megatron.core.tensor_parallel") @@ -60,15 +85,12 @@ def original_vocab_parallel_cross_entropy( parallel_state.get_tensor_model_parallel_world_size = lambda: tp_size - sys.modules["megatron"] = megatron - sys.modules["megatron.core"] = megatron_core sys.modules["megatron.core.fusions"] = fusions sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused_ce sys.modules["megatron.core.tensor_parallel"] = tensor_parallel sys.modules["megatron.core.tensor_parallel.cross_entropy"] = unfused_ce sys.modules["megatron.core.parallel_state"] = parallel_state - megatron.core = megatron_core megatron_core.fusions = fusions megatron_core.tensor_parallel = tensor_parallel megatron_core.parallel_state = parallel_state @@ -78,13 +100,108 @@ def original_vocab_parallel_cross_entropy( return fused_ce, unfused_ce +def _install_fake_megatron_rms_norm( + layer_norm_is_wrapped_torch_norm: bool = True, + with_backends_module: bool = True, + with_transformer_block_module: bool = True, +): + """Install the RMSNorm slice of the Megatron stub. + + Returns ``(backends_module, transformer_block_module)`` so tests can inspect what the + patch helpers wrote onto them. Mirrors the CE installer's shape so future kernels can + grow alongside via the same pattern. + + Args: + layer_norm_is_wrapped_torch_norm: When True (default), seeds + ``transformer_block.LayerNormImpl`` as the stub ``WrappedTorchNorm``. The block- + level patch only displaces that fallback; set False to verify the no-op path + taken under TE / Apex. + """ + _, megatron_core = _ensure_megatron_roots() + + backends = None + if with_backends_module: + models = types.ModuleType("megatron.core.models") + backends = types.ModuleType("megatron.core.models.backends") + + class _OriginalNormSentinel: + """Sentinel class returned by the stub's ``layer_norm`` when ``rms_norm=False`` — + tests assert identity against this to confirm the patch delegated correctly.""" + + class _LocalSpecProvider: + def layer_norm(self, rms_norm=False, for_qk=False, has_residual=False): + # Echo the kwargs so tests can verify pass-through; the value returned is + # the sentinel class so identity checks work regardless. + return _OriginalNormSentinel + + backends.LocalSpecProvider = _LocalSpecProvider + backends._OriginalNormSentinel = _OriginalNormSentinel # exposed for tests + sys.modules["megatron.core.models"] = models + sys.modules["megatron.core.models.backends"] = backends + megatron_core.models = models + models.backends = backends + + transformer_block = None + if with_transformer_block_module: + transformer = types.ModuleType("megatron.core.transformer") + transformer_block = types.ModuleType("megatron.core.transformer.transformer_block") + torch_norm_mod = types.ModuleType("megatron.core.transformer.torch_norm") + + class _OriginalTorchNormInstance: + """Sentinel marker for "the original WrappedTorchNorm was instantiated." The stub + ``WrappedTorchNorm.__new__`` returns one of these; tests assert ``isinstance(...)``.""" + + def __init__(self, hidden_size, eps): + self.hidden_size = hidden_size + self.eps = eps + + class _WrappedTorchNorm: + """Stub of ``megatron.core.transformer.torch_norm.WrappedTorchNorm``.""" + + def __new__(cls, config=None, hidden_size=None, eps=1e-5, **kwargs): + return _OriginalTorchNormInstance(hidden_size=hidden_size, eps=eps) + + torch_norm_mod.WrappedTorchNorm = _WrappedTorchNorm + + if layer_norm_is_wrapped_torch_norm: + transformer_block.LayerNormImpl = _WrappedTorchNorm + else: + + class _SomeOtherNorm: + """Stand-in for TE / Apex LN — block-level patch should leave this alone.""" + + transformer_block.LayerNormImpl = _SomeOtherNorm + + # Expose sentinels for test assertions. + transformer_block._OriginalTorchNormInstance = _OriginalTorchNormInstance + transformer_block._WrappedTorchNorm = _WrappedTorchNorm + + sys.modules["megatron.core.transformer"] = transformer + sys.modules["megatron.core.transformer.transformer_block"] = transformer_block + sys.modules["megatron.core.transformer.torch_norm"] = torch_norm_mod + megatron_core.transformer = transformer + transformer.transformer_block = transformer_block + transformer.torch_norm = torch_norm_mod + + return backends, transformer_block + + def _uninstall_fake_megatron(): + """Tear down every stub module installed by either installer.""" for mod in [ + # CE side "megatron.core.parallel_state", "megatron.core.fusions.fused_cross_entropy", "megatron.core.fusions", "megatron.core.tensor_parallel.cross_entropy", "megatron.core.tensor_parallel", + # RMSNorm side + "megatron.core.models.backends", + "megatron.core.models", + "megatron.core.transformer.transformer_block", + "megatron.core.transformer.torch_norm", + "megatron.core.transformer", + # Shared roots "megatron.core", "megatron", ]: @@ -92,21 +209,35 @@ def _uninstall_fake_megatron(): @pytest.fixture -def fake_megatron(): - fused_ce, unfused_ce = _install_fake_megatron(tp_size=1) +def fake_megatron_ce(): + fused_ce, unfused_ce = _install_fake_megatron_ce(tp_size=1) try: yield fused_ce, unfused_ce finally: _uninstall_fake_megatron() +@pytest.fixture +def fake_megatron_rms_norm(): + backends, transformer_block = _install_fake_megatron_rms_norm() + try: + yield backends, transformer_block + finally: + _uninstall_fake_megatron() + + +# =========================================================================== +# 2. Cross-entropy patch tests +# =========================================================================== + + # --------------------------------------------------------------------------- -# Both symbols get replaced. +# 2.1 Both CE symbols get replaced. # --------------------------------------------------------------------------- -def test_patch_replaces_fused_symbol(fake_megatron): - fused_ce, _ = fake_megatron +def test_patch_replaces_fused_symbol(fake_megatron_ce): + fused_ce, _ = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron original = fused_ce.fused_vocab_parallel_cross_entropy @@ -116,8 +247,8 @@ def test_patch_replaces_fused_symbol(fake_megatron): assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" -def test_patch_replaces_unfused_symbol(fake_megatron): - _, unfused_ce = fake_megatron +def test_patch_replaces_unfused_symbol(fake_megatron_ce): + _, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron original = unfused_ce.vocab_parallel_cross_entropy @@ -127,9 +258,9 @@ def test_patch_replaces_unfused_symbol(fake_megatron): assert unfused_ce.vocab_parallel_cross_entropy.__name__ == "liger_vocab_parallel_cross_entropy" -def test_patch_replaces_both_fused_and_unfused_symbols_in_one_call(fake_megatron): +def test_patch_replaces_both_fused_and_unfused_symbols_in_one_call(fake_megatron_ce): """A single ``cross_entropy=True`` call must replace both Megatron CE paths.""" - fused_ce, unfused_ce = fake_megatron + fused_ce, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) @@ -138,9 +269,9 @@ def test_patch_replaces_both_fused_and_unfused_symbols_in_one_call(fake_megatron assert unfused_ce.vocab_parallel_cross_entropy.__name__ == "liger_vocab_parallel_cross_entropy" -def test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched(fake_megatron): +def test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched(fake_megatron_ce): """Default ``cross_entropy=False`` must not touch the CE symbols even if the call runs.""" - fused_ce, unfused_ce = fake_megatron + fused_ce, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron fused_before = fused_ce.fused_vocab_parallel_cross_entropy @@ -152,10 +283,10 @@ def test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched(fake_megatro assert unfused_ce.vocab_parallel_cross_entropy is unfused_before -def test_patch_is_idempotent_for_both_symbols(fake_megatron): +def test_patch_is_idempotent_for_both_symbols(fake_megatron_ce): """Calling ``apply_liger_kernel_to_megatron(cross_entropy=True)`` twice must not stack wrappers — the sentinel attribute guards against double-patching.""" - fused_ce, unfused_ce = fake_megatron + fused_ce, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) @@ -171,14 +302,14 @@ def test_patch_is_idempotent_for_both_symbols(fake_megatron): assert unfused_first.__wrapped__.__name__ == "original_vocab_parallel_cross_entropy" -def test_patch_fused_wrapper_passes_tp_group_through(fake_megatron): +def test_patch_fused_wrapper_passes_tp_group_through(fake_megatron_ce): """The fused wrapper closure must forward ``tp_group`` to the underlying class. We swap the CE class for a recording fake so the call doesn't need CUDA — just confirms ``tp_group`` reaches the class's ``__call__``.""" import torch - fused_ce, _ = fake_megatron + fused_ce, _ = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron from liger_kernel.megatron import cross_entropy as ce_mod @@ -209,12 +340,12 @@ def size(self): # --------------------------------------------------------------------------- -# TP-1 guard. +# 2.2 CE TP-1 guard. # --------------------------------------------------------------------------- def test_patch_raises_on_tp_greater_than_one(): - _install_fake_megatron(tp_size=2) + _install_fake_megatron_ce(tp_size=2) try: from liger_kernel.megatron import apply_liger_kernel_to_megatron @@ -226,7 +357,7 @@ def test_patch_raises_on_tp_greater_than_one(): def test_patch_defers_tp_check_when_parallel_state_not_initialized(): """If get_tensor_model_parallel_world_size() raises, patch should still succeed.""" - fused_ce, unfused_ce = _install_fake_megatron(tp_size=1) + fused_ce, unfused_ce = _install_fake_megatron_ce(tp_size=1) def raising_tp_size(): raise AssertionError("parallel_state not initialized") @@ -244,7 +375,7 @@ def raising_tp_size(): # --------------------------------------------------------------------------- -# Missing-megatron / missing-symbol errors. +# 2.3 CE missing-megatron / missing-symbol errors. # --------------------------------------------------------------------------- @@ -265,7 +396,7 @@ def blocking_import(name, *args, **kwargs): def test_patch_raises_when_fused_symbol_missing(): - _install_fake_megatron(tp_size=1, with_fused_symbol=False) + _install_fake_megatron_ce(tp_size=1, with_fused_symbol=False) try: from liger_kernel.megatron import apply_liger_kernel_to_megatron @@ -277,7 +408,7 @@ def test_patch_raises_when_fused_symbol_missing(): def test_patch_raises_when_unfused_symbol_missing(): """Symmetric to the fused-missing case; the unfused module exists but its symbol doesn't.""" - _install_fake_megatron(tp_size=1, with_unfused_symbol=False) + _install_fake_megatron_ce(tp_size=1, with_unfused_symbol=False) try: from liger_kernel.megatron import apply_liger_kernel_to_megatron @@ -288,11 +419,11 @@ def test_patch_raises_when_unfused_symbol_missing(): # --------------------------------------------------------------------------- -# Class-default construction + runtime label_smoothing override on the unfused path. +# 2.4 CE class-default construction + runtime label_smoothing override on the unfused path. # --------------------------------------------------------------------------- -def test_patch_constructs_ce_with_class_defaults(fake_megatron): +def test_patch_constructs_ce_with_class_defaults(fake_megatron_ce): """The public ``apply_liger_kernel_to_megatron`` API exposes no CE-specific kwargs; the patch must therefore construct ``LigerMegatronCrossEntropy`` with class defaults. @@ -322,7 +453,7 @@ def recording_init(self, ignore_index=-100, label_smoothing=0.0, reduction="none assert entry == {"ignore_index": -100, "label_smoothing": 0.0, "reduction": "none"} -def test_unfused_wrapper_honors_runtime_label_smoothing(fake_megatron): +def test_unfused_wrapper_honors_runtime_label_smoothing(fake_megatron_ce): """The unfused signature takes ``label_smoothing`` as a runtime arg; the wrapper must honor it. When the caller passes a non-default value, the wrapper constructs a fresh @@ -334,7 +465,7 @@ def test_unfused_wrapper_honors_runtime_label_smoothing(fake_megatron): """ import torch - _, unfused_ce = fake_megatron + _, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron from liger_kernel.megatron import cross_entropy as ce_mod @@ -366,12 +497,12 @@ def __call__(self, logits, target, tp_group=None): ) -def test_unfused_wrapper_uses_default_when_caller_does_not_pass_label_smoothing(fake_megatron): +def test_unfused_wrapper_uses_default_when_caller_does_not_pass_label_smoothing(fake_megatron_ce): """When the caller doesn't pass ``label_smoothing``, the wrapper reuses the patch-time ``default_ce`` instance — no fresh allocation per call.""" import torch - _, unfused_ce = fake_megatron + _, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron from liger_kernel.megatron import cross_entropy as ce_mod @@ -399,7 +530,7 @@ def __call__(self, logits, target, tp_group=None): ) -def test_unfused_wrapper_honors_explicit_zero_label_smoothing(fake_megatron): +def test_unfused_wrapper_honors_explicit_zero_label_smoothing(fake_megatron_ce): """Explicit ``label_smoothing=0.0`` at call time must be honored verbatim, not silently replaced by the patch-time default. @@ -409,7 +540,7 @@ def test_unfused_wrapper_honors_explicit_zero_label_smoothing(fake_megatron): positionally.""" import torch - _, unfused_ce = fake_megatron + _, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron from liger_kernel.megatron import cross_entropy as ce_mod @@ -437,12 +568,210 @@ def __call__(self, logits, target, tp_group=None): ) +# =========================================================================== +# 3. RMSNorm patch tests +# =========================================================================== +# Liger's RMSNorm patch displaces two Megatron symbols: +# +# - ``megatron.core.models.backends.LocalSpecProvider.layer_norm`` (a method) +# fills the per-layer norm slots inside ``TransformerLayerSubmodules``. +# - ``megatron.core.transformer.transformer_block.LayerNormImpl`` (a class) +# fills the block-level ``final_layernorm`` slot when the caller passes a +# per-layer spec rather than a ``TransformerBlockSubmodules``. +# +# The block-level patch only displaces the pure-torch ``WrappedTorchNorm`` +# fallback — users on TE / Apex chose those deliberately and Liger should not +# undo their fusions. Tests below cover both replacement targets, dispatch +# behavior, idempotency, and the "skip on non-WrappedTorchNorm" contract. + + # --------------------------------------------------------------------------- -# Public-API surface checks (mirrors transformers-side ``test_import_from_root`` -# and ``test_apply_liger_kernel_only_passes_valid_kwargs`` patterns). +# 3.1 Both RMSNorm symbols get replaced. # --------------------------------------------------------------------------- +def test_rms_norm_patch_replaces_local_spec_provider_layer_norm(fake_megatron_rms_norm): + backends, _ = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = backends.LocalSpecProvider.layer_norm + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + assert backends.LocalSpecProvider.layer_norm is not original + assert backends.LocalSpecProvider.layer_norm.__name__ == "patched_layer_norm" + # Marker + __wrapped__ chain back to the original method so future un-patching + # or introspection can find it. + assert getattr(backends.LocalSpecProvider.layer_norm, "__liger_patched__", False) is True + assert backends.LocalSpecProvider.layer_norm.__wrapped__ is original + + +def test_rms_norm_patch_replaces_transformer_block_layernorm_impl(fake_megatron_rms_norm): + _, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = transformer_block.LayerNormImpl + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + assert transformer_block.LayerNormImpl is not original + assert transformer_block.LayerNormImpl.__name__ == "_LigerOrTorchNorm" + assert getattr(transformer_block.LayerNormImpl, "__liger_patched__", False) is True + assert transformer_block.LayerNormImpl.__wrapped__ is original + + +def test_rms_norm_patch_replaces_both_symbols_in_one_call(fake_megatron_rms_norm): + """A single ``rms_norm=True`` call must replace both Megatron RMSNorm paths.""" + backends, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + assert backends.LocalSpecProvider.layer_norm.__name__ == "patched_layer_norm" + assert transformer_block.LayerNormImpl.__name__ == "_LigerOrTorchNorm" + + +def test_rms_norm_patch_with_rms_norm_false_leaves_norm_symbols_untouched(fake_megatron_rms_norm): + """Default ``rms_norm=False`` must not touch the norm symbols even if the call runs.""" + backends, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + lsp_before = backends.LocalSpecProvider.layer_norm + impl_before = transformer_block.LayerNormImpl + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=False) + + assert backends.LocalSpecProvider.layer_norm is lsp_before + assert transformer_block.LayerNormImpl is impl_before + + +def test_rms_norm_patch_is_idempotent(fake_megatron_rms_norm): + """Calling ``apply_liger_kernel_to_megatron(rms_norm=True)`` twice must not stack + wrappers — the sentinel attribute on each patched target guards against double-wrap.""" + backends, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + lsp_first = backends.LocalSpecProvider.layer_norm + impl_first = transformer_block.LayerNormImpl + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + assert backends.LocalSpecProvider.layer_norm is lsp_first + assert transformer_block.LayerNormImpl is impl_first + # __wrapped__ still references the original Megatron symbols, not the first Liger + # wrapper (otherwise the second apply would have chained over the first). + assert lsp_first.__wrapped__.__qualname__.endswith("LocalSpecProvider.layer_norm") + assert impl_first.__wrapped__ is transformer_block._WrappedTorchNorm + + +# --------------------------------------------------------------------------- +# 3.2 RMSNorm dispatch behavior through the patched targets. +# --------------------------------------------------------------------------- + + +def test_rms_norm_patch_local_spec_provider_returns_liger_for_rms_norm_true(fake_megatron_rms_norm): + """When the patched ``layer_norm`` method is called with ``rms_norm=True``, it returns + ``LigerMegatronRMSNorm`` — the actual class binding callers use to construct the per-layer + norms inside ``TransformerLayerSubmodules``.""" + backends, _ = fake_megatron_rms_norm + from liger_kernel.megatron import LigerMegatronRMSNorm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + provider = backends.LocalSpecProvider() + result = provider.layer_norm(rms_norm=True) + assert result is LigerMegatronRMSNorm + + +def test_rms_norm_patch_local_spec_provider_delegates_for_rms_norm_false(fake_megatron_rms_norm): + """When ``rms_norm=False``, the patched method must delegate to the original method — + Liger never touches LayerNorm users.""" + backends, _ = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + provider = backends.LocalSpecProvider() + result = provider.layer_norm(rms_norm=False) + # The stub's original layer_norm returns its _OriginalNormSentinel. + assert result is backends._OriginalNormSentinel + + +def test_rms_norm_patch_transformer_block_routes_rmsnorm_through_liger(fake_megatron_rms_norm): + """The ``_LigerOrTorchNorm`` wrapping class dispatches on ``config.normalization`` — + when it's ``"RMSNorm"``, instantiation returns a ``LigerMegatronRMSNorm`` instance. + + Construction is enough — no kernel is actually invoked, so this runs on CPU.""" + from types import SimpleNamespace + + _, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import LigerMegatronRMSNorm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + config = SimpleNamespace( + normalization="RMSNorm", + sequence_parallel=False, + layernorm_zero_centered_gamma=False, + ) + instance = transformer_block.LayerNormImpl(config=config, hidden_size=64, eps=1e-5) + assert isinstance(instance, LigerMegatronRMSNorm) + + +def test_rms_norm_patch_transformer_block_routes_layernorm_through_original(fake_megatron_rms_norm): + """When ``config.normalization`` is not ``"RMSNorm"`` (e.g. ``"LayerNorm"``), the + wrapping class falls back to the original ``WrappedTorchNorm`` so LayerNorm users + keep their existing behavior.""" + from types import SimpleNamespace + + _, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + config = SimpleNamespace(normalization="LayerNorm") + result = transformer_block.LayerNormImpl(config=config, hidden_size=64, eps=1e-5) + # The stub's WrappedTorchNorm.__new__ returns a sentinel instance. + assert isinstance(result, transformer_block._OriginalTorchNormInstance) + assert result.hidden_size == 64 + + +# --------------------------------------------------------------------------- +# 3.3 Block-level "skip when not WrappedTorchNorm" contract. +# --------------------------------------------------------------------------- + + +def test_rms_norm_patch_transformer_block_skips_when_layer_norm_is_not_wrapped_torch_norm(): + """If ``transformer_block.LayerNormImpl`` is TE / Apex / anything other than the pure-torch + ``WrappedTorchNorm`` fallback, the block-level patch must be a no-op. Replacing TE's + fused LN+Linear with Liger's standalone RMSNorm would double-norm or skip the norm + entirely; replacing Apex would surprise users who chose it deliberately.""" + _, transformer_block = _install_fake_megatron_rms_norm( + layer_norm_is_wrapped_torch_norm=False, + ) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + impl_before = transformer_block.LayerNormImpl + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + # Unchanged — patch detected a non-WrappedTorchNorm value and bailed. + assert transformer_block.LayerNormImpl is impl_before + # And the spec-provider patch above DID still apply (it's independent of the block + # path), so we'd see the provider patched even though the block was skipped. + # The fixture didn't install backends — re-fetch from sys.modules. + backends_mod = sys.modules["megatron.core.models.backends"] + assert backends_mod.LocalSpecProvider.layer_norm.__name__ == "patched_layer_norm" + finally: + _uninstall_fake_megatron() + + +# =========================================================================== +# 4. Cross-kernel public-API surface checks +# =========================================================================== +# Mirrors transformers-side ``test_import_from_root`` and +# ``test_apply_liger_kernel_only_passes_valid_kwargs`` patterns. + + def test_import_from_root(): """All public Megatron symbols must be reachable from ``liger_kernel.megatron``. @@ -472,15 +801,14 @@ def test_public_apply_function_has_no_ce_specific_kwargs(): ) -# --------------------------------------------------------------------------- -# End-to-end integration through the patched symbols. -# +# =========================================================================== +# 5. End-to-end integration through the patched CE symbols +# =========================================================================== # Earlier tests verify symbol identity + stub plumbing; the suite was missing # the "patch + call with real tensors + check the numbers" coverage. These # tests install the fake megatron, apply the patch, then invoke the resulting # wrapper with live torch tensors and compare against ``F.cross_entropy``. # That's the only way to catch wrapper-math bugs that pass the identity tests. -# --------------------------------------------------------------------------- import torch # noqa: E402 (deferred so the no-torch import-smoke tests above are unaffected) @@ -506,11 +834,11 @@ def _ref_loss_sbv(logits_sbv: torch.Tensor, target_sb: torch.Tensor, return loss_flat.reshape(s, b) -def test_patched_fused_symbol_computes_correct_loss(fake_megatron): +def test_patched_fused_symbol_computes_correct_loss(fake_megatron_ce): """End-to-end: install stub megatron, patch, invoke the resulting fused symbol with real [s, b, v] logits, verify the loss matches ``F.cross_entropy``. Closes the gap between "patch wired correctly" (existing tests) and "patched function does the right math".""" - fused_ce, _ = fake_megatron + fused_ce, _ = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) @@ -529,10 +857,10 @@ def test_patched_fused_symbol_computes_correct_loss(fake_megatron): assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) -def test_patched_unfused_symbol_computes_correct_loss(fake_megatron): +def test_patched_unfused_symbol_computes_correct_loss(fake_megatron_ce): """Same as the fused case, but through the unfused symbol — verifies both wrappers are exercised and exercises the no-label_smoothing default branch (caller doesn't pass).""" - _, unfused_ce = fake_megatron + _, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) @@ -552,11 +880,11 @@ def test_patched_unfused_symbol_computes_correct_loss(fake_megatron): @pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) -def test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch(fake_megatron, label_smoothing): +def test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch(fake_megatron_ce, label_smoothing): """The unfused wrapper's main feature beyond the fused path is honoring a runtime label_smoothing arg. Verify the resulting loss actually matches ``F.cross_entropy(..., label_smoothing=...)``, not just that a fresh CE instance is built.""" - _, unfused_ce = fake_megatron + _, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) @@ -573,12 +901,12 @@ def test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch(fake_meg assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4) -def test_patched_fused_symbol_preserves_gradients(fake_megatron): +def test_patched_fused_symbol_preserves_gradients(fake_megatron_ce): """Backward through the patched fused symbol: gradient shape + parity vs. PyTorch's reference. Liger writes the gradient back into the input buffer, so verifying ``.grad`` after backward exercises both the reshape contract and the in-place write.""" - fused_ce, _ = fake_megatron + fused_ce, _ = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) @@ -601,10 +929,10 @@ def test_patched_fused_symbol_preserves_gradients(fake_megatron): assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) -def test_patched_unfused_symbol_preserves_gradients(fake_megatron): +def test_patched_unfused_symbol_preserves_gradients(fake_megatron_ce): """Symmetric to the fused-gradient test; ensures the closure in ``_patch_vocab_parallel_cross_entropy`` doesn't break autograd.""" - _, unfused_ce = fake_megatron + _, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) @@ -624,14 +952,14 @@ def test_patched_unfused_symbol_preserves_gradients(fake_megatron): assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) -def test_patched_fused_symbol_default_ignore_index_minus_100(fake_megatron): +def test_patched_fused_symbol_default_ignore_index_minus_100(fake_megatron_ce): """Patch-time defaults: targets containing -100 should be treated as ignored — Liger's kernel zeros those loss positions, matching ``F.cross_entropy(ignore_index=-100)``. Native Megatron's fused CE has no ignore_index concept and would silently produce garbage on -100; this is one place where Liger is strictly better than the symbol it replaces, and the test pins that behavior.""" - fused_ce, _ = fake_megatron + fused_ce, _ = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) @@ -653,14 +981,14 @@ def test_patched_fused_symbol_default_ignore_index_minus_100(fake_megatron): assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) -def test_rms_norm_only_patch_does_not_touch_ce_symbols(fake_megatron): +def test_rms_norm_only_patch_does_not_touch_ce_symbols(fake_megatron_ce): """Symmetric to ``test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched``, but for the opposite split. With ``rms_norm=True, cross_entropy=False`` (RMSNorm helpers require real megatron and will ImportError on the stub — that's fine, we only need to confirm the CE symbols are not pre-emptively touched before the RMSNorm helpers run). Documenting this protects against future apply_… reorderings that would silently couple the two.""" - fused_ce, unfused_ce = fake_megatron + fused_ce, unfused_ce = fake_megatron_ce from liger_kernel.megatron import apply_liger_kernel_to_megatron fused_before = fused_ce.fused_vocab_parallel_cross_entropy From 8c74f5c4df6116f5ad8c3eeb75a109b84e672698 Mon Sep 17 00:00:00 2001 From: Vaibhav Jindal Date: Wed, 10 Jun 2026 07:46:32 +0000 Subject: [PATCH 6/7] Megatron CE benchmark: use shared CSV + standard visualizer; drop in-script plotter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two anti-patterns in the benchmark script unwound: * Wrote to its own benchmark/data/all_benchmark_data_megatron.csv via a runtime monkey-patch of update_benchmark_data_csv. Every other Liger benchmark writes to the shared all_benchmark_data.csv and the rows are partitioned by the kernel_name column — that's what the visualizer keys off. The redirect was self-inflicted special-casing. * Embedded a ~100-line _generate_plots function that re-implemented benchmarks_visualizer.py's job, only differently enough to be a maintenance liability (log-2 x-axis, fill_between band, hard-coded title). The reason it existed was because the visualizer couldn't find our redirected CSV. Once the CSV moves back to the shared file, benchmarks_visualizer.py reads our rows out of the box. benchmark/scripts/benchmark_megatron_cross_entropy.py - Drop _CSV_FILENAME constant + _patched_update_benchmark_data_csv + the benchmark_utils.update_benchmark_data_csv reassignment. Default write path lands on benchmark/data/all_benchmark_data.csv. - Drop the matplotlib import dance and the entire _generate_plots function (and its invocation in __main__). - Header docstring updated: replaces the "Output → PNG" promise with the standard 2-step workflow command pointing at benchmarks_visualizer. benchmark/data/all_benchmark_data.csv Migrate the 96 existing H100 megatron rows in from the orphaned per-component CSV (schemas matched verbatim — header line dropped, rows appended). Preserves the measurements without requiring a GPU re-run. benchmark/data/all_benchmark_data_megatron.csv Deleted — its data lives in the shared CSV now. Workflow is now the same as every other Liger kernel benchmark: python benchmark/scripts/benchmark_megatron_cross_entropy.py --overwrite python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_cross_entropy --metric-name speed python benchmark/benchmarks_visualizer.py \ --kernel-name megatron_cross_entropy --metric-name memory PNG outputs land in benchmark/visualizations/ (gitignored, as before). Net: benchmark script shrinks from 348 → 207 lines; one tracked CSV file consolidated away; visualizer slices megatron rows out of the shared CSV with no changes on its side. Co-Authored-By: Claude Opus 4.7 (1M context) --- benchmark/data/all_benchmark_data.csv | 96 +++++++++++ .../data/all_benchmark_data_megatron.csv | 97 ----------- .../benchmark_megatron_cross_entropy.py | 156 +----------------- 3 files changed, 104 insertions(+), 245 deletions(-) delete mode 100644 benchmark/data/all_benchmark_data_megatron.csv diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 7b9b0c1b0..a578d2622 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -2219,3 +2219,99 @@ jsd,liger,full,memory,MB,BT,total tokens,1024,3514.0009765625,3514.0009765625,35 jsd,liger,full,memory,MB,BT,total tokens,2048,7014.0009765625,7014.0009765625,7014.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 jsd,liger,full,memory,MB,BT,total tokens,4096,14028.0009765625,14028.0009765625,14028.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 jsd,liger,full,memory,MB,BT,total tokens,8192,28056.0,28056.0,28056.0,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,4096,0.2385600060224533,0.23590399324893951,0.24048000574111938,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,8192,0.26820799708366394,0.26531198620796204,0.27125120162963867,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,16384,0.3599199950695038,0.3569599986076355,0.36266239881515505,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,32768,0.6502079963684082,0.6452223896980286,0.6563008189201356,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,65536,1.087440013885498,1.0827391862869262,1.0942399978637696,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,131072,2.1112000942230225,2.108524799346924,2.1132928848266603,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,4096,0.2269120067358017,0.225913605093956,0.22775039970874789,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,8192,0.48070400953292847,0.47954559326171875,0.48152959942817686,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,16384,0.8124480247497559,0.8111680150032043,0.813696026802063,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,32768,2.1381120681762695,2.1316160202026366,2.142361545562744,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,65536,4.391056060791016,4.389933013916016,4.3939519882202145,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,131072,8.81107234954834,8.806303977966309,8.81488037109375,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,4096,0.31273600459098816,0.31200000643730164,0.31331199407577515,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,8192,0.6373440027236938,0.636352002620697,0.6381440162658691,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,16384,1.3848639726638794,1.3825664043426515,1.3870784282684325,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,32768,2.829935908317566,2.8288448333740233,2.8316224098205565,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,65536,5.823584079742432,5.8137922286987305,5.826655864715576,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,131072,11.810944080352783,11.80013427734375,11.817446517944337,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,4096,0.5649279952049255,0.5639231920242309,0.5661375880241394,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,8192,1.094543993473053,1.093395209312439,1.0956799983978271,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,16384,2.0701760053634644,2.068671941757202,2.072096109390259,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,32768,4.027215957641602,4.024915313720703,4.028927993774413,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,65536,7.93833589553833,7.937132930755616,7.940275096893311,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,131072,15.817471981048584,15.814432144165039,15.821215629577637,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,4096,0.3516159951686859,0.3489919900894165,0.35476480722427367,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,8192,0.4639679938554764,0.46081281304359434,0.46724479794502255,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,16384,0.7123200297355652,0.7092544078826905,0.7154880046844483,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,32768,1.322975993156433,1.3184319734573364,1.328320026397705,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,65536,2.3985120058059692,2.393356847763062,2.405248022079468,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,131072,4.698495864868164,4.69042558670044,4.70184965133667,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,4096,0.48710399866104126,0.4862975895404816,0.48793599009513855,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,8192,0.9741439819335938,0.9731391787528991,0.9752448081970215,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,16384,1.8759199976921082,1.8737216234207152,1.8779136180877687,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,32768,4.405440092086792,4.398931121826172,4.407916831970215,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,65536,8.920191764831543,8.916671752929688,8.923616409301758,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,131072,17.845855712890625,17.843469619750977,17.847046661376954,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,4096,0.5368639826774597,0.536191999912262,0.5374848127365113,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,8192,1.0521279573440552,1.05141122341156,1.0528064489364624,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,16384,2.1722079515457153,2.169472026824951,2.1735936641693114,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,32768,4.448447942733765,4.447282981872559,4.449868965148926,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,65536,9.033520221710205,9.030751991271973,9.042854690551758,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,131072,17.958431243896484,17.9557315826416,17.960690689086913,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,4096,0.7786880135536194,0.7767040133476257,0.7803391933441162,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,8192,1.475823998451233,1.4742015838623046,1.4774847745895388,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,16384,2.787328004837036,2.785683250427246,2.7883071899414062,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,32768,5.420928001403809,5.419379329681396,5.423315238952636,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,65536,10.677472114562988,10.674201774597169,10.67763843536377,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,131072,21.254928588867188,21.25368995666504,21.2557315826416,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,4096,0.41488000750541687,0.4122175931930542,0.41767039299011227,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,8192,0.5923520028591156,0.5882560014724731,0.6005120277404785,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,16384,0.9725440144538879,0.9690751910209655,0.9779008150100708,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,32768,1.8396799564361572,1.8349119424819946,1.8454079627990723,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,65536,3.4398880004882812,3.429087924957275,3.4599680423736574,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,131072,6.780799865722656,6.763827323913574,6.796000003814697,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,4096,0.5540960133075714,0.5528640151023865,0.555296003818512,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,8192,1.1072640419006348,1.1059776306152345,1.1086400032043457,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,16384,2.139024019241333,2.1362879276275635,2.142067241668701,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,32768,4.9202880859375,4.91553258895874,4.924480152130127,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,65536,9.961983680725098,9.958450889587402,9.964262008666992,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,131072,19.926143646240234,19.913933563232423,19.93060417175293,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,4096,0.6057599782943726,0.604960024356842,0.6063359975814819,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,8192,1.1868000030517578,1.1857984066009521,1.1880639791488647,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,16384,2.4351680278778076,2.4330944538116457,2.4368255615234373,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,32768,4.968512058258057,4.967942237854004,4.969702434539795,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,65536,10.074239730834961,10.06367359161377,10.077676963806152,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,131072,20.022079467773438,20.015307998657228,20.02815971374512,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,4096,0.846015989780426,0.8447743892669677,0.8472639918327332,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,8192,1.6103359460830688,1.6088960409164428,1.6117696285247802,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,16384,3.0503358840942383,3.049056053161621,3.0524160861968994,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,32768,5.937999963760376,5.936895847320557,5.939583778381348,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,65536,11.711487770080566,11.710764503479005,11.713843536376952,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,131072,23.322416305541992,23.321644973754882,23.32369956970215,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,4096,192.0791015625,192.0791015625,192.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,8192,384.0791015625,384.0791015625,384.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,16384,768.0791015625,768.0791015625,768.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,32768,1536.0791015625,1536.0791015625,1536.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,65536,3072.0791015625,3072.0791015625,3072.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,131072,6144.0791015625,6144.0791015625,6144.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,4096,512.0947265625,512.0947265625,512.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,8192,1024.0947265625,1024.0947265625,1024.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,16384,2048.0947265625,2048.0947265625,2048.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,32768,4096.0947265625,4096.0947265625,4096.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,65536,8192.0947265625,8192.0947265625,8192.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,131072,16384.09375,16384.09375,16384.09375,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,4096,448.1650390625,448.1650390625,448.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,8192,896.1650390625,896.1650390625,896.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,16384,1792.1650390625,1792.1650390625,1792.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,32768,3584.1650390625,3584.1650390625,3584.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,65536,7168.1650390625,7168.1650390625,7168.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,131072,14336.1650390625,14336.1650390625,14336.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,4096,320.1650390625,320.1650390625,320.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,8192,640.1650390625,640.1650390625,640.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,16384,1280.1650390625,1280.1650390625,1280.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,32768,2560.1650390625,2560.1650390625,2560.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,65536,5120.1650390625,5120.1650390625,5120.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,131072,10240.1650390625,10240.1650390625,10240.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 diff --git a/benchmark/data/all_benchmark_data_megatron.csv b/benchmark/data/all_benchmark_data_megatron.csv deleted file mode 100644 index 496f0f904..000000000 --- a/benchmark/data/all_benchmark_data_megatron.csv +++ /dev/null @@ -1,97 +0,0 @@ -kernel_name,kernel_provider,kernel_operation_mode,metric_name,metric_unit,x_name,x_label,x_value,y_value_50,y_value_20,y_value_80,extra_benchmark_config_str,gpu_name,timestamp,liger_version -megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,4096,0.23187200725078583,0.22867199778556824,0.23430399596691132,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 -megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,8192,0.26259198784828186,0.26010880470275877,0.26554239392280576,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 -megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,16384,0.35471999645233154,0.3516608059406281,0.3578752100467682,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 -megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,32768,0.6460640132427216,0.6421120166778564,0.6500800251960754,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 -megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,65536,1.0836479663848877,1.0789376258850099,1.0899327754974366,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 -megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,131072,2.107487916946411,2.10230393409729,2.1125311851501465,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:19,0.8.0 -megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,4096,0.22684800624847412,0.22604799270629883,0.22790400683879852,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 -megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,8192,0.48054400086402893,0.4792320132255554,0.48161279559135434,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 -megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,16384,0.8116160035133362,0.8104000091552734,0.8132799863815308,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 -megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,32768,2.1355679035186768,2.1311680316925052,2.1407871723175047,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 -megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,65536,4.391728162765503,4.389222240447998,4.394495964050293,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 -megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,131072,8.81164836883545,8.805439949035645,8.816160202026367,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:20,0.8.0 -megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,4096,0.3126560002565384,0.312063992023468,0.3132735908031463,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 -megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,8192,0.6373920142650604,0.6365439891815186,0.6382399797439575,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 -megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,16384,1.3829439878463745,1.3810559511184692,1.3850752115249634,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 -megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,32768,2.829360008239746,2.8282368659973143,2.831148767471314,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 -megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,65536,5.815711975097656,5.807616233825684,5.824639797210693,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 -megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,131072,11.804255962371826,11.790649795532227,11.81064338684082,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:22,0.8.0 -megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,4096,0.5651199817657471,0.5638527750968934,0.5661119818687439,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 -megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,8192,1.095136046409607,1.0940415859222412,1.0963200330734253,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 -megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,16384,2.069024085998535,2.0674240589141846,2.0703680515289307,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 -megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,32768,4.026576042175293,4.025190353393555,4.027974319458008,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 -megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,65536,7.939584016799927,7.937094402313232,7.94121618270874,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 -megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,131072,15.816815853118896,15.814399719238281,15.81926441192627,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:23,0.8.0 -megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,4096,0.3551519960165024,0.3523648023605347,0.3593983888626099,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 -megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,8192,0.466608002781868,0.46348801255226135,0.470521605014801,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 -megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,16384,0.7164639830589294,0.7130944013595581,0.7210240125656128,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 -megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,32768,1.326367974281311,1.3212159872055054,1.3306879997253418,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 -megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,65536,2.4017759561538696,2.392825651168823,2.412748908996582,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 -megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,131072,4.6971518993377686,4.690547180175781,4.702419185638428,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:24,0.8.0 -megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,4096,0.4853439927101135,0.48447999358177185,0.4862591981887817,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 -megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,8192,0.975488007068634,0.9742976188659669,0.9764159917831421,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 -megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,16384,1.8761919736862183,1.8743679523468018,1.8787519931793213,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 -megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,32768,4.398000001907349,4.3941184997558596,4.408435344696045,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 -megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,65536,8.924351692199707,8.919648170471191,8.92563247680664,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 -megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,131072,17.839040756225586,17.836934661865236,17.843289947509767,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:25,0.8.0 -megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,4096,0.5374079942703247,0.5366335868835449,0.5382080078125,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 -megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,8192,1.0536960363388062,1.0528128385543822,1.0545535564422608,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 -megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,16384,2.1704800128936768,2.1692031383514405,2.172160005569458,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 -megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,32768,4.446431875228882,4.445299243927002,4.4482752799987795,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 -megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,65536,9.035776138305664,9.031341171264648,9.042457962036133,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 -megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,131072,17.946975708007812,17.945389556884766,17.95260887145996,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:26,0.8.0 -megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,4096,0.7788800001144409,0.7774208068847656,0.780454421043396,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 -megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,8192,1.4757919907569885,1.4740799903869628,1.4768768072128298,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 -megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,16384,2.786736011505127,2.7851327419281007,2.7889408111572265,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 -megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,32768,5.420752048492432,5.417894268035889,5.422150421142579,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 -megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,65536,10.675616264343262,10.674079704284669,10.679340934753418,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 -megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,131072,21.251296043395996,21.248454666137697,21.253714752197265,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:27,0.8.0 -megatron_cross_entropy,liger,full,speed,ms,V,vocab size,4096,0.4208959937095642,0.41683200001716614,0.4256959855556488,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,liger,full,speed,ms,V,vocab size,8192,0.5997599959373474,0.5965440273284912,0.6039680242538452,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,liger,full,speed,ms,V,vocab size,16384,0.9797760248184204,0.9747712016105652,0.9856639981269837,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,liger,full,speed,ms,V,vocab size,32768,1.8492159843444824,1.8442879915237427,1.85481595993042,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,liger,full,speed,ms,V,vocab size,65536,3.455839991569519,3.4454591274261475,3.471328020095825,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,liger,full,speed,ms,V,vocab size,131072,6.783936023712158,6.772691154479981,6.791551971435547,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,torch,full,speed,ms,V,vocab size,4096,0.5539519786834717,0.5527359843254089,0.5549759864807129,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,torch,full,speed,ms,V,vocab size,8192,1.106719970703125,1.1052928209304809,1.1081023931503298,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,torch,full,speed,ms,V,vocab size,16384,2.137727975845337,2.1360192775726317,2.139878463745117,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,torch,full,speed,ms,V,vocab size,32768,4.921311855316162,4.918290996551514,4.929247951507568,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,torch,full,speed,ms,V,vocab size,65536,9.962112426757812,9.959987258911132,9.963628578186036,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,torch,full,speed,ms,V,vocab size,131072,19.926719665527344,19.91397705078125,19.931795120239258,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:28,0.8.0 -megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,4096,0.6056640148162842,0.6047423720359802,0.606719970703125,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 -megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,8192,1.1874719858169556,1.1862783908843995,1.188646388053894,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 -megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,16384,2.434976100921631,2.4332736015319822,2.4367167949676514,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 -megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,32768,4.970719814300537,4.968512058258057,4.9720383644104,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 -megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,65536,10.077088356018066,10.070701026916502,10.083423614501953,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 -megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,131072,20.020895957946777,20.018156433105467,20.022406387329102,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:29,0.8.0 -megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,4096,0.8470079898834229,0.8454336166381836,0.8486016154289245,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 -megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,8192,1.6092480421066284,1.6075520515441895,1.6106047868728637,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 -megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,16384,3.04915189743042,3.0480000972747803,3.0506880283355713,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 -megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,32768,5.9349119663238525,5.934432029724121,5.936927795410156,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 -megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,65536,11.71017599105835,11.70949764251709,11.711187362670898,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 -megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,131072,23.321871757507324,23.319308853149415,23.324396514892577,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:30,0.8.0 -megatron_cross_entropy,liger,full,memory,MB,V,vocab size,4096,192.0791015625,192.0791015625,192.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,liger,full,memory,MB,V,vocab size,8192,384.0791015625,384.0791015625,384.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,liger,full,memory,MB,V,vocab size,16384,768.0791015625,768.0791015625,768.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,liger,full,memory,MB,V,vocab size,32768,1536.0791015625,1536.0791015625,1536.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,liger,full,memory,MB,V,vocab size,65536,3072.0791015625,3072.0791015625,3072.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,liger,full,memory,MB,V,vocab size,131072,6144.0791015625,6144.0791015625,6144.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,torch,full,memory,MB,V,vocab size,4096,512.0947265625,512.0947265625,512.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,torch,full,memory,MB,V,vocab size,8192,1024.0947265625,1024.0947265625,1024.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,torch,full,memory,MB,V,vocab size,16384,2048.0947265625,2048.0947265625,2048.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,torch,full,memory,MB,V,vocab size,32768,4096.0947265625,4096.0947265625,4096.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,torch,full,memory,MB,V,vocab size,65536,8192.0947265625,8192.0947265625,8192.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,torch,full,memory,MB,V,vocab size,131072,16384.09375,16384.09375,16384.09375,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,4096,448.1650390625,448.1650390625,448.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,8192,896.1650390625,896.1650390625,896.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,16384,1792.1650390625,1792.1650390625,1792.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,32768,3584.1650390625,3584.1650390625,3584.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,65536,7168.1650390625,7168.1650390625,7168.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,131072,14336.1650390625,14336.1650390625,14336.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,4096,320.1650390625,320.1650390625,320.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,8192,640.1650390625,640.1650390625,640.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,16384,1280.1650390625,1280.1650390625,1280.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,32768,2560.1650390625,2560.1650390625,2560.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,65536,5120.1650390625,5120.1650390625,5120.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 -megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,131072,10240.1650390625,10240.1650390625,10240.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 00:54:31,0.8.0 diff --git a/benchmark/scripts/benchmark_megatron_cross_entropy.py b/benchmark/scripts/benchmark_megatron_cross_entropy.py index afa502c5f..883e1b802 100644 --- a/benchmark/scripts/benchmark_megatron_cross_entropy.py +++ b/benchmark/scripts/benchmark_megatron_cross_entropy.py @@ -15,22 +15,22 @@ Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not installed, both megatron providers are silently skipped. -Output: - - CSV: ``benchmark/data/all_benchmark_data_megatron.csv`` - (separate per-component file, mirroring the recent ``all_benchmark_data_cutile.csv`` - precedent — keeps the PR diff scannable) - - Plots (best-effort): ``benchmark/visualizations/megatron_cross_entropy_*.png`` - rendered when matplotlib is available; skipped silently otherwise. +Output goes to the shared ``benchmark/data/all_benchmark_data.csv`` like every +other Liger benchmark — rows are tagged with ``kernel_name="megatron_cross_entropy"`` +and the standard visualizer renders them via: + + python benchmark/benchmarks_visualizer.py \\ + --kernel-name megatron_cross_entropy --metric-name speed + python benchmark/benchmarks_visualizer.py \\ + --kernel-name megatron_cross_entropy --metric-name memory """ -import functools import os import torch import torch.nn.functional as F import triton -import utils as benchmark_utils from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -38,24 +38,6 @@ from utils import parse_benchmark_script_args from utils import run_benchmarks -_CSV_FILENAME = "all_benchmark_data_megatron.csv" - -# Redirect CSV output to a per-component file (parallel to all_benchmark_data_cutile.csv). -# We can't reuse the LIGER_KERNEL_IMPL knob because it also drives kernel-backend -# selection in liger_kernel.ops — overloading it would force us onto a megatron-named -# backend that doesn't exist. Patching update_benchmark_data_csv is surgical and only -# affects this benchmark process. -_original_update_benchmark_data_csv = benchmark_utils.update_benchmark_data_csv - - -@functools.wraps(_original_update_benchmark_data_csv) -def _patched_update_benchmark_data_csv(*args, **kwargs): - kwargs["filename"] = _CSV_FILENAME - return _original_update_benchmark_data_csv(*args, **kwargs) - - -benchmark_utils.update_benchmark_data_csv = _patched_update_benchmark_data_csv - from liger_kernel.megatron import LigerMegatronCrossEntropy from liger_kernel.utils import infer_device @@ -71,17 +53,6 @@ def _patched_update_benchmark_data_csv(*args, **kwargs): vocab_parallel_cross_entropy = None _MEGATRON_AVAILABLE = False -try: - import matplotlib - - matplotlib.use("Agg") # headless - import matplotlib.pyplot as plt # noqa: E402 - - _HAVE_MATPLOTLIB = True -except ImportError: - plt = None - _HAVE_MATPLOTLIB = False - def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True): logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad) @@ -204,115 +175,6 @@ def full(): return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) -# --------------------------------------------------------------------------- -# Plot generation (best-effort). -# --------------------------------------------------------------------------- - - -def _generate_plots(out_dir: str) -> None: - """Generate one PNG per (metric, mode) combination from the CSV we just wrote. - - Silently skipped when matplotlib is unavailable. Reads the CSV rather than - re-running benchmarks so the plots use the same numbers that landed on disk. - """ - if not _HAVE_MATPLOTLIB: - print("[plots] matplotlib not available; skipping plot generation.") - return - - import csv - from pathlib import Path - - csv_path = Path(os.path.join(os.path.dirname(__file__), "..", "data", "all_benchmark_data_megatron.csv")) - if not csv_path.exists(): - print(f"[plots] CSV not found at {csv_path}; skipping plot generation.") - return - - out_dir_path = Path(out_dir) - out_dir_path.mkdir(parents=True, exist_ok=True) - - # CSV layout is denormalized: one row per (provider, mode, metric, x_value). - # We need to aggregate rows back into (provider → x_values, y50_list, y20_list, y80_list) - # series before plotting. - rows = [] - with open(csv_path, "r") as f: - reader = csv.DictReader(f) - for row in reader: - if row.get("kernel_name") == "megatron_cross_entropy": - rows.append(row) - - if not rows: - print("[plots] no megatron_cross_entropy rows in the CSV; skipping plots.") - return - - series: dict = {} # (metric_name, mode, provider) → {xs, y50, y20, y80, meta} - for row in rows: - key = (row["metric_name"], row["kernel_operation_mode"], row["kernel_provider"]) - entry = series.setdefault( - key, - { - "xs": [], "y50": [], "y20": [], "y80": [], - "x_label": row.get("x_label", "x"), - "metric_unit": row.get("metric_unit", ""), - }, - ) - try: - entry["xs"].append(float(row["x_value"])) - entry["y50"].append(float(row["y_value_50"])) - entry["y20"].append(float(row["y_value_20"])) - entry["y80"].append(float(row["y_value_80"])) - except (KeyError, ValueError): - # If the CSV schema differs from what we expect, surface it loudly but skip - # plotting rather than crashing the whole benchmark. - print(f"[plots] WARNING: skipping malformed row: {row}") - continue - - if not series: - print("[plots] no usable series in the CSV; skipping plots.") - return - - # Group series by (metric_name, mode) so each plot collects all providers for that slice. - plots: dict = {} - for (metric_name, mode, provider), entry in series.items(): - plots.setdefault((metric_name, mode), []).append((provider, entry)) - - plot_paths = [] - for (metric_name, mode), provider_entries in plots.items(): - mode_label = mode if mode not in (None, "", "None") else "full" - fig, ax = plt.subplots(figsize=(8, 5)) - x_label = "x" - metric_unit = "" - for provider, entry in sorted(provider_entries, key=lambda pe: pe[0]): - # Sort points by x so the line plot is monotone. - order = sorted(range(len(entry["xs"])), key=lambda i: entry["xs"][i]) - xs = [entry["xs"][i] for i in order] - y50 = [entry["y50"][i] for i in order] - y20 = [entry["y20"][i] for i in order] - y80 = [entry["y80"][i] for i in order] - # Capture the line's color so fill_between's band uses the same hue across - # matplotlib versions that don't share auto-cycles between the two calls. - (line,) = ax.plot(xs, y50, marker="o", label=provider) - ax.fill_between(xs, y20, y80, alpha=0.2, color=line.get_color()) - x_label = entry["x_label"] - metric_unit = entry["metric_unit"] - - ax.set_xscale("log", base=2) - ax.set_xlabel(x_label) - ax.set_ylabel(f"{metric_name} ({metric_unit})") - ax.set_title(f"Megatron CE — {metric_name}, mode={mode_label}") - ax.legend() - ax.grid(True, alpha=0.3) - - out_path = out_dir_path / f"megatron_cross_entropy_{metric_name}_{mode_label}.png" - fig.tight_layout() - fig.savefig(out_path, dpi=120) - plt.close(fig) - plot_paths.append(str(out_path)) - - print(f"[plots] wrote {len(plot_paths)} plot(s):") - for p in plot_paths: - print(f" - {p}") - - if __name__ == "__main__": args = parse_benchmark_script_args() @@ -345,5 +207,3 @@ def _generate_plots(out_dir: str) -> None: metric_unit="MB", **common_configs, ) - - _generate_plots(out_dir=os.path.join(os.path.dirname(__file__), "..", "visualizations")) From 2d876768f8b20c5efe4d471ee279f70f8a3d46da Mon Sep 17 00:00:00 2001 From: Vaibhav Jindal Date: Wed, 10 Jun 2026 20:36:03 +0000 Subject: [PATCH 7/7] checkstyle fixes --- .../benchmark_megatron_cross_entropy.py | 4 +- examples/megatron/run_mode1_monkey_patch.py | 10 +---- examples/megatron/run_mode2_hand_spec.py | 7 ++-- src/liger_kernel/megatron/cross_entropy.py | 6 +-- src/liger_kernel/megatron/monkey_patch.py | 6 +-- test/megatron/test_cross_entropy.py | 19 +++++++--- test/megatron/test_monkey_patch.py | 37 ++++++++++--------- 7 files changed, 42 insertions(+), 47 deletions(-) diff --git a/benchmark/scripts/benchmark_megatron_cross_entropy.py b/benchmark/scripts/benchmark_megatron_cross_entropy.py index 883e1b802..d523054b2 100644 --- a/benchmark/scripts/benchmark_megatron_cross_entropy.py +++ b/benchmark/scripts/benchmark_megatron_cross_entropy.py @@ -104,9 +104,7 @@ def _megatron_fused_call(logits, target): return _megatron_fused_call if provider == "megatron-unfused": if not _MEGATRON_AVAILABLE: - raise RuntimeError( - "megatron-core not installed; cannot benchmark 'megatron-unfused' provider" - ) + raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron-unfused' provider") tp_group = _ensure_single_rank_tp_group() def _megatron_unfused_call(logits, target): diff --git a/examples/megatron/run_mode1_monkey_patch.py b/examples/megatron/run_mode1_monkey_patch.py index 31638a06a..5e3f460df 100644 --- a/examples/megatron/run_mode1_monkey_patch.py +++ b/examples/megatron/run_mode1_monkey_patch.py @@ -152,14 +152,8 @@ def _print_ce_symbols() -> None: import megatron.core.tensor_parallel.cross_entropy as unfused print("\n=== Resolved CE symbols ===") - print( - f" fused.fused_vocab_parallel_cross_entropy → " - f"{fused.fused_vocab_parallel_cross_entropy.__name__}" - ) - print( - f" unfused.vocab_parallel_cross_entropy → " - f"{unfused.vocab_parallel_cross_entropy.__name__}" - ) + print(f" fused.fused_vocab_parallel_cross_entropy → {fused.fused_vocab_parallel_cross_entropy.__name__}") + print(f" unfused.vocab_parallel_cross_entropy → {unfused.vocab_parallel_cross_entropy.__name__}") print() diff --git a/examples/megatron/run_mode2_hand_spec.py b/examples/megatron/run_mode2_hand_spec.py index c4d66ef9f..4053a7adc 100644 --- a/examples/megatron/run_mode2_hand_spec.py +++ b/examples/megatron/run_mode2_hand_spec.py @@ -101,9 +101,9 @@ def __init__(self, *args, liger_ce_label_smoothing: float = 0.0, **kwargs): def compute_language_model_loss(self, labels, logits): # LanguageModule contract: input labels are [b, s], output loss is [b, s]. # LigerMegatronCrossEntropy matches the fused signature, which expects [s, b]. - labels_sb = labels.transpose(0, 1).contiguous() # [s, b] + labels_sb = labels.transpose(0, 1).contiguous() # [s, b] loss_sb = self.liger_ce(logits, labels_sb, self.pg_collection.tp) # [s, b] - return loss_sb.transpose(0, 1).contiguous() # [b, s] + return loss_sb.transpose(0, 1).contiguous() # [b, s] def initialize_distributed(tp: int = 2, pp: int = 1) -> None: @@ -234,8 +234,7 @@ def _print_ce_class(model: torch.nn.Module) -> None: if ce is None: print(" model.liger_ce → (not set; subclass missing)") else: - print(f" model.liger_ce → " - f"{type(ce).__module__}.{type(ce).__name__}") + print(f" model.liger_ce → {type(ce).__module__}.{type(ce).__name__}") print(f" ce.label_smoothing → {ce.label_smoothing}") print(f" ce.ignore_index → {ce.ignore_index}") print() diff --git a/src/liger_kernel/megatron/cross_entropy.py b/src/liger_kernel/megatron/cross_entropy.py index 4ed11e69b..fb899dd2f 100644 --- a/src/liger_kernel/megatron/cross_entropy.py +++ b/src/liger_kernel/megatron/cross_entropy.py @@ -77,8 +77,4 @@ def forward( return loss.reshape(s, b) def extra_repr(self) -> str: - return ( - f"ignore_index={self.ignore_index}, " - f"label_smoothing={self.label_smoothing}, " - f"reduction={self.reduction!r}" - ) + return f"ignore_index={self.ignore_index}, label_smoothing={self.label_smoothing}, reduction={self.reduction!r}" diff --git a/src/liger_kernel/megatron/monkey_patch.py b/src/liger_kernel/megatron/monkey_patch.py index 4ed19f948..db0bb43d3 100644 --- a/src/liger_kernel/megatron/monkey_patch.py +++ b/src/liger_kernel/megatron/monkey_patch.py @@ -223,8 +223,7 @@ def liger_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_g fused_ce.fused_vocab_parallel_cross_entropy = liger_fused_vocab_parallel_cross_entropy logger.info( - "Patched megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy " - "with Liger cross-entropy." + "Patched megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy with Liger cross-entropy." ) @@ -287,6 +286,5 @@ def liger_vocab_parallel_cross_entropy( unfused_ce.vocab_parallel_cross_entropy = liger_vocab_parallel_cross_entropy logger.info( - "Patched megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy " - "with Liger cross-entropy." + "Patched megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy with Liger cross-entropy." ) diff --git a/test/megatron/test_cross_entropy.py b/test/megatron/test_cross_entropy.py index 589a5df4d..a530e8607 100644 --- a/test/megatron/test_cross_entropy.py +++ b/test/megatron/test_cross_entropy.py @@ -151,12 +151,12 @@ def test_class_rejects_non_none_reduction(bad_reduction): def test_class_rejects_non_3d_logits(): """The class explicitly guards against HuggingFace-shape [b, s, v] callers etc.""" ce = LigerMegatronCrossEntropy() - bad = torch.randn(8, 16, device=device) # 2-D + bad = torch.randn(8, 16, device=device) # 2-D target = torch.randint(0, 16, (8,), device=device, dtype=torch.long) with pytest.raises(ValueError, match="3-D"): ce(bad, target) - too_many = torch.randn(2, 2, 4, 16, device=device) # 4-D + too_many = torch.randn(2, 2, 4, 16, device=device) # 4-D target2 = torch.randint(0, 16, (2, 2, 4), device=device, dtype=torch.long) with pytest.raises(ValueError, match="3-D"): ce(too_many, target2) @@ -259,7 +259,7 @@ def _assign_ignore_index(target: torch.Tensor, ignore_index: int, frac: float = [ (16, 1, 4096), (32, 2, 32000), # llama-ish vocab - (5, 3, 123), # weird shape + (5, 3, 123), # weird shape ], ) @pytest.mark.parametrize("scalar", [0.5, 1.0, 5.0]) @@ -302,8 +302,8 @@ def test_class_correctness_scalar_sweep(s, b, v, scalar, dtype, atol, rtol): "s, b, v, ignore_index", [ (16, 1, 4096, -100), # standard hf sentinel - (32, 2, 32000, 2), # positive id (valid vocab slot used as ignore) - (5, 3, 123, -123), # weird negative + (32, 2, 32000, 2), # positive id (valid vocab slot used as ignore) + (5, 3, 123, -123), # weird negative ], ) @pytest.mark.parametrize( @@ -360,7 +360,14 @@ def test_class_correctness_with_ignore_index_sweep(s, b, v, ignore_index, dtype, ], ) def test_class_correctness_with_label_smoothing_and_ignore_index( - s, b, v, ignore_index, label_smoothing, dtype, atol, rtol, + s, + b, + v, + ignore_index, + label_smoothing, + dtype, + atol, + rtol, ): """Combined ignore_index × label_smoothing sweep — the two are independent in Liger's CE kernel but mixing them historically surfaced bugs in the smoothing math. Mirrors diff --git a/test/megatron/test_monkey_patch.py b/test/megatron/test_monkey_patch.py index e761bfc28..4f1510b0f 100644 --- a/test/megatron/test_monkey_patch.py +++ b/test/megatron/test_monkey_patch.py @@ -29,7 +29,6 @@ import pytest - # =========================================================================== # 1. Stub-megatron installers + fixtures # =========================================================================== @@ -77,7 +76,10 @@ def original_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, t if with_unfused_symbol: def original_vocab_parallel_cross_entropy( - vocab_parallel_logits, target, label_smoothing=0.0, tp_group=None, + vocab_parallel_logits, + target, + label_smoothing=0.0, + tp_group=None, ): raise AssertionError("original megatron unfused kernel called — patch failed") @@ -437,11 +439,13 @@ def test_patch_constructs_ce_with_class_defaults(fake_megatron_ce): real_ctor = ce_mod.LigerMegatronCrossEntropy.__init__ def recording_init(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): - captured.append({ - "ignore_index": ignore_index, - "label_smoothing": label_smoothing, - "reduction": reduction, - }) + captured.append( + { + "ignore_index": ignore_index, + "label_smoothing": label_smoothing, + "reduction": reduction, + } + ) real_ctor(self, ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction) with patch.object(ce_mod.LigerMegatronCrossEntropy, "__init__", recording_init): @@ -492,8 +496,7 @@ def __call__(self, logits, target, tp_group=None): unfused_ce.vocab_parallel_cross_entropy(logits, target, label_smoothing=0.3) assert constructed == [0.3], ( - f"unfused wrapper should construct one fresh instance with the runtime override; " - f"got: {constructed}" + f"unfused wrapper should construct one fresh instance with the runtime override; got: {constructed}" ) @@ -525,9 +528,7 @@ def __call__(self, logits, target, tp_group=None): # Second positional call also without label_smoothing — still no new construction. unfused_ce.vocab_parallel_cross_entropy(logits, target) - assert constructed == [], ( - f"default-path calls must reuse default_ce — no fresh instances; got: {constructed}" - ) + assert constructed == [], f"default-path calls must reuse default_ce — no fresh instances; got: {constructed}" def test_unfused_wrapper_honors_explicit_zero_label_smoothing(fake_megatron_ce): @@ -563,8 +564,7 @@ def __call__(self, logits, target, tp_group=None): unfused_ce.vocab_parallel_cross_entropy(logits, target, 0.0) assert constructed == [0.0], ( - f"explicit label_smoothing=0.0 at call time must be honored verbatim; " - f"got: {constructed}" + f"explicit label_smoothing=0.0 at call time must be honored verbatim; got: {constructed}" ) @@ -820,8 +820,9 @@ def test_public_apply_function_has_no_ce_specific_kwargs(): _device = infer_device() -def _ref_loss_sbv(logits_sbv: torch.Tensor, target_sb: torch.Tensor, - ignore_index: int = -100, label_smoothing: float = 0.0) -> torch.Tensor: +def _ref_loss_sbv( + logits_sbv: torch.Tensor, target_sb: torch.Tensor, ignore_index: int = -100, label_smoothing: float = 0.0 +) -> torch.Tensor: """Reference CE for [s, b, v] logits / [s, b] target, returning [s, b].""" s, b, v = logits_sbv.shape loss_flat = F.cross_entropy( @@ -896,7 +897,9 @@ def test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch(fake_meg ref = _ref_loss_sbv(logits.clone(), target, label_smoothing=label_smoothing) got = unfused_ce.vocab_parallel_cross_entropy( - logits.clone(), target, label_smoothing=label_smoothing, + logits.clone(), + target, + label_smoothing=label_smoothing, ) assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4)