mirror: feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check)#4297
mirror: feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check)#4297ko3n1g wants to merge 25 commits into
Conversation
Implement Gemma-4 E4B architecture as a Megatron-Bridge layer spec: - Gemma4SelfAttention: GQA with per-head qk_layernorm, v_norm, sliding-window causal attention, and shared-KV cache (last num_kv_shared_layers reuse K/V) - Dual-RoPE: sliding-window layers use theta=10000, full-attention layers use theta=1000000 with partial_factor=0.25 - Per-Layer Embeddings (PLE): per-layer vocab embedding projected and added to hidden states at each transformer layer (norm → linear → add embed lookup × 1/√2) - Gemma4TransformerLayer: 4-norm residual structure matching HF implementation - wire_gemma4_kv_sharing(): post-construction wiring of shared-KV references - gemma4_layer_spec / get_gemma4_layer_spec(): ModuleSpec factory for --spec flag Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add loader plugin for converting HuggingFace Gemma-4 E4B checkpoints to Megatron format, compatible with convert.py --loader gemma4_hf: - QKV weight fusion and layout mapping (HF separate Q/K/V → Megatron fused) - Per-Layer Embedding (PLE) weight mapping (embed_tokens_per_layer) - GEGLU weight interleaved TP split (gate/up interleaved per rank, not contiguous) - Shared-KV layer detection and zero-initialization for non-source layers - geglu_tanh=True metadata to match HF gelu_pytorch_tanh activation Usage (from Megatron-Bridge root): PYTHONPATH=$PWD/src:$PWD/examples/models/gemma/gemma4:$MEGATRON_LM_ROOT/tools/checkpoint \ python $MEGATRON_LM_ROOT/tools/checkpoint/convert.py --loader gemma4_hf --saver core Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add end-to-end verification and training scripts for Gemma-4 E4B: - parity_check_e4b.py: distributed logit parity check between converted Megatron checkpoint (TP=1/2) and HuggingFace reference model; applies final_logit_softcapping=30.0 before comparison; expected max|diff| < 3.0 Fix: explicitly call Bridge's wire_gemma4_kv_sharing() after model construction so shared-KV layers are wired with the correct class - train_gemma4_e4b_parity.sh: launcher for parity check (torchrun, TP=2) - train_gemma4_e4b_pipeline.sh: full pipeline (convert → parity → training) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Document Gemma-4 E4B integration covering checkpoint conversion, parity verification, and training. Includes PYTHONPATH setup for Bridge-based loader discovery and note on GEGLU weight TP splitting fix. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Add Gemma4E4BProvider to gemma4_layer_specs.py so the parity check (and future training scripts) work against a clean Megatron-Core that has no Gemma4-specific CLI args or TransformerConfig fields. Provider responsibilities: - Builds TransformerConfig with standard fields only; injects Gemma4 fields (global_kv_channels, num_kv_shared_layers, per_layer_embed_*) via setattr after dataclass construction, guarding against the dual-RoPE ValueError added in clean MCore. - Replaces model.rotary_pos_emb with Gemma4RotaryEmbedding (dual-theta). - Attaches PLE modules (VocabParallelEmbedding + ColumnParallelLinear + Gemma4RMSNorm) and patches model.forward() to compute per_layer_inputs once and inject via extra_block_kwargs -> TransformerBlock threading. - Calls wire_gemma4_kv_sharing() after construction. Update parity_check_e4b.py: - Remove 7 Gemma4-specific CLI flags from _build_megatron_argv(). - Replace gpt_builder / model_provider with Gemma4E4BProvider.build(). - No Megatron-LM source changes required. Verified: TP=1 max|diff|=2.73, TP=2 max|diff|=2.94 (atol=3.0, bf16). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
- Fix parity expected results: fp32 ~0.15 (atol=0.3), bf16 ~2.73 (atol=3.0) - Add Gemma4E4BProvider to What's included table - Remove stale CLI flag references (--num-kv-shared-layers, --geglu-tanh, --per-layer-embed-*) — these are now Gemma4E4BProvider defaults - Update GEGLU fix section: split signaled via md.geglu=True in loader - Add PYTHONPATH to test command - Note clean MCore compatibility (no Gemma4-specific CLI args) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
…ain.sh Align with Bridge examples convention (slurm_pretrain.sh pattern used by deepseek_v4, gpt_oss, stepfun, etc.). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Move shared Gemma4 text modeling, providers, and bridge logic under models/gemma, keep Gemma4 VL files focused on multimodal wrappers, and add text conversion/inference examples and tests. Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
|
/ok to test 1d0f4ba |
Light Code Review — PR 4297Findings1. Bare print() in examples/models/gemma/gemma4/parity_check_e4b.py About 12 bare print() calls in a distributed script. Project rule: NEVER use bare print() — use logging.getLogger(name) or print_rank_0(). Since this is a torchrun script that runs on every rank, the output will be duplicated. Use print_rank_0() (already imported from megatron.core) or logging. 2. MoE VL provider defaults to float32 while text-only MoE uses bfloat16 gemma4_vl_bridge.py sets provider.bf16 = False / provider.params_dtype = torch.float32 for the MoE VL path, while the text-only MoE path in gemma4_bridge.py sets bf16 = True / bfloat16. The test (test_dtype_is_fp32_for_vl) confirms this is intentional, but it doubles memory usage for MoE VL users going through AutoBridge. If this is needed for parity/precision reasons, a brief code comment explaining why the VL path differs from the text path would help future readers. 3. No _compute_attention_mask handling for audio bidirectional attention In modeling_gemma4_vl.py, _compute_attention_mask creates bidirectional blocks for image token spans via _bidirectional_block_mask, but audio token spans receive only the default causal mask. If the Gemma 4 architecture expects bidirectional self-attention within audio segments (similar to image segments), this is a logic gap. If audio is intentionally causal, a comment would clarify. Missing Test Coverage
Suggested Test Cases
No perf tests impacted. |
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
|
Hi, the previous failure was due to the Codecov patch coverage gate, not a functional test failure. I pushed a new commit, 63ebff6, with additional Gemma4 unit coverage and parity script logging cleanup. Locally, ruff/pre-commit passed and the Gemma4 unit test suite passed with 276 tests. Could you please rerun CI/Codecov on 63ebff6? |
|
/ok to test 63ebff6 |
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr>
|
@ko3nlg Hi, I added additional Gemma4/Gemma4-VL unit tests to improve patch coverage, especially around the previously uncovered modeling paths. The latest commit is 2715f5d. Could you please rerun the tests/CI with this commit? |
Signed-off-by: Dogeun Kim <82812668+DOGEUNNKIM@users.noreply.github.com>
|
/ok to test 2715f5d |
|
Hi, all CI checks are passing now. Could you please review this again when you have a chance? Thank you. |
Claude summary
Mirror of #4148 by @DOGEUNNKIM — copied into the upstream repo so the full CI pipeline runs natively (cross-fork PRs cannot trigger it).
DOGEUNNKIM/Megatron-Bridge:mainCommits are copied verbatim with authorship preserved. Review and discussion remain on #4148.