Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions benchmark/data/all_benchmark_data.csv

Large diffs are not rendered by default.

207 changes: 207 additions & 0 deletions benchmark/scripts/benchmark_megatron_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""Benchmark Liger's Megatron-LM cross-entropy wrapper.

Compares four providers on the per-token CE call shape ``[seq, batch, vocab]``:

- **torch**: vanilla ``F.cross_entropy``
- **megatron**: Megatron's *fused* ``fused_vocab_parallel_cross_entropy`` path
(``cross_entropy_loss_fusion=True``, JIT-fused via TorchScript)
- **megatron-unfused**: Megatron's *unfused* ``vocab_parallel_cross_entropy``
path (``cross_entropy_loss_fusion=False``, eager Python; the path users on
``label_smoothing`` typically end up on)
- **liger**: ``LigerMegatronCrossEntropy`` — Liger's Triton CE wrapped in the
Megatron fused signature. Same kernel regardless of which Megatron symbol
it was patched onto, so we only benchmark it once.

Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not
installed, both megatron providers are silently skipped.

Output goes to the shared ``benchmark/data/all_benchmark_data.csv`` like every
other Liger benchmark — rows are tagged with ``kernel_name="megatron_cross_entropy"``
and the standard visualizer renders them via:

python benchmark/benchmarks_visualizer.py \\
--kernel-name megatron_cross_entropy --metric-name speed
python benchmark/benchmarks_visualizer.py \\
--kernel-name megatron_cross_entropy --metric-name memory
"""

import os

import torch
import torch.nn.functional as F
import triton

from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.megatron import LigerMegatronCrossEntropy
from liger_kernel.utils import infer_device

device = infer_device()

try:
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy

_MEGATRON_AVAILABLE = True
except ImportError:
fused_vocab_parallel_cross_entropy = None
vocab_parallel_cross_entropy = None
_MEGATRON_AVAILABLE = False


def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True):
logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad)
target = torch.randint(0, v, (s, b), device=device, dtype=torch.long)
return logits, target


def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
s, b, v = logits.shape
return F.cross_entropy(
logits.reshape(-1, v).float(),
target.reshape(-1),
reduction="none",
).reshape(s, b)


def _ensure_single_rank_tp_group():
"""Initialize torch.distributed (single-rank) and return a usable TP group.

For a single-process benchmark we use the world group of size 1; the
internal all-reduce becomes a no-op.
"""
import torch.distributed as dist

if not dist.is_initialized():
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("LOCAL_RANK", "0")
dist.init_process_group(backend="nccl")
return dist.group.WORLD


def _select_fwd(provider: str):
if provider == "liger":
ce = LigerMegatronCrossEntropy(reduction="none")
return lambda logits, target: ce(logits, target)
if provider == "torch":
return _pytorch_cross_entropy
if provider == "megatron":
if not _MEGATRON_AVAILABLE:
raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider")
tp_group = _ensure_single_rank_tp_group()

def _megatron_fused_call(logits, target):
return fused_vocab_parallel_cross_entropy(logits, target, tp_group)

return _megatron_fused_call
if provider == "megatron-unfused":
if not _MEGATRON_AVAILABLE:
raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron-unfused' provider")
tp_group = _ensure_single_rank_tp_group()

def _megatron_unfused_call(logits, target):
# Unfused signature: (logits, target, label_smoothing=0.0, tp_group=None)
return vocab_parallel_cross_entropy(logits, target, 0.0, tp_group)

return _megatron_unfused_call
raise ValueError(f"unknown provider: {provider!r}")


def bench_speed_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
v = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
s = input.extra_benchmark_config["S"]
b = input.extra_benchmark_config["B"]

logits, target = _make_inputs(s, b, v)
fwd_fn = _select_fwd(provider)

def fwd():
return fwd_fn(logits, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "backward":
# Megatron's fused CE writes gradients in-place into saved tensors during backward,
# which breaks the standard retain_graph=True / repeated-backward pattern do_bench
# uses elsewhere. Run a fresh fwd+bwd each iteration so each backward sees an
# unmodified autograd graph. Measurement therefore includes forward time —
# subtract the "forward" measurement to derive backward-only timing.
def _fwd_bwd():
if logits.grad is not None:
logits.grad = None
out = fwd_fn(logits, target)
out.sum().backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES)
elif mode == "full":

def full():
y = fwd()
y.sum().backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES)
else:
raise ValueError(f"unknown mode: {mode!r}")

return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80)


def bench_memory_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
v = input.x
provider = input.kernel_provider
s = input.extra_benchmark_config["S"]
b = input.extra_benchmark_config["B"]

logits, target = _make_inputs(s, b, v)
fwd_fn = _select_fwd(provider)

def full():
y = fwd_fn(logits, target)
y.sum().backward()

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80)


if __name__ == "__main__":
args = parse_benchmark_script_args()

providers = ["liger", "torch"]
if _MEGATRON_AVAILABLE:
providers.append("megatron")
providers.append("megatron-unfused")

common_configs = {
"kernel_name": "megatron_cross_entropy",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, 18)],
"kernel_providers": providers,
"extra_benchmark_configs": [{"S": 2048, "B": 4}],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_megatron_cross_entropy,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_megatron_cross_entropy,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
44 changes: 44 additions & 0 deletions docs/High-Level-APIs.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,47 @@ You can also use the Patching APIs to use the kernels for a specific model archi
extra:
show_docstring: true
show_signature: true

---

## Megatron-LM

Liger also exposes a patch for the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
training framework, replacing Megatron's native RMSNorm and both vocab-parallel
cross-entropy paths (fused and unfused) with Liger's Triton kernels.

| **Framework** | **API** | **Supported Operations** |
|---------------|--------------------------------------------------------|--------------------------|
| Megatron-LM | `liger_kernel.megatron.apply_liger_kernel_to_megatron` | RMSNorm, CrossEntropyLoss |

**Scope**: Initial release supports `tensor_model_parallel_size=1` only for
cross-entropy. Vocab-parallel cross-entropy (TP>1) is follow-up work — with
TP>1, each rank holds a sharded `[N, V/tp]` logits slice and cross-entropy
requires cross-rank all-reduces that Liger's kernel does not perform. The
patch raises a `RuntimeError` at patch time or call time if TP>1 is detected.

**Usage**:

```python
from liger_kernel.megatron import apply_liger_kernel_to_megatron

# Call before Megatron's forward pass reaches compute_language_model_loss.
# Defaults match Megatron's native CE behavior; no CE-specific config needed.
apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True)
```

Both the fused (`config.cross_entropy_loss_fusion=True`,
`cross_entropy_fusion_impl='native'`) and unfused
(`config.cross_entropy_loss_fusion=False`) CE paths are patched in a single
call, so Megatron picks up Liger regardless of which path your config selects.

For training setups that need explicit kernel configuration (custom
`ignore_index`, `label_smoothing`, etc.), instantiate
`LigerMegatronCrossEntropy` directly and wire it into your model — see
`examples/megatron/run_mode2_hand_spec.py`.

::: liger_kernel.megatron.apply_liger_kernel_to_megatron
options:
extra:
show_docstring: true
show_signature: true
40 changes: 29 additions & 11 deletions examples/megatron/run_mode1_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
"""Mode 1 — monkey-patch Megatron-Core to use Liger RMSNorm.
"""Mode 1 — monkey-patch Megatron-Core to use Liger RMSNorm + cross-entropy.

Adapted from Megatron's ``examples/run_simple_mcore_train_loop.py``. The
relevant additions (vs. that file) are:

1. ``apply_liger_kernel_to_megatron(rms_norm=True)`` called once at the
top of ``model_provider()``. This patches both
``LocalSpecProvider.layer_norm`` (per-layer norm slots) and
``transformer_block.LayerNormImpl`` (block-level ``final_layernorm``).
1. ``apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True)``
called once at the top of ``model_provider()``. This patches:
- ``LocalSpecProvider.layer_norm`` (per-layer norm slots)
- ``transformer_block.LayerNormImpl`` (block-level ``final_layernorm``)
- ``fused_cross_entropy.fused_vocab_parallel_cross_entropy``
(the fused CE path)
- ``tensor_parallel.cross_entropy.vocab_parallel_cross_entropy``
(the unfused CE path)

2. ``normalization="RMSNorm"`` added to ``TransformerConfig`` so the
model actually has RMSNorm slots to patch (Megatron defaults to
``LayerNorm``).

3. ``_print_norm_classes`` after model construction — prints the
resolved class for every norm slot so you can verify Liger took
over.
3. ``_print_norm_classes`` + ``_print_ce_symbols`` after model construction
— print the resolved class/function bindings so you can verify Liger
took over for every slot.

Run with:
torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=29500 \\
Expand Down Expand Up @@ -70,7 +74,7 @@ def initialize_distributed(tp: int = 2, pp: int = 1) -> None:

def model_provider() -> GPTModel:
# ↓↓ Mode 1 — patch once, everything below picks up Liger ↓↓
apply_liger_kernel_to_megatron(rms_norm=True)
apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True)
# ↑↑ ------------------------------------------------------ ↑↑

cfg = TransformerConfig(
Expand All @@ -84,7 +88,7 @@ def model_provider() -> GPTModel:
return GPTModel(
config=cfg,
transformer_layer_spec=get_gpt_layer_local_spec(normalization="RMSNorm"),
vocab_size=100,
vocab_size=128,
max_sequence_length=_SEQUENCE_LENGTH,
)

Expand Down Expand Up @@ -142,8 +146,21 @@ def _print_norm_classes(model: torch.nn.Module) -> None:
print()


def _print_ce_symbols() -> None:
"""Show the current bindings of Megatron's two CE entry points."""
import megatron.core.fusions.fused_cross_entropy as fused
import megatron.core.tensor_parallel.cross_entropy as unfused

print("\n=== Resolved CE symbols ===")
print(f" fused.fused_vocab_parallel_cross_entropy → {fused.fused_vocab_parallel_cross_entropy.__name__}")
print(f" unfused.vocab_parallel_cross_entropy → {unfused.vocab_parallel_cross_entropy.__name__}")
print()


def main() -> None:
initialize_distributed(tp=2, pp=1)
# TP=1, DP=2 — CE patch (TP=1 only). Norms are correct under any TP value, so
# demonstrating both Liger features in one script means running data-parallel.
initialize_distributed(tp=1, pp=1)
model_parallel_cuda_manual_seed(123)
torch.manual_seed(123)

Expand All @@ -153,6 +170,7 @@ def main() -> None:
print("\n=== Full model tree (mode 1: monkey-patch) ===")
print(gpt_model)
_print_norm_classes(gpt_model)
_print_ce_symbols()

ddp_cfg = DistributedDataParallelConfig(
grad_reduce_in_fp32=False,
Expand Down
Loading
Loading