Skip to content

mirror: feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check)#4297

Open
ko3n1g wants to merge 25 commits into
mainfrom
ko3n1g/mirror/pr-4148
Open

mirror: feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check)#4297
ko3n1g wants to merge 25 commits into
mainfrom
ko3n1g/mirror/pr-4148

Conversation

@ko3n1g

@ko3n1g ko3n1g commented Jun 11, 2026

Copy link
Copy Markdown
Contributor
Claude summary

Mirror of #4148 by @DOGEUNNKIM — copied into the upstream repo so the full CI pipeline runs natively (cross-fork PRs cannot trigger it).

Commits are copied verbatim with authorship preserved. Review and discussion remain on #4148.

DOGEUNNKIM and others added 20 commits June 4, 2026 03:56
Implement Gemma-4 E4B architecture as a Megatron-Bridge layer spec:
- Gemma4SelfAttention: GQA with per-head qk_layernorm, v_norm, sliding-window
  causal attention, and shared-KV cache (last num_kv_shared_layers reuse K/V)
- Dual-RoPE: sliding-window layers use theta=10000, full-attention layers use
  theta=1000000 with partial_factor=0.25
- Per-Layer Embeddings (PLE): per-layer vocab embedding projected and added to
  hidden states at each transformer layer (norm → linear → add embed lookup × 1/√2)
- Gemma4TransformerLayer: 4-norm residual structure matching HF implementation
- wire_gemma4_kv_sharing(): post-construction wiring of shared-KV references
- gemma4_layer_spec / get_gemma4_layer_spec(): ModuleSpec factory for --spec flag

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add loader plugin for converting HuggingFace Gemma-4 E4B checkpoints to
Megatron format, compatible with convert.py --loader gemma4_hf:
- QKV weight fusion and layout mapping (HF separate Q/K/V → Megatron fused)
- Per-Layer Embedding (PLE) weight mapping (embed_tokens_per_layer)
- GEGLU weight interleaved TP split (gate/up interleaved per rank, not contiguous)
- Shared-KV layer detection and zero-initialization for non-source layers
- geglu_tanh=True metadata to match HF gelu_pytorch_tanh activation

Usage (from Megatron-Bridge root):
  PYTHONPATH=$PWD/src:$PWD/examples/models/gemma/gemma4:$MEGATRON_LM_ROOT/tools/checkpoint \
  python $MEGATRON_LM_ROOT/tools/checkpoint/convert.py --loader gemma4_hf --saver core

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add end-to-end verification and training scripts for Gemma-4 E4B:
- parity_check_e4b.py: distributed logit parity check between converted
  Megatron checkpoint (TP=1/2) and HuggingFace reference model; applies
  final_logit_softcapping=30.0 before comparison; expected max|diff| < 3.0
  Fix: explicitly call Bridge's wire_gemma4_kv_sharing() after model
  construction so shared-KV layers are wired with the correct class
- train_gemma4_e4b_parity.sh: launcher for parity check (torchrun, TP=2)
- train_gemma4_e4b_pipeline.sh: full pipeline (convert → parity → training)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Document Gemma-4 E4B integration covering checkpoint conversion, parity
verification, and training. Includes PYTHONPATH setup for Bridge-based
loader discovery and note on GEGLU weight TP splitting fix.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add Gemma4E4BProvider to gemma4_layer_specs.py so the parity check
(and future training scripts) work against a clean Megatron-Core that
has no Gemma4-specific CLI args or TransformerConfig fields.

Provider responsibilities:
- Builds TransformerConfig with standard fields only; injects Gemma4
  fields (global_kv_channels, num_kv_shared_layers, per_layer_embed_*)
  via setattr after dataclass construction, guarding against the
  dual-RoPE ValueError added in clean MCore.
- Replaces model.rotary_pos_emb with Gemma4RotaryEmbedding (dual-theta).
- Attaches PLE modules (VocabParallelEmbedding + ColumnParallelLinear +
  Gemma4RMSNorm) and patches model.forward() to compute per_layer_inputs
  once and inject via extra_block_kwargs -> TransformerBlock threading.
- Calls wire_gemma4_kv_sharing() after construction.

Update parity_check_e4b.py:
- Remove 7 Gemma4-specific CLI flags from _build_megatron_argv().
- Replace gpt_builder / model_provider with Gemma4E4BProvider.build().
- No Megatron-LM source changes required.

Verified: TP=1 max|diff|=2.73, TP=2 max|diff|=2.94 (atol=3.0, bf16).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
- Fix parity expected results: fp32 ~0.15 (atol=0.3), bf16 ~2.73 (atol=3.0)
- Add Gemma4E4BProvider to What's included table
- Remove stale CLI flag references (--num-kv-shared-layers, --geglu-tanh,
  --per-layer-embed-*) — these are now Gemma4E4BProvider defaults
- Update GEGLU fix section: split signaled via md.geglu=True in loader
- Add PYTHONPATH to test command
- Note clean MCore compatibility (no Gemma4-specific CLI args)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
…ain.sh

Align with Bridge examples convention (slurm_pretrain.sh pattern used by
deepseek_v4, gpt_oss, stepfun, etc.).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Move shared Gemma4 text modeling, providers, and bridge logic under models/gemma, keep Gemma4 VL files focused on multimodal wrappers, and add text conversion/inference examples and tests.

Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
@copy-pr-bot

copy-pr-bot Bot commented Jun 11, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g

ko3n1g commented Jun 11, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 1d0f4ba

@claude

claude Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Light Code Review — PR 4297

Findings

1. Bare print() in examples/models/gemma/gemma4/parity_check_e4b.py

About 12 bare print() calls in a distributed script. Project rule: NEVER use bare print() — use logging.getLogger(name) or print_rank_0(). Since this is a torchrun script that runs on every rank, the output will be duplicated. Use print_rank_0() (already imported from megatron.core) or logging.

2. MoE VL provider defaults to float32 while text-only MoE uses bfloat16

gemma4_vl_bridge.py sets provider.bf16 = False / provider.params_dtype = torch.float32 for the MoE VL path, while the text-only MoE path in gemma4_bridge.py sets bf16 = True / bfloat16. The test (test_dtype_is_fp32_for_vl) confirms this is intentional, but it doubles memory usage for MoE VL users going through AutoBridge. If this is needed for parity/precision reasons, a brief code comment explaining why the VL path differs from the text path would help future readers.

3. No _compute_attention_mask handling for audio bidirectional attention

In modeling_gemma4_vl.py, _compute_attention_mask creates bidirectional blocks for image token spans via _bidirectional_block_mask, but audio token spans receive only the default causal mask. If the Gemma 4 architecture expects bidirectional self-attention within audio segments (similar to image segments), this is a logic gap. If audio is intentionally causal, a comment would clarify.

Missing Test Coverage

  • _conversion_mode() dispatch logic: Gemma4VLBridge._conversion_mode() reads GEMMA4_CONVERSION_MODE and routes to text/vl/audio. No unit test exercises this dispatch or verifies the env var fallback.
  • _gemma4_text_conversion_mode exception safety: The context manager in gemma4.py sets and restores GEMMA4_CONVERSION_MODE. No test verifies the env var is restored when an exception occurs inside the context.
  • _keep_hf_precision_buffers_in_fp32(): New helper in modeling_gemma4_vl.py that keeps inv_freq / inv_timescales buffers in fp32. No unit test verifies this hook works correctly.
  • Audio feature scatter path: Gemma4VLModel.forward() now handles input_features (audio) via _scatter_modality_features, but no test exercises this branch.

Suggested Test Cases

  • tests/unit_tests/models/gemma_vl/test_gemma4_vl_bridge.py::TestGemma4VLBridgeConversionMode::test_conversion_mode_returns_text_when_env_set
  • tests/unit_tests/models/gemma_vl/test_gemma4_vl_bridge.py::TestGemma4VLBridgeConversionMode::test_conversion_mode_returns_vl_by_default
  • tests/unit_tests/models/gemma_vl/test_gemma4_vl_bridge.py::TestGemma4VLBridgeConversionMode::test_conversion_mode_audio_dispatch
  • tests/unit_tests/recipes/gemma/test_gemma4_recipe.py::test_e4b_pretrain_config_restores_env_on_exception
  • tests/unit_tests/models/gemma_vl/test_gemma4_vl_provider.py::TestGemma4DenseVLProvider::test_keep_hf_precision_buffers_in_fp32
  • tests/unit_tests/models/gemma_vl/test_gemma4_vl_bridge.py::TestGemma4VLBridgeMappingRegistry::test_dense_vl_audio_tower_replicated_mappings

No perf tests impacted.

@yaoyu-33 yaoyu-33 added feature New capabilities, enhancements, or enablement work needs-review PR is ready for code review and waiting on a reviewer labels Jun 11, 2026
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
@DOGEUNNKIM

Copy link
Copy Markdown

@ko3n1g

Hi, the previous failure was due to the Codecov patch coverage gate, not a functional test failure.

I pushed a new commit, 63ebff6, with additional Gemma4 unit coverage and parity script logging cleanup. Locally, ruff/pre-commit passed and the Gemma4 unit test suite passed with 276 tests.

Could you please rerun CI/Codecov on 63ebff6?

@ko3n1g

ko3n1g commented Jun 11, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 63ebff6

Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
@DOGEUNNKIM

DOGEUNNKIM commented Jun 11, 2026

Copy link
Copy Markdown

@ko3nlg

Hi, I added additional Gemma4/Gemma4-VL unit tests to improve patch coverage, especially around the previously uncovered modeling paths.

The latest commit is 2715f5d. Could you please rerun the tests/CI with this commit?

Signed-off-by: Dogeun Kim <82812668+DOGEUNNKIM@users.noreply.github.com>
@yaoyu-33 yaoyu-33 added ready-to-merge PR is approved, current, and only waiting for CI to pass before merge and removed needs-review PR is ready for code review and waiting on a reviewer labels Jun 12, 2026
@yaoyu-33

Copy link
Copy Markdown
Contributor

/ok to test 2715f5d

@DOGEUNNKIM

Copy link
Copy Markdown

@yaoyu-33

Hi, all CI checks are passing now.

Could you please review this again when you have a chance?

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic feature New capabilities, enhancements, or enablement work ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants