Skip to content

feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check)#4148

Open
DOGEUNNKIM wants to merge 28 commits into
NVIDIA-NeMo:mainfrom
DOGEUNNKIM:main
Open

feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check)#4148
DOGEUNNKIM wants to merge 28 commits into
NVIDIA-NeMo:mainfrom
DOGEUNNKIM:main

Conversation

@DOGEUNNKIM

@DOGEUNNKIM DOGEUNNKIM commented Jun 4, 2026

Copy link
Copy Markdown

What does this PR do?

  • Adds full Gemma-4 E4B (3.8B dense text) support to Megatron-Bridge,
    consolidating all Gemma4-specific code that was previously scattered in
    Megatron-LM core (see MCore PR: [Feature-request] Add Megatron Core support geglu for Gemma 4 dense model (e.g. E4B) NVIDIA/Megatron-LM#5090).
  • Introduces Gemma4DenseProvider — a Bridge-native model provider that works
    against clean Megatron-Core (no Gemma4-specific CLI args or
    TransformerConfig fields required in MCore).
  • Separates the text-only path (Gemma4Bridge / Gemma4ForCausalLM) from the
    VL path (Gemma4VLBridge / Gemma4ForConditionalGeneration); gemma_vl
    imports from gemma, not vice versa.
  • Provides HF → Megatron checkpoint conversion, logit parity verification, and
    training launch scripts.

Changelog

src/megatron/bridge/models/gemma/modeling_gemma4.py

Core Gemma-4 architecture:

  • Gemma4TransformerLayer — 4-norm residual structure, Phase 4 PLE residual
    block (gelu(gate(h)) × per_layer_input → proj → norm → residual), per-layer
    scalar, optional MoE block.
  • Gemma4SelfAttention — GQA with per-head q_norm/k_norm,
    heterogeneous head_dim (sliding=256, global=512), shared-KV forwarding.
  • Gemma4DenseRotaryEmbedding — Dual-theta RoPE: sliding layers use
    θ=10 000, global layers use θ=1 000 000 with partial rotation (factor=0.25).
  • wire_gemma4_kv_sharing() — Post-construction wiring of shared-KV source
    references for the last num_kv_shared_layers layers.

src/megatron/bridge/models/gemma/gemma4_provider.py

  • Gemma4DenseProvider — All-in-one dataclass provider for clean MCore:
    builds GPTModel, attaches dual RoPE, PLE modules, and wires KV sharing.
    activation_func uses fast_gelu (registered as "gelu_pytorch_tanh" in
    ACTIVATION_FUNC_MAP) so megatron_to_hf_activation can match it by
    identity — avoids "Unsupported activation function" on HF export.

src/megatron/bridge/models/gemma_vl/modeling_gemma4_vl.py

  • Gemma4VLModel — Wraps HF vision/audio towers with Megatron language
    model.
    • _compute_attention_mask returns a (B, 1, S, S) bool tensor
      (True = blocked) instead of a dict, matching unit-test expectations and
      downstream pipeline usage.
    • forward() returns logits directly in inference mode
      (labels=None, loss_mask=None) to avoid tuple AttributeError in the
      pipeline forward pass.

src/megatron/bridge/models/gemma/gemma4_bridge.py / src/megatron/bridge/models/gemma_vl/gemma4_vl_bridge.py

  • Gemma4Bridge handles Gemma4ForCausalLM; includes a load-state alias hook
    that maps self_attention_sliding/global.* checkpoint keys to
    self_attention.* model keys, silencing the strict-load warning on export.
  • Gemma4VLBridge handles Gemma4ForConditionalGeneration; respects
    GEMMA4_CONVERSION_MODE=text|vl|audio to select the appropriate provider
    and parameter mappings.

examples/models/gemma/gemma4/conversion.sh

HF → Megatron import, Megatron → HF export, and round-trip parity check.
Sets GEMMA4_CONVERSION_MODE=text to use Gemma4DenseProvider (GPTModel)
for gemma-4-E4B-it which is otherwise detected as
Gemma4ForConditionalGeneration.

examples/models/gemma/gemma4/inference.sh

Text generation from HF or Megatron checkpoint using
hf_to_megatron_generate_text.py. Sets GEMMA4_CONVERSION_MODE=text; prompts
use the Gemma 4 IT chat template for the instruction-tuned model.

examples/models/gemma/gemma4/parity_check_e4b.py

Distributed logit parity check (TP=1/2). Expected max|diff| < 3.0 (bf16).

examples/models/gemma/gemma4/README.md

Usage guide: requirements, workspace setup, conversion, inference, parity
checks, pretraining, and architecture notes.

Relation to MCore PR

This PR is the Bridge half of the split requested in
Megatron-LM #5090.
The only MCore dependency retained is the generic per_layer_inputs hook
already threaded through TransformerBlock.forward().

Validation

Validated locally with:

  • HF → Megatron import ✅
  • Megatron → HF export ✅
  • Round-trip weight verification ✅
  • Text,vision,audio modality logit parity check ✅
  • Unit tests: all 227 tests passing ✅

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
  • Reviewer: Does the PR have correct import guards for all optional libraries?

DOGEUNNKIM and others added 5 commits June 4, 2026 03:56
Implement Gemma-4 E4B architecture as a Megatron-Bridge layer spec:
- Gemma4SelfAttention: GQA with per-head qk_layernorm, v_norm, sliding-window
  causal attention, and shared-KV cache (last num_kv_shared_layers reuse K/V)
- Dual-RoPE: sliding-window layers use theta=10000, full-attention layers use
  theta=1000000 with partial_factor=0.25
- Per-Layer Embeddings (PLE): per-layer vocab embedding projected and added to
  hidden states at each transformer layer (norm → linear → add embed lookup × 1/√2)
- Gemma4TransformerLayer: 4-norm residual structure matching HF implementation
- wire_gemma4_kv_sharing(): post-construction wiring of shared-KV references
- gemma4_layer_spec / get_gemma4_layer_spec(): ModuleSpec factory for --spec flag

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add loader plugin for converting HuggingFace Gemma-4 E4B checkpoints to
Megatron format, compatible with convert.py --loader gemma4_hf:
- QKV weight fusion and layout mapping (HF separate Q/K/V → Megatron fused)
- Per-Layer Embedding (PLE) weight mapping (embed_tokens_per_layer)
- GEGLU weight interleaved TP split (gate/up interleaved per rank, not contiguous)
- Shared-KV layer detection and zero-initialization for non-source layers
- geglu_tanh=True metadata to match HF gelu_pytorch_tanh activation

Usage (from Megatron-Bridge root):
  PYTHONPATH=$PWD/src:$PWD/examples/models/gemma/gemma4:$MEGATRON_LM_ROOT/tools/checkpoint \
  python $MEGATRON_LM_ROOT/tools/checkpoint/convert.py --loader gemma4_hf --saver core

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add end-to-end verification and training scripts for Gemma-4 E4B:
- parity_check_e4b.py: distributed logit parity check between converted
  Megatron checkpoint (TP=1/2) and HuggingFace reference model; applies
  final_logit_softcapping=30.0 before comparison; expected max|diff| < 3.0
  Fix: explicitly call Bridge's wire_gemma4_kv_sharing() after model
  construction so shared-KV layers are wired with the correct class
- train_gemma4_e4b_parity.sh: launcher for parity check (torchrun, TP=2)
- train_gemma4_e4b_pipeline.sh: full pipeline (convert → parity → training)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Document Gemma-4 E4B integration covering checkpoint conversion, parity
verification, and training. Includes PYTHONPATH setup for Bridge-based
loader discovery and note on GEGLU weight TP splitting fix.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add Gemma4E4BProvider to gemma4_layer_specs.py so the parity check
(and future training scripts) work against a clean Megatron-Core that
has no Gemma4-specific CLI args or TransformerConfig fields.

Provider responsibilities:
- Builds TransformerConfig with standard fields only; injects Gemma4
  fields (global_kv_channels, num_kv_shared_layers, per_layer_embed_*)
  via setattr after dataclass construction, guarding against the
  dual-RoPE ValueError added in clean MCore.
- Replaces model.rotary_pos_emb with Gemma4RotaryEmbedding (dual-theta).
- Attaches PLE modules (VocabParallelEmbedding + ColumnParallelLinear +
  Gemma4RMSNorm) and patches model.forward() to compute per_layer_inputs
  once and inject via extra_block_kwargs -> TransformerBlock threading.
- Calls wire_gemma4_kv_sharing() after construction.

Update parity_check_e4b.py:
- Remove 7 Gemma4-specific CLI flags from _build_megatron_argv().
- Replace gpt_builder / model_provider with Gemma4E4BProvider.build().
- No Megatron-LM source changes required.

Verified: TP=1 max|diff|=2.73, TP=2 max|diff|=2.94 (atol=3.0, bf16).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
@copy-pr-bot

copy-pr-bot Bot commented Jun 4, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@DOGEUNNKIM DOGEUNNKIM marked this pull request as ready for review June 4, 2026 04:53
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic feature New capabilities, enhancements, or enablement work needs-more-tests Requires additional L0 and L1 test coverage before merge needs-review PR is ready for code review and waiting on a reviewer labels Jun 4, 2026
DOGEUNNKIM and others added 4 commits June 4, 2026 08:55
- Fix parity expected results: fp32 ~0.15 (atol=0.3), bf16 ~2.73 (atol=3.0)
- Add Gemma4E4BProvider to What's included table
- Remove stale CLI flag references (--num-kv-shared-layers, --geglu-tanh,
  --per-layer-embed-*) — these are now Gemma4E4BProvider defaults
- Update GEGLU fix section: split signaled via md.geglu=True in loader
- Add PYTHONPATH to test command
- Note clean MCore compatibility (no Gemma4-specific CLI args)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
…ain.sh

Align with Bridge examples convention (slurm_pretrain.sh pattern used by
deepseek_v4, gpt_oss, stepfun, etc.).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Comment thread src/megatron/bridge/models/gemma/gemma4_bridge.py
Comment thread examples/models/gemma/gemma4/parity_check_e4b.py
Comment thread src/megatron/bridge/models/gemma_vl/gemma4_vl_bridge.py Outdated
@yaoyu-33 yaoyu-33 requested a review from weijiac0619 June 4, 2026 22:50
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Move shared Gemma4 text modeling, providers, and bridge logic under models/gemma, keep Gemma4 VL files focused on multimodal wrappers, and add text conversion/inference examples and tests.

Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
@DOGEUNNKIM

Copy link
Copy Markdown
Author

@yaoyu-33 Thanks for the review. I updated the PR based on the feedback.

  • Kept Gemma4ForCausalLM registration in Gemma4Bridge for pure text use cases.
  • Added examples/models/gemma/gemma4/conversion.sh and inference.sh.
  • Refactored pure LLM code out of models/gemma_vl.

The shared/text Gemma4 implementation now lives under models/gemma:

  • modeling_gemma4.py
  • gemma4_provider.py
  • gemma4_bridge.py

models/gemma_vl now only keeps multimodal-specific wrappers, VL providers, and VL/audio mappings, while reusing the text provider/modeling code from models/gemma.

I also fixed the Gemma4 Dense export-load warning caused by self_attention_sliding/global.* checkpoint keys versus self_attention.* model keys using a Gemma4-only load-state alias hook.

Local validation:

  • HF -> Megatron import: passed
  • Megatron -> HF export: passed
  • round-trip weight verification: passed
  • text ,vision, audio logit parity: passed

@yaoyu-33 yaoyu-33 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Structured review comments from local review.

Comment thread src/megatron/bridge/recipes/gemma/gemma4.py Outdated
Comment thread examples/models/gemma/gemma4/slurm_pretrain.sh Outdated
Comment thread examples/models/gemma/gemma4/slurm_pretrain.sh Outdated
Comment thread examples/models/gemma/gemma4/parity_check_e4b.py
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
@DOGEUNNKIM

Copy link
Copy Markdown
Author

@yaoyu-33 Thanks for the thorough review! I've addressed all the comments in [cc94737]:

Copyright headers:

Added NVIDIA 2026 Apache 2.0 headers to all new non-test Python files (gemma4_bridge.py, gemma4_provider.py, modeling_gemma4_vl.py, gemma4_vl_bridge.py, gemma4.py, parity_check_e4b.py).

Hardcoded paths:

Replaced /mnt/nvme0/kdg6245 defaults in slurm_pretrain.sh with ${VAR:?'Error: set VAR to a writable directory'} guards — the script now fails explicitly if the variable is unset rather than silently using a personal path.

####torchrun → uv:
Updated TORCHRUN_BIN default to uv run python -m torch.distributed.run per repo convention.

Recipe unit test:

Added tests/unit_tests/recipes/test_gemma4_recipe.py that verifies 23 critical model config fields stay in sync between gemma4_e4b_pretrain_config() and Gemma4Bridge._build_dense_provider(), catching silent drift at test time.

Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
@gautham-kollu

Copy link
Copy Markdown
Contributor

@DOGEUNNKIM gentle reminder of the merge conflicts

Signed-off-by: Dogeun Kim <82812668+DOGEUNNKIM@users.noreply.github.com>
@DOGEUNNKIM

Copy link
Copy Markdown
Author

@DOGEUNNKIM gentle reminder of the merge conflicts

@gautham-kollu Thanks for the reminder. I resolved the merge conflicts and pushed the updated changes.

@ko3n1g

ko3n1g commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

/ok to test 2715f5d

@DOGEUNNKIM

Copy link
Copy Markdown
Author

@weijiac0619

Hi, all CI checks are passing now.

Could you please review this again when you have a chance?

Thank you.

@weijiac0619

Copy link
Copy Markdown
Contributor

sure. taking a look now

@weijiac0619

Copy link
Copy Markdown
Contributor

Hi @DOGEUNNKIM , LGTM. Could you resolve the conflicts?

weijiac0619
weijiac0619 previously approved these changes Jun 15, 2026
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@DOGEUNNKIM

Copy link
Copy Markdown
Author

Hi @DOGEUNNKIM , LGTM. Could you resolve the conflicts?

@weijiac0619
The conflicts have been resolved by @yaoyu-33.
Could you please rerun the tests when you have a chance?

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic community-request feature New capabilities, enhancements, or enablement work needs-more-tests Requires additional L0 and L1 test coverage before merge ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants