Skip to content

Add Megatron-LM cross-entropy integration#1207

Draft
PrathyushaPolepalli wants to merge 1 commit intolinkedin:mainfrom
PrathyushaPolepalli:megatron-cross-entropy-integration
Draft

Add Megatron-LM cross-entropy integration#1207
PrathyushaPolepalli wants to merge 1 commit intolinkedin:mainfrom
PrathyushaPolepalli:megatron-cross-entropy-integration

Conversation

@PrathyushaPolepalli
Copy link
Copy Markdown

@PrathyushaPolepalli PrathyushaPolepalli commented Apr 28, 2026

Summary

Adds `apply_liger_kernel_to_megatron()` monkey-patch that swaps Megatron-LM's native `fused_vocab_parallel_cross_entropy` for Liger's Triton cross-entropy kernel.
  from liger_kernel.megatron import apply_liger_kernel_to_megatron                                                                                                                                                                                          
                                                                                                                                                                                                                                                            
  apply_liger_kernel_to_megatron(
      ignore_index=-100,                                                                                                                                                                                                                                    
      label_smoothing=cfg.label_smoothing_factor,                                                      
  )   

Enables online softmax + in-place gradients + no full-softmax materialization inside Megatron training pipelines.

Scope: tensor_model_parallel_size=1 only. With TP>1, each rank holds a sharded [N, V/tp] logits slice and CE requires cross-rank all-reduces that Liger's kernel does not perform.

The patch raises RuntimeError at patch time (via megatron.core.parallel_state) and again at call time (via the tp_group argument Megatron passes), so misconfiguration fails loudly. Vocab-parallel support is follow-up work.

Tested on Qwen3-30B-A3B scaled MoE, 1× H100_8, BF16:

Model config:

  • 24 layers, hidden=1024, FFN hidden=6144
  • 128 experts, top-8 routing, MoE FFN hidden=768
  • ~7.8B total params, ~0.8B active per token
  • Vocab size: 151,936
  • Sequence length: 4096

Parallelism:

  • Tensor Parallel (TP): 1
  • Pipeline Parallel (PP): 1
  • Expert Parallel (EP): 8 (16 experts per GPU)
  • Data Parallel (DP): 8 (non-expert), 1 (expert)

Training config:

  • Global batch size: 1024, Micro batch size: 2
  • Distributed Adam optimizer
  • Selective activation recompute (core_attn, mlp)
  • --cross-entropy-loss-fusion enabled

Throughput results:
| Throughput | Iter time
Megatron native fused CE (baseline) | ~99 TFLOP/s/GPU | ~39,400 ms
Liger CE (this PR) | ~108 (+9%) | ~35,900 ms

Numerical correctness: lm_loss ~4.1e-3 in both, no NaN/skipped iterations.
Variance: Liger CE 107.7-109.1 TFLOP/s/GPU (consistent).

megatron_cross_entropy_memory_full_token_length megatron_cross_entropy_speed_backward_token_length megatron_cross_entropy_speed_forward_token_length megatron_cross_entropy_speed_full_token_length

Test setup: Single H100 80GB, sequence length S=2048, batch size B=4, vocab sizes 4K → 131K. Each provider is the same cross-entropy operation, just different implementations:

  • liger — apply_liger_kernel_to_megatron() patch (Liger's Triton kernel)
  • torch — standard torch.nn.functional.cross_entropy
  • megatron — Megatron's native fused_vocab_parallel_cross_entropy

Testing Done

  • Hardware Type: H100
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@PrathyushaPolepalli PrathyushaPolepalli marked this pull request as draft April 28, 2026 05:59
@PrathyushaPolepalli PrathyushaPolepalli force-pushed the megatron-cross-entropy-integration branch 4 times, most recently from 642c576 to ed3c27e Compare April 29, 2026 23:27
@PrathyushaPolepalli PrathyushaPolepalli force-pushed the megatron-cross-entropy-integration branch from ed3c27e to b1fa5bc Compare April 29, 2026 23:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant