Fix Gemma 4 LoRA training: vision backward NaN + audio_tower freeze leak#1052
Merged
Blaizzy merged 4 commits intoBlaizzy:mainfrom Apr 24, 2026
Merged
Fix Gemma 4 LoRA training: vision backward NaN + audio_tower freeze leak#1052Blaizzy merged 4 commits intoBlaizzy:mainfrom
Blaizzy merged 4 commits intoBlaizzy:mainfrom
Conversation
Two Gemma-4-specific bugs that block LoRA fine-tuning on the multimodal
checkpoint.
### 1. vision_tower produces NaN gradients on padded images
When pixel_values contain fewer real patches than ``max_patches`` (any
non-square input, or a smaller-than-max resolution), the padded query rows
in the bidirectional attention mask are entirely -inf. softmax of an
all-(-inf) row is finite in forward (it is zeroed out downstream by the
pooler) but during backward produces NaN: ``exp(-inf - -inf) = exp(NaN)``.
Minimal repro on ``google/gemma-4-E2B-it`` (rescaled image with padding):
model, _ = load("google/gemma-4-E2B-it")
freeze_model(model); model.vision_tower.unfreeze()
pv = ... # any pixel_values where num_real < max_patches
def f(x): return model.vision_tower(x).sum()
val, grad = nn.value_and_grad(model.vision_tower, f)(pv)
# val is NaN; every grad leaf is NaN
Forward alone is fine (finite result), so the symptom only appears once a
training loop calls ``nn.value_and_grad``, which makes it easy to misread
as a LoRA / dtype bug.
Fix: fill masked positions with a large finite negative (``-1e4``) in the
image-embed dtype. The forward behaviour is unchanged (exp(-1e4) is
subnormal and the padded outputs are still zeroed in the pooler), but the
gradient through softmax stays finite, so LoRA / full FT on Gemma 4 works.
### 2. freeze_model leaks audio_tower / embed_audio / embed_vision
``freeze_model`` only covers a fixed set of top-level module names that
pre-dates the Gemma-4 audio branch. As a result every ``audio_tower.*``
weight (plus ``embed_audio`` / ``embed_vision``) is still trainable when
``get_peft_model`` hands the model to the SFT trainer. Two user-facing
consequences:
- ``save_adapter`` dumps ``model.trainable_parameters()``, so the saved
LoRA file grows from ~100 MB to ~700 MB with unrelated audio weights.
- Those weights silently receive gradients (though no real audio input
flows through them on an image-only batch), which is surprising.
Fix: include ``audio_tower`` / ``embed_audio`` / ``embed_vision`` in the
top-level freeze list. Some audio sub-modules (notably
``AudioRelativePositionEmbedding`` in Gemma 4, whose ``_inv_timescales``
is a plain buffer) crash ``freeze()``, so fall back to walking leaf
modules on error.
### Verified
With both fixes applied, Gemma 4 E2B LoRA fine-tuning on Unsplash / Pexels
photos runs to completion on an M4 Max (500 iters, batch 1, bf16):
loss 10.6 → 1.1, adapter file ~110 MB per checkpoint, no NaN.
Blaizzy
requested changes
Apr 24, 2026
Contributor
Author
|
Trimmed the comments in b01c289 — kept only the non-obvious WHY (the NaN-in-backward rationale on the mask, and the freeze() fallback). PTAL. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Fine-tuning the multimodal
google/gemma-4-E2B-itcheckpoint with the current SFT + LoRA pipeline hits two Gemma-4-specific bugs that together make training impossible or highly wasteful.1.
vision_towerbackward produces NaN on padded images (blocker)When
pixel_valuescontain fewer real patches thanmax_patches(any non-square input or smaller-than-max resolution), padded query rows in the bidirectional attention mask are entirely-inf.softmaxof an all-(-inf) row is finite in forward (those positions are zeroed out later in the pooler) but during backward produces NaN —exp(-inf - -inf) = exp(NaN)— so every LoRA/LR gradient ends up NaN.Minimal repro on vanilla
google/gemma-4-E2B-it:```python
import mlx.core as mx, mlx.nn as nn
from mlx_vlm.utils import load
from mlx_vlm.trainer.utils import freeze_model
model, _ = load("google/gemma-4-E2B-it", processor_config={"trust_remote_code": True})
freeze_model(model)
model.vision_tower.unfreeze()
any pixel_values where num_real < max_patches, e.g. a random 960x624 tensor
pv = mx.random.normal(shape=(1, 3, 960, 624))
def f(x):
return model.vision_tower(x).sum()
v_direct = f(pv) # finite
v_vg, _ = nn.value_and_grad(model.vision_tower, f)(pv) # NaN
```
The forward alone is fine, so the symptom only appears once a training loop wraps the forward in
nn.value_and_grad— which makes it easy to misdiagnose as a LoRA or dtype bug.Fix: fill masked positions with a large finite negative (
-1e4) instead of-inf. Forward behaviour is unchanged (exp(-1e4)is subnormal and the padded outputs are still zeroed downstream by the pooler), but the softmax gradient stays finite, so LoRA / full FT on Gemma 4 converges normally.2.
freeze_modelleaksaudio_tower/embed_audio/embed_visionfreeze_modelhard-codes a top-level freeze list that pre-dates Gemma 4's audio branch. So afterget_peft_modelhands the model to the SFT trainer, everyaudio_tower.*weight (plusembed_audio,embed_vision) is still trainable. Two user-visible consequences:save_adapterdumpsmodel.trainable_parameters(), so adapter files balloon from ~100 MB to ~700 MB per checkpoint with unrelated audio weights.Fix: add
audio_tower/embed_audio/embed_visionto the top-level freeze list. A fallback is included because some audio sub-modules (notablyAudioRelativePositionEmbedding, whose_inv_timescalesis a plain buffer) tripnn.Module.freeze(); on exception we walk leaf modules instead.Verified
With both fixes applied, Gemma 4 E2B LoRA fine-tuning on Unsplash / Pexels photos runs to completion on an M4 Max (500 iters, batch 1, bf16):
Scope
Two surgical changes; no API changes:
mlx_vlm/models/gemma4/vision.py— mask fill valuemlx_vlm/trainer/utils.py— freeze-list additions + leaf-walk fallbackContext: I maintain a separate repo where this sequence (LoRA-train Gemma 4 → fuse → CoreML bundle → iPhone) is productized; happy to split into two PRs if you prefer.