Skip to content

Commit bce8664

Browse files
Add Megatron-LM cross-entropy integration
1 parent dfbc4dd commit bce8664

7 files changed

Lines changed: 684 additions & 0 deletions

File tree

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""Benchmark Liger's Megatron-LM cross-entropy wrapper.
2+
3+
Benchmarks the Liger [seq, batch, vocab] cross-entropy wrapper against PyTorch's
4+
native ``F.cross_entropy`` on equivalent input shapes. When megatron-core is
5+
installed, Megatron's own ``fused_vocab_parallel_cross_entropy`` is added as a
6+
third provider to reproduce end-to-end comparisons.
7+
8+
Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not
9+
installed, the "megatron" provider is silently skipped.
10+
"""
11+
12+
import torch
13+
import torch.nn.functional as F
14+
import triton
15+
16+
from utils import QUANTILES
17+
from utils import SingleBenchmarkRunInput
18+
from utils import SingleBenchmarkRunOutput
19+
from utils import _test_memory
20+
from utils import parse_benchmark_script_args
21+
from utils import run_benchmarks
22+
23+
from liger_kernel.megatron.cross_entropy import _build_wrapper
24+
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
25+
from liger_kernel.utils import infer_device
26+
27+
device = infer_device()
28+
29+
try:
30+
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
31+
32+
_MEGATRON_AVAILABLE = True
33+
except ImportError:
34+
fused_vocab_parallel_cross_entropy = None
35+
_MEGATRON_AVAILABLE = False
36+
37+
38+
def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True):
39+
logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad)
40+
target = torch.randint(0, v, (s, b), device=device, dtype=torch.long)
41+
return logits, target
42+
43+
44+
def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
45+
s, b, v = logits.shape
46+
return F.cross_entropy(
47+
logits.reshape(-1, v).float(),
48+
target.reshape(-1),
49+
reduction="none",
50+
).reshape(s, b)
51+
52+
53+
def _ensure_single_rank_tp_group():
54+
"""Initialize torch.distributed (single-rank) and return a usable TP group.
55+
56+
For a single-process benchmark we use the world group of
57+
size 1, where the internal all-reduce becomes a no-op.
58+
"""
59+
import os
60+
61+
import torch.distributed as dist
62+
63+
if not dist.is_initialized():
64+
os.environ.setdefault("MASTER_ADDR", "localhost")
65+
os.environ.setdefault("MASTER_PORT", "29500")
66+
os.environ.setdefault("WORLD_SIZE", "1")
67+
os.environ.setdefault("RANK", "0")
68+
os.environ.setdefault("LOCAL_RANK", "0")
69+
dist.init_process_group(backend="nccl")
70+
return dist.group.WORLD
71+
72+
73+
def _select_fwd(provider: str):
74+
if provider == "liger":
75+
wrapper = _build_wrapper(LigerCrossEntropyLoss(reduction="none"))
76+
return wrapper
77+
if provider == "torch":
78+
return _pytorch_cross_entropy
79+
if provider == "megatron":
80+
if not _MEGATRON_AVAILABLE:
81+
raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider")
82+
tp_group = _ensure_single_rank_tp_group()
83+
84+
def _megatron_call(logits, target):
85+
return fused_vocab_parallel_cross_entropy(logits, target, tp_group)
86+
87+
return _megatron_call
88+
raise ValueError(f"unknown provider: {provider!r}")
89+
90+
91+
def bench_speed_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
92+
v = input.x
93+
provider = input.kernel_provider
94+
mode = input.kernel_operation_mode
95+
s = input.extra_benchmark_config["S"]
96+
b = input.extra_benchmark_config["B"]
97+
98+
logits, target = _make_inputs(s, b, v)
99+
fwd_fn = _select_fwd(provider)
100+
101+
def fwd():
102+
return fwd_fn(logits, target)
103+
104+
if mode == "forward":
105+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
106+
elif mode == "backward":
107+
# Megatron's fused CE writes gradients in-place into saved tensors during backward,
108+
# which breaks the standard retain_graph=True / repeated-backward pattern do_bench
109+
# uses elsewhere. Run a fresh fwd+bwd each iteration so each backward sees an
110+
# unmodified autograd graph. Measurement therefore includes forward time —
111+
# subtract the "forward" measurement to derive backward-only timing.
112+
def _fwd_bwd():
113+
if logits.grad is not None:
114+
logits.grad = None
115+
out = fwd_fn(logits, target)
116+
out.sum().backward()
117+
118+
ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES)
119+
elif mode == "full":
120+
121+
def full():
122+
y = fwd()
123+
y.sum().backward()
124+
125+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES)
126+
else:
127+
raise ValueError(f"unknown mode: {mode!r}")
128+
129+
return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80)
130+
131+
132+
def bench_memory_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
133+
v = input.x
134+
provider = input.kernel_provider
135+
s = input.extra_benchmark_config["S"]
136+
b = input.extra_benchmark_config["B"]
137+
138+
logits, target = _make_inputs(s, b, v)
139+
fwd_fn = _select_fwd(provider)
140+
141+
def full():
142+
y = fwd_fn(logits, target)
143+
y.sum().backward()
144+
145+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
146+
return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80)
147+
148+
149+
if __name__ == "__main__":
150+
args = parse_benchmark_script_args()
151+
152+
providers = ["liger", "torch"]
153+
if _MEGATRON_AVAILABLE:
154+
providers.append("megatron")
155+
156+
common_configs = {
157+
"kernel_name": "megatron_cross_entropy",
158+
"x_name": "V",
159+
"x_label": "vocab size",
160+
"x_values": [2**i for i in range(12, 18)],
161+
"kernel_providers": providers,
162+
"extra_benchmark_configs": [{"S": 2048, "B": 4}],
163+
"overwrite": args.overwrite,
164+
}
165+
166+
run_benchmarks(
167+
bench_test_fn=bench_speed_megatron_cross_entropy,
168+
kernel_operation_modes=["forward", "backward", "full"],
169+
metric_name="speed",
170+
metric_unit="ms",
171+
**common_configs,
172+
)
173+
run_benchmarks(
174+
bench_test_fn=bench_memory_megatron_cross_entropy,
175+
kernel_operation_modes=["full"],
176+
metric_name="memory",
177+
metric_unit="MB",
178+
**common_configs,
179+
)

docs/High-Level-APIs.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,45 @@ You can also use the Patching APIs to use the kernels for a specific model archi
9191
extra:
9292
show_docstring: true
9393
show_signature: true
94+
95+
---
96+
97+
## Megatron-LM
98+
99+
Liger also exposes a patch for the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
100+
training framework, replacing Megatron's native
101+
`fused_vocab_parallel_cross_entropy` with Liger's Triton cross-entropy kernel.
102+
103+
| **Framework** | **API** | **Supported Operations** |
104+
|---------------|--------------------------------------------------------|--------------------------|
105+
| Megatron-LM | `liger_kernel.megatron.apply_liger_kernel_to_megatron` | CrossEntropyLoss |
106+
107+
**Scope**: Initial release supports `tensor_model_parallel_size=1` only.
108+
Vocab-parallel cross-entropy (TP>1) is follow-up work — with TP>1, each rank
109+
holds a sharded `[N, V/tp]` logits slice and cross-entropy requires cross-rank
110+
all-reduces that Liger's kernel does not perform. The patch raises a
111+
`RuntimeError` at patch time or call time if TP>1 is detected.
112+
113+
**Usage**:
114+
115+
```python
116+
from liger_kernel.megatron import apply_liger_kernel_to_megatron
117+
118+
# Call before Megatron's forward pass reaches compute_language_model_loss.
119+
# Match Megatron's config: pass the same ignore_index and label_smoothing
120+
# values used by your training setup (Liger does not auto-detect them).
121+
apply_liger_kernel_to_megatron(
122+
ignore_index=-100,
123+
label_smoothing=cfg.label_smoothing_factor,
124+
)
125+
```
126+
127+
Ensure Megatron's fused-CE code path is enabled in your training config (e.g.
128+
`--cross-entropy-loss-fusion` in the Megatron-LM CLI) — if the unfused path is
129+
selected, the patched symbol is never called.
130+
131+
::: liger_kernel.megatron.apply_liger_kernel_to_megatron
132+
options:
133+
extra:
134+
show_docstring: true
135+
show_signature: true
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from liger_kernel.megatron.cross_entropy import apply_liger_kernel_to_megatron
2+
3+
__all__ = ["apply_liger_kernel_to_megatron"]
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
3+
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
4+
5+
6+
def _check_tensor_parallel_size_at_patch_time() -> None:
7+
"""Raise RuntimeError if Megatron's parallel state already reports TP>1.
8+
9+
If Megatron is importable but the parallel state is not yet initialized
10+
(for example, ``apply_liger_kernel_to_megatron`` is called before
11+
``initialize_megatron``), silently defer; the wrapper checks again at call
12+
time against the ``tp_group`` argument Megatron supplies.
13+
"""
14+
try:
15+
from megatron.core import parallel_state
16+
except ImportError:
17+
return
18+
try:
19+
tp_size = parallel_state.get_tensor_model_parallel_world_size()
20+
except (AssertionError, RuntimeError):
21+
return
22+
if tp_size > 1:
23+
raise RuntimeError(
24+
f"apply_liger_kernel_to_megatron currently requires tensor_model_parallel_size=1, "
25+
f"got {tp_size}. Vocab-parallel cross-entropy support is planned as follow-up work."
26+
)
27+
28+
29+
def _build_wrapper(loss_fn: LigerCrossEntropyLoss):
30+
"""Build a drop-in replacement for ``fused_vocab_parallel_cross_entropy``.
31+
32+
The returned callable has exactly the same parameter list Megatron expects
33+
(``vocab_parallel_logits``, ``target``, ``tp_group``). Any unknown kwargs
34+
will raise ``TypeError`` naturally — this is intentional: if a future
35+
Megatron release adds new parameters to the fused-CE contract, we want to
36+
fail loudly rather than silently drop them.
37+
"""
38+
39+
def liger_fused_vocab_parallel_cross_entropy(
40+
vocab_parallel_logits: torch.Tensor,
41+
target: torch.Tensor,
42+
tp_group=None,
43+
) -> torch.Tensor:
44+
if tp_group is not None and hasattr(tp_group, "size") and tp_group.size() > 1:
45+
raise RuntimeError(
46+
f"Liger Megatron cross-entropy wrapper requires tensor_model_parallel_size=1, "
47+
f"got tp_group.size()={tp_group.size()}. Vocab-parallel support is tracked as "
48+
f"follow-up work."
49+
)
50+
51+
s, b, v = vocab_parallel_logits.shape
52+
logits_2d = vocab_parallel_logits.reshape(-1, v)
53+
target_1d = target.reshape(-1)
54+
loss = loss_fn(logits_2d, target_1d)
55+
return loss.reshape(s, b)
56+
57+
return liger_fused_vocab_parallel_cross_entropy
58+
59+
60+
def apply_liger_kernel_to_megatron(
61+
reduction: str = "none",
62+
ignore_index: int = -100,
63+
label_smoothing: float = 0.0,
64+
) -> None:
65+
"""Replace Megatron-LM's fused_vocab_parallel_cross_entropy with Liger's Triton cross-entropy.
66+
67+
This monkey-patches
68+
``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``
69+
so that Megatron training pipelines use Liger's Triton kernel (online
70+
softmax, in-place gradients, no full-softmax materialization) instead of
71+
Megatron's native fused implementation.
72+
73+
Args:
74+
reduction: Must be ``"none"``; Megatron's fused-CE contract returns
75+
per-token loss shaped ``[seq, batch]`` and handles reduction itself
76+
downstream.
77+
ignore_index: Target index to ignore. Pass the value used in your
78+
Megatron training config.
79+
label_smoothing: Cross-entropy label smoothing factor. Liger does not
80+
auto-detect this — callers should pass
81+
``cfg.label_smoothing_factor`` (or equivalent) from their
82+
Megatron ``TransformerConfig`` if label smoothing is enabled, to
83+
preserve the native behavior.
84+
85+
Scope:
86+
Initial release supports ``tensor_model_parallel_size=1`` only. With
87+
TP>1, each rank holds a vocab-sharded logits slice ``[N, V/tp]`` and
88+
computing cross-entropy requires cross-rank all-reduces that Liger's
89+
kernel does not perform. A ``RuntimeError`` is raised at patch time if
90+
the Megatron parallel state already reports TP>1, and again at call
91+
time if a multi-rank ``tp_group`` is passed.
92+
93+
Raises:
94+
AssertionError: If ``reduction != "none"``.
95+
ImportError: If ``megatron.core.fusions.fused_cross_entropy`` is not
96+
importable, or if the expected
97+
``fused_vocab_parallel_cross_entropy`` symbol is missing from that
98+
module (indicating an incompatible Megatron version).
99+
RuntimeError: If tensor model parallelism > 1 is detected.
100+
101+
Example:
102+
>>> from liger_kernel.megatron import apply_liger_kernel_to_megatron
103+
>>> apply_liger_kernel_to_megatron(
104+
... ignore_index=-100,
105+
... label_smoothing=cfg.label_smoothing_factor,
106+
... )
107+
>>> # call before Megatron's forward pass reaches compute_language_model_loss
108+
"""
109+
assert reduction == "none", (
110+
f"Megatron's fused_vocab_parallel_cross_entropy contract requires per-token loss; "
111+
f"reduction must be 'none', got {reduction!r}."
112+
)
113+
114+
try:
115+
import megatron.core.fusions.fused_cross_entropy as fce
116+
except ImportError as exc:
117+
raise ImportError(
118+
"apply_liger_kernel_to_megatron requires megatron-core to be installed. "
119+
"Expected symbol path: "
120+
"megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy."
121+
) from exc
122+
123+
if not hasattr(fce, "fused_vocab_parallel_cross_entropy"):
124+
raise ImportError(
125+
"megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy not found. "
126+
"The symbol path may have changed in your Megatron-LM version. Please file an issue "
127+
"on https://github.com/linkedin/Liger-Kernel with your megatron-core version."
128+
)
129+
130+
_check_tensor_parallel_size_at_patch_time()
131+
132+
loss_fn = LigerCrossEntropyLoss(
133+
ignore_index=ignore_index,
134+
label_smoothing=label_smoothing,
135+
reduction="none",
136+
)
137+
fce.fused_vocab_parallel_cross_entropy = _build_wrapper(loss_fn)

test/megatron/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)