|
| 1 | +"""Benchmark Liger's Megatron-LM cross-entropy wrapper. |
| 2 | +
|
| 3 | +Benchmarks the Liger [seq, batch, vocab] cross-entropy wrapper against PyTorch's |
| 4 | +native ``F.cross_entropy`` on equivalent input shapes. When megatron-core is |
| 5 | +installed, Megatron's own ``fused_vocab_parallel_cross_entropy`` is added as a |
| 6 | +third provider to reproduce end-to-end comparisons. |
| 7 | +
|
| 8 | +Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not |
| 9 | +installed, the "megatron" provider is silently skipped. |
| 10 | +""" |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn.functional as F |
| 14 | +import triton |
| 15 | + |
| 16 | +from utils import QUANTILES |
| 17 | +from utils import SingleBenchmarkRunInput |
| 18 | +from utils import SingleBenchmarkRunOutput |
| 19 | +from utils import _test_memory |
| 20 | +from utils import parse_benchmark_script_args |
| 21 | +from utils import run_benchmarks |
| 22 | + |
| 23 | +from liger_kernel.megatron.cross_entropy import _build_wrapper |
| 24 | +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss |
| 25 | +from liger_kernel.utils import infer_device |
| 26 | + |
| 27 | +device = infer_device() |
| 28 | + |
| 29 | +try: |
| 30 | + from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy |
| 31 | + |
| 32 | + _MEGATRON_AVAILABLE = True |
| 33 | +except ImportError: |
| 34 | + fused_vocab_parallel_cross_entropy = None |
| 35 | + _MEGATRON_AVAILABLE = False |
| 36 | + |
| 37 | + |
| 38 | +def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True): |
| 39 | + logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad) |
| 40 | + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) |
| 41 | + return logits, target |
| 42 | + |
| 43 | + |
| 44 | +def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| 45 | + s, b, v = logits.shape |
| 46 | + return F.cross_entropy( |
| 47 | + logits.reshape(-1, v).float(), |
| 48 | + target.reshape(-1), |
| 49 | + reduction="none", |
| 50 | + ).reshape(s, b) |
| 51 | + |
| 52 | + |
| 53 | +def _ensure_single_rank_tp_group(): |
| 54 | + """Initialize torch.distributed (single-rank) and return a usable TP group. |
| 55 | +
|
| 56 | + For a single-process benchmark we use the world group of |
| 57 | + size 1, where the internal all-reduce becomes a no-op. |
| 58 | + """ |
| 59 | + import os |
| 60 | + |
| 61 | + import torch.distributed as dist |
| 62 | + |
| 63 | + if not dist.is_initialized(): |
| 64 | + os.environ.setdefault("MASTER_ADDR", "localhost") |
| 65 | + os.environ.setdefault("MASTER_PORT", "29500") |
| 66 | + os.environ.setdefault("WORLD_SIZE", "1") |
| 67 | + os.environ.setdefault("RANK", "0") |
| 68 | + os.environ.setdefault("LOCAL_RANK", "0") |
| 69 | + dist.init_process_group(backend="nccl") |
| 70 | + return dist.group.WORLD |
| 71 | + |
| 72 | + |
| 73 | +def _select_fwd(provider: str): |
| 74 | + if provider == "liger": |
| 75 | + wrapper = _build_wrapper(LigerCrossEntropyLoss(reduction="none")) |
| 76 | + return wrapper |
| 77 | + if provider == "torch": |
| 78 | + return _pytorch_cross_entropy |
| 79 | + if provider == "megatron": |
| 80 | + if not _MEGATRON_AVAILABLE: |
| 81 | + raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider") |
| 82 | + tp_group = _ensure_single_rank_tp_group() |
| 83 | + |
| 84 | + def _megatron_call(logits, target): |
| 85 | + return fused_vocab_parallel_cross_entropy(logits, target, tp_group) |
| 86 | + |
| 87 | + return _megatron_call |
| 88 | + raise ValueError(f"unknown provider: {provider!r}") |
| 89 | + |
| 90 | + |
| 91 | +def bench_speed_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 92 | + v = input.x |
| 93 | + provider = input.kernel_provider |
| 94 | + mode = input.kernel_operation_mode |
| 95 | + s = input.extra_benchmark_config["S"] |
| 96 | + b = input.extra_benchmark_config["B"] |
| 97 | + |
| 98 | + logits, target = _make_inputs(s, b, v) |
| 99 | + fwd_fn = _select_fwd(provider) |
| 100 | + |
| 101 | + def fwd(): |
| 102 | + return fwd_fn(logits, target) |
| 103 | + |
| 104 | + if mode == "forward": |
| 105 | + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
| 106 | + elif mode == "backward": |
| 107 | + # Megatron's fused CE writes gradients in-place into saved tensors during backward, |
| 108 | + # which breaks the standard retain_graph=True / repeated-backward pattern do_bench |
| 109 | + # uses elsewhere. Run a fresh fwd+bwd each iteration so each backward sees an |
| 110 | + # unmodified autograd graph. Measurement therefore includes forward time — |
| 111 | + # subtract the "forward" measurement to derive backward-only timing. |
| 112 | + def _fwd_bwd(): |
| 113 | + if logits.grad is not None: |
| 114 | + logits.grad = None |
| 115 | + out = fwd_fn(logits, target) |
| 116 | + out.sum().backward() |
| 117 | + |
| 118 | + ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES) |
| 119 | + elif mode == "full": |
| 120 | + |
| 121 | + def full(): |
| 122 | + y = fwd() |
| 123 | + y.sum().backward() |
| 124 | + |
| 125 | + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) |
| 126 | + else: |
| 127 | + raise ValueError(f"unknown mode: {mode!r}") |
| 128 | + |
| 129 | + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) |
| 130 | + |
| 131 | + |
| 132 | +def bench_memory_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 133 | + v = input.x |
| 134 | + provider = input.kernel_provider |
| 135 | + s = input.extra_benchmark_config["S"] |
| 136 | + b = input.extra_benchmark_config["B"] |
| 137 | + |
| 138 | + logits, target = _make_inputs(s, b, v) |
| 139 | + fwd_fn = _select_fwd(provider) |
| 140 | + |
| 141 | + def full(): |
| 142 | + y = fwd_fn(logits, target) |
| 143 | + y.sum().backward() |
| 144 | + |
| 145 | + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) |
| 146 | + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) |
| 147 | + |
| 148 | + |
| 149 | +if __name__ == "__main__": |
| 150 | + args = parse_benchmark_script_args() |
| 151 | + |
| 152 | + providers = ["liger", "torch"] |
| 153 | + if _MEGATRON_AVAILABLE: |
| 154 | + providers.append("megatron") |
| 155 | + |
| 156 | + common_configs = { |
| 157 | + "kernel_name": "megatron_cross_entropy", |
| 158 | + "x_name": "V", |
| 159 | + "x_label": "vocab size", |
| 160 | + "x_values": [2**i for i in range(12, 18)], |
| 161 | + "kernel_providers": providers, |
| 162 | + "extra_benchmark_configs": [{"S": 2048, "B": 4}], |
| 163 | + "overwrite": args.overwrite, |
| 164 | + } |
| 165 | + |
| 166 | + run_benchmarks( |
| 167 | + bench_test_fn=bench_speed_megatron_cross_entropy, |
| 168 | + kernel_operation_modes=["forward", "backward", "full"], |
| 169 | + metric_name="speed", |
| 170 | + metric_unit="ms", |
| 171 | + **common_configs, |
| 172 | + ) |
| 173 | + run_benchmarks( |
| 174 | + bench_test_fn=bench_memory_megatron_cross_entropy, |
| 175 | + kernel_operation_modes=["full"], |
| 176 | + metric_name="memory", |
| 177 | + metric_unit="MB", |
| 178 | + **common_configs, |
| 179 | + ) |
0 commit comments