feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check)#4148
feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check)#4148DOGEUNNKIM wants to merge 28 commits into
Conversation
Implement Gemma-4 E4B architecture as a Megatron-Bridge layer spec: - Gemma4SelfAttention: GQA with per-head qk_layernorm, v_norm, sliding-window causal attention, and shared-KV cache (last num_kv_shared_layers reuse K/V) - Dual-RoPE: sliding-window layers use theta=10000, full-attention layers use theta=1000000 with partial_factor=0.25 - Per-Layer Embeddings (PLE): per-layer vocab embedding projected and added to hidden states at each transformer layer (norm → linear → add embed lookup × 1/√2) - Gemma4TransformerLayer: 4-norm residual structure matching HF implementation - wire_gemma4_kv_sharing(): post-construction wiring of shared-KV references - gemma4_layer_spec / get_gemma4_layer_spec(): ModuleSpec factory for --spec flag Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add loader plugin for converting HuggingFace Gemma-4 E4B checkpoints to Megatron format, compatible with convert.py --loader gemma4_hf: - QKV weight fusion and layout mapping (HF separate Q/K/V → Megatron fused) - Per-Layer Embedding (PLE) weight mapping (embed_tokens_per_layer) - GEGLU weight interleaved TP split (gate/up interleaved per rank, not contiguous) - Shared-KV layer detection and zero-initialization for non-source layers - geglu_tanh=True metadata to match HF gelu_pytorch_tanh activation Usage (from Megatron-Bridge root): PYTHONPATH=$PWD/src:$PWD/examples/models/gemma/gemma4:$MEGATRON_LM_ROOT/tools/checkpoint \ python $MEGATRON_LM_ROOT/tools/checkpoint/convert.py --loader gemma4_hf --saver core Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add end-to-end verification and training scripts for Gemma-4 E4B: - parity_check_e4b.py: distributed logit parity check between converted Megatron checkpoint (TP=1/2) and HuggingFace reference model; applies final_logit_softcapping=30.0 before comparison; expected max|diff| < 3.0 Fix: explicitly call Bridge's wire_gemma4_kv_sharing() after model construction so shared-KV layers are wired with the correct class - train_gemma4_e4b_parity.sh: launcher for parity check (torchrun, TP=2) - train_gemma4_e4b_pipeline.sh: full pipeline (convert → parity → training) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Document Gemma-4 E4B integration covering checkpoint conversion, parity verification, and training. Includes PYTHONPATH setup for Bridge-based loader discovery and note on GEGLU weight TP splitting fix. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add Gemma4E4BProvider to gemma4_layer_specs.py so the parity check (and future training scripts) work against a clean Megatron-Core that has no Gemma4-specific CLI args or TransformerConfig fields. Provider responsibilities: - Builds TransformerConfig with standard fields only; injects Gemma4 fields (global_kv_channels, num_kv_shared_layers, per_layer_embed_*) via setattr after dataclass construction, guarding against the dual-RoPE ValueError added in clean MCore. - Replaces model.rotary_pos_emb with Gemma4RotaryEmbedding (dual-theta). - Attaches PLE modules (VocabParallelEmbedding + ColumnParallelLinear + Gemma4RMSNorm) and patches model.forward() to compute per_layer_inputs once and inject via extra_block_kwargs -> TransformerBlock threading. - Calls wire_gemma4_kv_sharing() after construction. Update parity_check_e4b.py: - Remove 7 Gemma4-specific CLI flags from _build_megatron_argv(). - Replace gpt_builder / model_provider with Gemma4E4BProvider.build(). - No Megatron-LM source changes required. Verified: TP=1 max|diff|=2.73, TP=2 max|diff|=2.94 (atol=3.0, bf16). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
- Fix parity expected results: fp32 ~0.15 (atol=0.3), bf16 ~2.73 (atol=3.0) - Add Gemma4E4BProvider to What's included table - Remove stale CLI flag references (--num-kv-shared-layers, --geglu-tanh, --per-layer-embed-*) — these are now Gemma4E4BProvider defaults - Update GEGLU fix section: split signaled via md.geglu=True in loader - Add PYTHONPATH to test command - Note clean MCore compatibility (no Gemma4-specific CLI args) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
…ain.sh Align with Bridge examples convention (slurm_pretrain.sh pattern used by deepseek_v4, gpt_oss, stepfun, etc.). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Move shared Gemma4 text modeling, providers, and bridge logic under models/gemma, keep Gemma4 VL files focused on multimodal wrappers, and add text conversion/inference examples and tests. Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
|
@yaoyu-33 Thanks for the review. I updated the PR based on the feedback.
The shared/text Gemma4 implementation now lives under
I also fixed the Gemma4 Dense export-load warning caused by Local validation:
|
yaoyu-33
left a comment
There was a problem hiding this comment.
Structured review comments from local review.
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
|
@yaoyu-33 Thanks for the thorough review! I've addressed all the comments in [cc94737]: Copyright headers:Added NVIDIA 2026 Apache 2.0 headers to all new non-test Python files (gemma4_bridge.py, gemma4_provider.py, modeling_gemma4_vl.py, gemma4_vl_bridge.py, gemma4.py, parity_check_e4b.py). Hardcoded paths:Replaced /mnt/nvme0/kdg6245 defaults in slurm_pretrain.sh with ${VAR:?'Error: set VAR to a writable directory'} guards — the script now fails explicitly if the variable is unset rather than silently using a personal path. ####torchrun → uv: Recipe unit test:Added tests/unit_tests/recipes/test_gemma4_recipe.py that verifies 23 critical model config fields stay in sync between gemma4_e4b_pretrain_config() and Gemma4Bridge._build_dense_provider(), catching silent drift at test time. |
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
|
@DOGEUNNKIM gentle reminder of the merge conflicts |
Signed-off-by: Dogeun Kim <82812668+DOGEUNNKIM@users.noreply.github.com>
@gautham-kollu Thanks for the reminder. I resolved the merge conflicts and pushed the updated changes. |
|
/ok to test 2715f5d |
|
Hi, all CI checks are passing now. Could you please review this again when you have a chance? Thank you. |
|
sure. taking a look now |
|
Hi @DOGEUNNKIM , LGTM. Could you resolve the conflicts? |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@weijiac0619 Thank you! |
What does this PR do?
consolidating all Gemma4-specific code that was previously scattered in
Megatron-LM core (see MCore PR: [Feature-request] Add Megatron Core support geglu for Gemma 4 dense model (e.g. E4B) NVIDIA/Megatron-LM#5090).
Gemma4DenseProvider— a Bridge-native model provider that worksagainst clean Megatron-Core (no Gemma4-specific CLI args or
TransformerConfigfields required in MCore).Gemma4Bridge/Gemma4ForCausalLM) from theVL path (
Gemma4VLBridge/Gemma4ForConditionalGeneration);gemma_vlimports from
gemma, not vice versa.training launch scripts.
Changelog
src/megatron/bridge/models/gemma/modeling_gemma4.pyCore Gemma-4 architecture:
Gemma4TransformerLayer— 4-norm residual structure, Phase 4 PLE residualblock (
gelu(gate(h)) × per_layer_input → proj → norm → residual), per-layerscalar, optional MoE block.
Gemma4SelfAttention— GQA with per-headq_norm/k_norm,heterogeneous head_dim (sliding=256, global=512), shared-KV forwarding.
Gemma4DenseRotaryEmbedding— Dual-theta RoPE: sliding layers useθ=10 000, global layers use θ=1 000 000 with partial rotation (factor=0.25).
wire_gemma4_kv_sharing()— Post-construction wiring of shared-KV sourcereferences for the last
num_kv_shared_layerslayers.src/megatron/bridge/models/gemma/gemma4_provider.pyGemma4DenseProvider— All-in-one dataclass provider for clean MCore:builds
GPTModel, attaches dual RoPE, PLE modules, and wires KV sharing.activation_funcusesfast_gelu(registered as"gelu_pytorch_tanh"inACTIVATION_FUNC_MAP) somegatron_to_hf_activationcan match it byidentity — avoids
"Unsupported activation function"on HF export.src/megatron/bridge/models/gemma_vl/modeling_gemma4_vl.pyGemma4VLModel— Wraps HF vision/audio towers with Megatron languagemodel.
_compute_attention_maskreturns a(B, 1, S, S)bool tensor(
True = blocked) instead of a dict, matching unit-test expectations anddownstream pipeline usage.
forward()returns logits directly in inference mode(
labels=None, loss_mask=None) to avoidtupleAttributeErrorin thepipeline forward pass.
src/megatron/bridge/models/gemma/gemma4_bridge.py/src/megatron/bridge/models/gemma_vl/gemma4_vl_bridge.pyGemma4BridgehandlesGemma4ForCausalLM; includes a load-state alias hookthat maps
self_attention_sliding/global.*checkpoint keys toself_attention.*model keys, silencing the strict-load warning on export.Gemma4VLBridgehandlesGemma4ForConditionalGeneration; respectsGEMMA4_CONVERSION_MODE=text|vl|audioto select the appropriate providerand parameter mappings.
examples/models/gemma/gemma4/conversion.shHF → Megatron import, Megatron → HF export, and round-trip parity check.
Sets
GEMMA4_CONVERSION_MODE=textto useGemma4DenseProvider(GPTModel)for
gemma-4-E4B-itwhich is otherwise detected asGemma4ForConditionalGeneration.examples/models/gemma/gemma4/inference.shText generation from HF or Megatron checkpoint using
hf_to_megatron_generate_text.py. SetsGEMMA4_CONVERSION_MODE=text; promptsuse the Gemma 4 IT chat template for the instruction-tuned model.
examples/models/gemma/gemma4/parity_check_e4b.pyDistributed logit parity check (TP=1/2). Expected max|diff| < 3.0 (bf16).
examples/models/gemma/gemma4/README.mdUsage guide: requirements, workspace setup, conversion, inference, parity
checks, pretraining, and architecture notes.
Relation to MCore PR
This PR is the Bridge half of the split requested in
Megatron-LM #5090.
The only MCore dependency retained is the generic
per_layer_inputshookalready threaded through
TransformerBlock.forward().Validation
Validated locally with:
Before your PR is "Ready for review"
Pre checks: