Add Megatron-LM cross-entropy integration#1207
Draft
PrathyushaPolepalli wants to merge 1 commit intolinkedin:mainfrom
Draft
Add Megatron-LM cross-entropy integration#1207PrathyushaPolepalli wants to merge 1 commit intolinkedin:mainfrom
PrathyushaPolepalli wants to merge 1 commit intolinkedin:mainfrom
Conversation
642c576 to
ed3c27e
Compare
ed3c27e to
b1fa5bc
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds `apply_liger_kernel_to_megatron()` monkey-patch that swaps Megatron-LM's native `fused_vocab_parallel_cross_entropy` for Liger's Triton cross-entropy kernel.Enables online softmax + in-place gradients + no full-softmax materialization inside Megatron training pipelines.
Scope:
tensor_model_parallel_size=1only. With TP>1, each rank holds a sharded[N, V/tp]logits slice and CE requires cross-rank all-reduces that Liger's kernel does not perform.The patch raises
RuntimeErrorat patch time (viamegatron.core.parallel_state) and again at call time (via thetp_groupargument Megatron passes), so misconfiguration fails loudly. Vocab-parallel support is follow-up work.Tested on Qwen3-30B-A3B scaled MoE, 1× H100_8, BF16:
Model config:
Parallelism:
Training config:
Throughput results:
| Throughput | Iter time
Megatron native fused CE (baseline) | ~99 TFLOP/s/GPU | ~39,400 ms
Liger CE (this PR) | ~108 (+9%) | ~35,900 ms
Numerical correctness: lm_loss ~4.1e-3 in both, no NaN/skipped iterations.
Variance: Liger CE 107.7-109.1 TFLOP/s/GPU (consistent).
Test setup: Single H100 80GB, sequence length S=2048, batch size B=4, vocab sizes 4K → 131K. Each provider is the same cross-entropy operation, just different implementations:
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence